common.cuh 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812
  1. #pragma once
  2. #include "ggml.h"
  3. #include "ggml-cuda.h"
  4. #include <cstdint>
  5. #include <memory>
  6. #if defined(GGML_USE_HIP)
  7. #define GGML_COMMON_DECL_HIP
  8. #define GGML_COMMON_IMPL_HIP
  9. #else
  10. #define GGML_COMMON_DECL_CUDA
  11. #define GGML_COMMON_IMPL_CUDA
  12. #if defined(GGML_USE_MUSA)
  13. #define GGML_COMMON_DECL_MUSA
  14. #define GGML_COMMON_IMPL_MUSA
  15. #endif
  16. #endif
  17. #include "ggml-common.h"
  18. #include <array>
  19. #include <cassert>
  20. #include <cfloat>
  21. #include <cstdio>
  22. #include <string>
  23. #include <vector>
  24. #if defined(GGML_USE_HIP)
  25. #include "vendors/hip.h"
  26. #elif defined(GGML_USE_MUSA)
  27. #include "vendors/musa.h"
  28. #else
  29. #include "vendors/cuda.h"
  30. #endif // defined(GGML_USE_HIP)
  31. #define STRINGIZE_IMPL(...) #__VA_ARGS__
  32. #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
  33. #define WARP_SIZE 32
  34. #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
  35. #define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
  36. #define GGML_CUDA_CC_PASCAL 600
  37. #define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
  38. #define GGML_CUDA_CC_VOLTA 700
  39. #define GGML_CUDA_CC_TURING 750
  40. #define GGML_CUDA_CC_AMPERE 800
  41. #define GGML_CUDA_CC_ADA_LOVELACE 890
  42. #define GGML_CUDA_CC_OFFSET_AMD 0x1000000
  43. #define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
  44. #define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
  45. // AMD
  46. // GCN/CDNA, wave size is 64
  47. #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
  48. #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
  49. #define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
  50. #define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
  51. #define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
  52. #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
  53. // RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32
  54. #define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
  55. #define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
  56. #define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
  57. #define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
  58. #define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD)
  59. #define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
  60. #define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
  61. #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
  62. #define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
  63. #define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
  64. #define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
  65. #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
  66. // Moore Threads
  67. #define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210)
  68. #define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
  69. #define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
  70. #define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
  71. #define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
  72. #define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
  73. #define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
  74. #define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
  75. #ifdef __CUDA_ARCH_LIST__
  76. constexpr bool ggml_cuda_has_arch_impl(int) {
  77. return false;
  78. }
  79. template<class ... Archs>
  80. constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) {
  81. return arch == first || ggml_cuda_has_arch_impl(arch, rest...);
  82. }
  83. constexpr bool ggml_cuda_has_arch(const int arch) {
  84. return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
  85. }
  86. constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) {
  87. if (cur == 0) {
  88. GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch);
  89. }
  90. return cur;
  91. }
  92. template<class ... Archs>
  93. constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) {
  94. if (first <= arch && first > cur) {
  95. return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...);
  96. } else {
  97. return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...);
  98. }
  99. }
  100. constexpr int ggml_cuda_highest_compiled_arch(const int arch) {
  101. return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__);
  102. }
  103. #else
  104. static int ggml_cuda_highest_compiled_arch(const int arch) {
  105. return arch;
  106. }
  107. #endif // __CUDA_ARCH_LIST__
  108. // ---------------------------------------------------------------------------------------------------------
  109. #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
  110. #define GGML_CUDA_MAX_STREAMS 8
  111. [[noreturn]]
  112. void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
  113. #define CUDA_CHECK_GEN(err, success, error_fn) \
  114. do { \
  115. auto err_ = (err); \
  116. if (err_ != (success)) { \
  117. ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_)); \
  118. } \
  119. } while (0)
  120. #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
  121. #if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
  122. static const char * cublas_get_error_str(const cublasStatus_t err) {
  123. return cublasGetStatusString(err);
  124. }
  125. #else
  126. static const char * cublas_get_error_str(const cublasStatus_t err) {
  127. switch (err) {
  128. case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
  129. case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
  130. case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
  131. case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
  132. case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
  133. case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
  134. case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
  135. case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
  136. case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
  137. default: return "unknown error";
  138. }
  139. }
  140. #endif // CUDART_VERSION >= 12000
  141. #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
  142. #if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
  143. static const char * cu_get_error_str(CUresult err) {
  144. const char * err_str;
  145. cuGetErrorString(err, &err_str);
  146. return err_str;
  147. }
  148. #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
  149. #endif
  150. #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
  151. #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
  152. #else
  153. #define GGML_CUDA_ASSUME(x)
  154. #endif // CUDART_VERSION >= 11010
  155. #ifdef GGML_CUDA_F16
  156. typedef half dfloat; // dequantize float
  157. typedef half2 dfloat2;
  158. #else
  159. typedef float dfloat; // dequantize float
  160. typedef float2 dfloat2;
  161. #endif // GGML_CUDA_F16
  162. #if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
  163. #define GGML_USE_VMM
  164. #endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
  165. #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
  166. #define FP16_AVAILABLE
  167. #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
  168. #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
  169. #define FAST_FP16_AVAILABLE
  170. #endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
  171. #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
  172. #define FP16_MMA_AVAILABLE
  173. #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
  174. #if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
  175. #define FP16_MMA_AVAILABLE
  176. #endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
  177. #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
  178. #define NEW_MMA_AVAILABLE
  179. #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
  180. #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
  181. #define CP_ASYNC_AVAILABLE
  182. #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
  183. #if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
  184. #define FLASH_ATTN_AVAILABLE
  185. #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
  186. static bool fp16_available(const int cc) {
  187. return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
  188. }
  189. static bool fast_fp16_available(const int cc) {
  190. return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
  191. }
  192. // To be used for feature selection of external libraries, e.g. cuBLAS.
  193. static bool fast_fp16_hardware_available(const int cc) {
  194. return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
  195. }
  196. // Any FP16 tensor core instructions are available for ggml code.
  197. static bool fp16_mma_available(const int cc) {
  198. #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
  199. return false;
  200. #else
  201. return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
  202. GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
  203. #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
  204. }
  205. // To be used for feature selection of external libraries, e.g. cuBLAS.
  206. static bool fp16_mma_hardware_available(const int cc) {
  207. return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
  208. GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
  209. }
  210. // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
  211. static bool new_mma_available(const int cc) {
  212. return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
  213. }
  214. static bool cp_async_available(const int cc) {
  215. return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
  216. }
  217. static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
  218. #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
  219. return 64;
  220. #else
  221. return 32;
  222. #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(__GFX9__) || defined(__GFX8__))
  223. }
  224. [[noreturn]]
  225. static __device__ void no_device_code(
  226. const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
  227. #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
  228. printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
  229. file_name, line, function_name, arch);
  230. GGML_UNUSED(arch_list);
  231. #else
  232. printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
  233. file_name, line, function_name, arch, arch_list);
  234. #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
  235. __trap();
  236. GGML_UNUSED(no_device_code); // suppress unused function warning
  237. #if defined(GGML_USE_MUSA)
  238. __builtin_unreachable();
  239. #endif // defined(GGML_USE_MUSA)
  240. }
  241. #ifdef __CUDA_ARCH__
  242. #define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
  243. #else
  244. #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
  245. #endif // __CUDA_ARCH__
  246. // The compiler is always able to unroll loops if they contain continue expressions.
  247. // In such cases loop unrolling can still be achieved via recursion:
  248. template <int n>
  249. struct ggml_cuda_unroll {
  250. template <typename Func, typename... Args>
  251. __device__ void operator()(const Func & f, Args... args) const {
  252. f(n - 1, args...);
  253. ggml_cuda_unroll<n - 1>{}(f, args...);
  254. }
  255. };
  256. template <>
  257. struct ggml_cuda_unroll<1> {
  258. template <typename Func, typename... Args>
  259. __device__ void operator()(const Func & f, Args... args) const {
  260. f(0, args...);
  261. }
  262. };
  263. template<int width = WARP_SIZE>
  264. static __device__ __forceinline__ int warp_reduce_sum(int x) {
  265. #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
  266. return __reduce_add_sync(0xffffffff, x);
  267. #else
  268. #pragma unroll
  269. for (int offset = width/2; offset > 0; offset >>= 1) {
  270. x += __shfl_xor_sync(0xffffffff, x, offset, width);
  271. }
  272. return x;
  273. #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
  274. }
  275. template<int width = WARP_SIZE>
  276. static __device__ __forceinline__ float warp_reduce_sum(float x) {
  277. #pragma unroll
  278. for (int offset = width/2; offset > 0; offset >>= 1) {
  279. x += __shfl_xor_sync(0xffffffff, x, offset, width);
  280. }
  281. return x;
  282. }
  283. template<int width = WARP_SIZE>
  284. static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
  285. #pragma unroll
  286. for (int offset = width/2; offset > 0; offset >>= 1) {
  287. a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width);
  288. a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width);
  289. }
  290. return a;
  291. }
  292. template<int width = WARP_SIZE>
  293. static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
  294. #ifdef FP16_AVAILABLE
  295. #pragma unroll
  296. for (int offset = width/2; offset > 0; offset >>= 1) {
  297. a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width));
  298. }
  299. return a;
  300. #else
  301. NO_DEVICE_CODE;
  302. return a;
  303. #endif // FP16_AVAILABLE
  304. }
  305. template<int width = WARP_SIZE>
  306. static __device__ __forceinline__ float warp_reduce_max(float x) {
  307. #pragma unroll
  308. for (int offset = width/2; offset > 0; offset >>= 1) {
  309. x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width));
  310. }
  311. return x;
  312. }
  313. static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
  314. #ifdef FP16_AVAILABLE
  315. #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
  316. return __float2half(fmaxf(__half2float(a), __half2float(b)));
  317. #else
  318. return __hmax(a, b);
  319. #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
  320. #else
  321. NO_DEVICE_CODE;
  322. GGML_UNUSED(b);
  323. return a;
  324. #endif // FP16_AVAILABLE
  325. }
  326. static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
  327. #if defined(GGML_USE_HIP) && HIP_VERSION >= 50700000
  328. return half2(__hmax(a.x, b.x), __hmax(a.y, b.y));
  329. #elif !defined(GGML_USE_HIP) && CUDART_VERSION >= CUDART_HMAX
  330. return __hmax2(a, b);
  331. #elif !defined(GGML_USE_HIP)
  332. half2 ret;
  333. reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b)));
  334. reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
  335. return ret;
  336. #else
  337. GGML_UNUSED(a);
  338. GGML_UNUSED(b);
  339. NO_DEVICE_CODE;
  340. #endif
  341. }
  342. template<int width = WARP_SIZE>
  343. static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
  344. #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
  345. #pragma unroll
  346. for (int offset = width/2; offset > 0; offset >>= 1) {
  347. x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
  348. }
  349. return x;
  350. #else
  351. GGML_UNUSED(x);
  352. NO_DEVICE_CODE;
  353. #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
  354. }
  355. #if CUDART_VERSION < CUDART_HMASK
  356. static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
  357. const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
  358. const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
  359. return mask_low | mask_high;
  360. }
  361. #endif // CUDART_VERSION < CUDART_HMASK
  362. static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
  363. #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
  364. #if defined(CDNA) || defined(RDNA2) || defined(__gfx906__)
  365. c = __builtin_amdgcn_sdot4(a, b, c, false);
  366. #elif defined(RDNA3) || defined(RDNA4)
  367. c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
  368. #elif defined(RDNA1) || defined(__gfx900__)
  369. int tmp1;
  370. int tmp2;
  371. asm("\n \
  372. v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
  373. v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
  374. v_add3_u32 %0, %1, %2, %0 \n \
  375. v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
  376. v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
  377. v_add3_u32 %0, %1, %2, %0 \n \
  378. "
  379. : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
  380. : "v"(a), "v"(b)
  381. );
  382. #else
  383. const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
  384. const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
  385. c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
  386. #endif
  387. return c;
  388. #else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
  389. #if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
  390. return __dp4a(a, b, c);
  391. #else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
  392. const int8_t * a8 = (const int8_t *) &a;
  393. const int8_t * b8 = (const int8_t *) &b;
  394. return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
  395. #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
  396. #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
  397. }
  398. typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
  399. static __device__ __forceinline__ float get_alibi_slope(
  400. const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
  401. ) {
  402. if (max_bias <= 0.0f) {
  403. return 1.0f;
  404. }
  405. const float base = h < n_head_log2 ? m0 : m1;
  406. const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
  407. return powf(base, exph);
  408. }
  409. template <ggml_type type>
  410. struct ggml_cuda_type_traits;
  411. template<>
  412. struct ggml_cuda_type_traits<GGML_TYPE_F16> {
  413. static constexpr int qk = 1;
  414. static constexpr int qr = 1;
  415. };
  416. template<>
  417. struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
  418. static constexpr int qk = QK4_0;
  419. static constexpr int qr = QR4_0;
  420. static constexpr int qi = QI4_0;
  421. };
  422. template<>
  423. struct ggml_cuda_type_traits<GGML_TYPE_Q4_1> {
  424. static constexpr int qk = QK4_1;
  425. static constexpr int qr = QR4_1;
  426. static constexpr int qi = QI4_1;
  427. };
  428. template<>
  429. struct ggml_cuda_type_traits<GGML_TYPE_Q5_0> {
  430. static constexpr int qk = QK5_0;
  431. static constexpr int qr = QR5_0;
  432. static constexpr int qi = QI5_0;
  433. };
  434. template<>
  435. struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
  436. static constexpr int qk = QK5_1;
  437. static constexpr int qr = QR5_1;
  438. static constexpr int qi = QI5_1;
  439. };
  440. template<>
  441. struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
  442. static constexpr int qk = QK8_0;
  443. static constexpr int qr = QR8_0;
  444. static constexpr int qi = QI8_0;
  445. };
  446. template<>
  447. struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
  448. static constexpr int qk = QK_K;
  449. static constexpr int qr = QR2_K;
  450. static constexpr int qi = QI2_K;
  451. };
  452. template<>
  453. struct ggml_cuda_type_traits<GGML_TYPE_Q3_K> {
  454. static constexpr int qk = QK_K;
  455. static constexpr int qr = QR3_K;
  456. static constexpr int qi = QI3_K;
  457. };
  458. template<>
  459. struct ggml_cuda_type_traits<GGML_TYPE_Q4_K> {
  460. static constexpr int qk = QK_K;
  461. static constexpr int qr = QR4_K;
  462. static constexpr int qi = QI4_K;
  463. };
  464. template<>
  465. struct ggml_cuda_type_traits<GGML_TYPE_Q5_K> {
  466. static constexpr int qk = QK_K;
  467. static constexpr int qr = QR5_K;
  468. static constexpr int qi = QI5_K;
  469. };
  470. template<>
  471. struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
  472. static constexpr int qk = QK_K;
  473. static constexpr int qr = QR6_K;
  474. static constexpr int qi = QI6_K;
  475. };
  476. template<>
  477. struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
  478. static constexpr int qk = QK_K;
  479. static constexpr int qr = QR2_XXS;
  480. static constexpr int qi = QI2_XXS;
  481. };
  482. template<>
  483. struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XS> {
  484. static constexpr int qk = QK_K;
  485. static constexpr int qr = QR2_XS;
  486. static constexpr int qi = QI2_XS;
  487. };
  488. template<>
  489. struct ggml_cuda_type_traits<GGML_TYPE_IQ2_S> {
  490. static constexpr int qk = QK_K;
  491. static constexpr int qr = QR2_S;
  492. static constexpr int qi = QI2_S;
  493. };
  494. template<>
  495. struct ggml_cuda_type_traits<GGML_TYPE_IQ3_XXS> {
  496. static constexpr int qk = QK_K;
  497. static constexpr int qr = QR3_XXS;
  498. static constexpr int qi = QI3_XXS;
  499. };
  500. template<>
  501. struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> {
  502. static constexpr int qk = QK_K;
  503. static constexpr int qr = QR1_S;
  504. static constexpr int qi = QI1_S;
  505. };
  506. template<>
  507. struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> {
  508. static constexpr int qk = QK_K;
  509. static constexpr int qr = QR1_M;
  510. static constexpr int qi = QI1_M;
  511. };
  512. template<>
  513. struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
  514. static constexpr int qk = QK4_NL;
  515. static constexpr int qr = QR4_NL;
  516. static constexpr int qi = QI4_NL;
  517. };
  518. template<>
  519. struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
  520. static constexpr int qk = QK_K;
  521. static constexpr int qr = QR4_XS;
  522. static constexpr int qi = QI4_XS;
  523. };
  524. template<>
  525. struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
  526. static constexpr int qk = QK_K;
  527. static constexpr int qr = QR3_S;
  528. static constexpr int qi = QI3_S;
  529. };
  530. //////////////////////
  531. struct ggml_cuda_device_info {
  532. int device_count;
  533. struct cuda_device_info {
  534. int cc; // compute capability
  535. int nsm; // number of streaming multiprocessors
  536. size_t smpb; // max. shared memory per block
  537. size_t smpbo; // max. shared memory per block (with opt-in)
  538. bool integrated; // Device is integrated as opposed to discrete
  539. bool vmm; // virtual memory support
  540. size_t vmm_granularity; // granularity of virtual memory
  541. size_t total_vram;
  542. int warp_size; // Number of threads in a dispatch
  543. };
  544. cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
  545. std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
  546. };
  547. const ggml_cuda_device_info & ggml_cuda_info();
  548. void ggml_cuda_set_device(int device);
  549. int ggml_cuda_get_device();
  550. struct ggml_cuda_pool {
  551. virtual ~ggml_cuda_pool() = default;
  552. virtual void * alloc(size_t size, size_t * actual_size) = 0;
  553. virtual void free(void * ptr, size_t size) = 0;
  554. };
  555. template<typename T>
  556. struct ggml_cuda_pool_alloc {
  557. ggml_cuda_pool * pool = nullptr;
  558. T * ptr = nullptr;
  559. size_t actual_size = 0;
  560. ggml_cuda_pool_alloc() = default;
  561. explicit ggml_cuda_pool_alloc(ggml_cuda_pool & pool) : pool(&pool) {
  562. }
  563. ggml_cuda_pool_alloc(ggml_cuda_pool & pool, size_t size) : pool(&pool) {
  564. alloc(size);
  565. }
  566. ~ggml_cuda_pool_alloc() {
  567. if (ptr != nullptr) {
  568. pool->free(ptr, actual_size);
  569. }
  570. }
  571. // size is in number of elements
  572. T * alloc(size_t size) {
  573. GGML_ASSERT(pool != nullptr);
  574. GGML_ASSERT(ptr == nullptr);
  575. ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
  576. return ptr;
  577. }
  578. T * alloc(ggml_cuda_pool & pool, size_t size) {
  579. this->pool = &pool;
  580. return alloc(size);
  581. }
  582. T * get() {
  583. return ptr;
  584. }
  585. ggml_cuda_pool_alloc(const ggml_cuda_pool_alloc &) = delete;
  586. ggml_cuda_pool_alloc(ggml_cuda_pool_alloc &&) = delete;
  587. ggml_cuda_pool_alloc& operator=(const ggml_cuda_pool_alloc &) = delete;
  588. ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete;
  589. };
  590. // backend interface
  591. struct ggml_tensor_extra_gpu {
  592. void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
  593. cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
  594. };
  595. #if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
  596. #define USE_CUDA_GRAPH
  597. #endif
  598. struct ggml_graph_node_properties {
  599. void * node_address;
  600. ggml_op node_op;
  601. int64_t ne[GGML_MAX_DIMS];
  602. size_t nb[GGML_MAX_DIMS];
  603. void * src_address[GGML_MAX_SRC];
  604. int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
  605. };
  606. struct ggml_cuda_graph {
  607. #ifdef USE_CUDA_GRAPH
  608. ~ggml_cuda_graph() {
  609. if (instance != nullptr) {
  610. CUDA_CHECK(cudaGraphExecDestroy(instance));
  611. }
  612. if (graph != nullptr) {
  613. CUDA_CHECK(cudaGraphDestroy(graph));
  614. }
  615. }
  616. cudaGraph_t graph = nullptr;
  617. cudaGraphExec_t instance = nullptr;
  618. size_t num_nodes = 0;
  619. std::vector<cudaGraphNode_t> nodes;
  620. std::vector<cudaKernelNodeParams> params;
  621. bool disable_due_to_gpu_arch = false;
  622. bool disable_due_to_too_many_updates = false;
  623. bool disable_due_to_failed_graph_capture = false;
  624. int number_consecutive_updates = 0;
  625. std::vector<ggml_graph_node_properties> ggml_graph_properties;
  626. bool use_cpy_indirection = false;
  627. std::vector<char *> cpy_dest_ptrs;
  628. char ** dest_ptrs_d;
  629. int dest_ptrs_size = 0;
  630. // Index to allow each cpy kernel to be aware of it's position within the graph
  631. // relative to other cpy nodes.
  632. int graph_cpynode_index = -1;
  633. #endif
  634. };
  635. struct ggml_backend_cuda_context {
  636. int device;
  637. std::string name;
  638. cudaEvent_t copy_event = nullptr;
  639. cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
  640. cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
  641. std::unique_ptr<ggml_cuda_graph> cuda_graph;
  642. explicit ggml_backend_cuda_context(int device) :
  643. device(device),
  644. name(GGML_CUDA_NAME + std::to_string(device)) {
  645. }
  646. ~ggml_backend_cuda_context();
  647. cudaStream_t stream(int device, int stream) {
  648. if (streams[device][stream] == nullptr) {
  649. ggml_cuda_set_device(device);
  650. CUDA_CHECK(cudaStreamCreateWithFlags(&streams[device][stream], cudaStreamNonBlocking));
  651. }
  652. return streams[device][stream];
  653. }
  654. cudaStream_t stream() {
  655. return stream(device, 0);
  656. }
  657. cublasHandle_t cublas_handle(int device) {
  658. if (cublas_handles[device] == nullptr) {
  659. ggml_cuda_set_device(device);
  660. CUBLAS_CHECK(cublasCreate(&cublas_handles[device]));
  661. CUBLAS_CHECK(cublasSetMathMode(cublas_handles[device], CUBLAS_TF32_TENSOR_OP_MATH));
  662. }
  663. return cublas_handles[device];
  664. }
  665. cublasHandle_t cublas_handle() {
  666. return cublas_handle(device);
  667. }
  668. // pool
  669. std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
  670. static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
  671. ggml_cuda_pool & pool(int device) {
  672. if (pools[device] == nullptr) {
  673. pools[device] = new_pool_for_device(device);
  674. }
  675. return *pools[device];
  676. }
  677. ggml_cuda_pool & pool() {
  678. return pool(device);
  679. }
  680. };