common.cuh 46 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297
  1. #pragma once
  2. #include "ggml.h"
  3. #include "ggml-impl.h"
  4. #include "ggml-cuda.h"
  5. #include <cstdint>
  6. #include <memory>
  7. #if defined(GGML_USE_HIP)
  8. #define GGML_COMMON_DECL_HIP
  9. #define GGML_COMMON_IMPL_HIP
  10. #else
  11. #define GGML_COMMON_DECL_CUDA
  12. #define GGML_COMMON_IMPL_CUDA
  13. #if defined(GGML_USE_MUSA)
  14. #define GGML_COMMON_DECL_MUSA
  15. #define GGML_COMMON_IMPL_MUSA
  16. #endif
  17. #endif
  18. #include "ggml-common.h"
  19. #include <array>
  20. #include <algorithm>
  21. #include <cassert>
  22. #include <cfloat>
  23. #include <cstdio>
  24. #include <string>
  25. #include <unordered_map>
  26. #include <vector>
  27. #if defined(GGML_USE_HIP)
  28. #include "vendors/hip.h"
  29. #elif defined(GGML_USE_MUSA)
  30. #include "vendors/musa.h"
  31. #else
  32. #include "vendors/cuda.h"
  33. #endif // defined(GGML_USE_HIP)
  34. #define STRINGIZE_IMPL(...) #__VA_ARGS__
  35. #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
  36. #define WARP_SIZE 32
  37. #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
  38. #define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
  39. #define GGML_CUDA_CC_PASCAL 600
  40. #define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
  41. #define GGML_CUDA_CC_VOLTA 700
  42. #define GGML_CUDA_CC_TURING 750
  43. #define GGML_CUDA_CC_AMPERE 800
  44. #define GGML_CUDA_CC_ADA_LOVELACE 890
  45. // While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
  46. // https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
  47. #define GGML_CUDA_CC_BLACKWELL 1200
  48. #define GGML_CUDA_CC_RUBIN 1300
  49. #define GGML_CUDA_CC_OFFSET_AMD 0x1000000
  50. #define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
  51. #define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
  52. // AMD
  53. // GCN/CDNA, wave size is 64
  54. #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
  55. #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
  56. #define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
  57. #define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
  58. #define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
  59. #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
  60. // RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32
  61. #define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
  62. #define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
  63. #define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
  64. #define GGML_CUDA_CC_RDNA3_5 (GGML_CUDA_CC_OFFSET_AMD + 0x1150) // AI 370, AI Max 395 laptops.
  65. #define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
  66. #define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD)
  67. #define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
  68. #define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
  69. #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
  70. #define GGML_CUDA_CC_IS_RDNA3_0(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA3_5)
  71. #define GGML_CUDA_CC_IS_RDNA3_5(cc) (cc >= GGML_CUDA_CC_RDNA3_5 && cc < GGML_CUDA_CC_RDNA4)
  72. #define GGML_CUDA_CC_IS_RDNA3(cc) (GGML_CUDA_CC_IS_RDNA3_0(cc) || GGML_CUDA_CC_IS_RDNA3_5(cc))
  73. #define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
  74. #define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
  75. #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
  76. #define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2)
  77. #define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3)
  78. #define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
  79. // Moore Threads
  80. #define MUSART_HMASK 40300 // MUSA rc4.3, min. ver. for half2 -> uint mask comparisons
  81. #define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
  82. #define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
  83. #define GGML_CUDA_CC_PH1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // MTT S5000
  84. #define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
  85. #define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
  86. #define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_PH1)
  87. #define GGML_CUDA_CC_IS_PH1(cc) (cc >= GGML_CUDA_CC_PH1)
  88. #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
  89. # define GGML_CUDA_USE_CUB
  90. #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
  91. #ifdef __CUDA_ARCH_LIST__
  92. constexpr bool ggml_cuda_has_arch_impl(int) {
  93. return false;
  94. }
  95. template<class ... Archs>
  96. constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) {
  97. return arch == first || ggml_cuda_has_arch_impl(arch, rest...);
  98. }
  99. constexpr bool ggml_cuda_has_arch(const int arch) {
  100. return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
  101. }
  102. constexpr int ggml_cuda_highest_compiled_arch_impl(const int /*arch*/, const int cur) {
  103. if (cur == 0) {
  104. return -1;
  105. }
  106. return cur;
  107. }
  108. template<class ... Archs>
  109. constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) {
  110. if (first <= arch && first > cur) {
  111. return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...);
  112. } else {
  113. return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...);
  114. }
  115. }
  116. constexpr int ggml_cuda_highest_compiled_arch(const int arch) {
  117. return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__);
  118. }
  119. #else
  120. static int ggml_cuda_highest_compiled_arch(const int arch) {
  121. return arch;
  122. }
  123. #endif // __CUDA_ARCH_LIST__
  124. // ---------------------------------------------------------------------------------------------------------
  125. #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
  126. #define GGML_CUDA_MAX_STREAMS 8
  127. [[noreturn]]
  128. void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
  129. #define CUDA_CHECK_GEN(err, success, error_fn) \
  130. do { \
  131. auto err_ = (err); \
  132. if (err_ != (success)) { \
  133. ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_)); \
  134. } \
  135. } while (0)
  136. #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
  137. #if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
  138. static const char * cublas_get_error_str(const cublasStatus_t err) {
  139. return cublasGetStatusString(err);
  140. }
  141. #else
  142. static const char * cublas_get_error_str(const cublasStatus_t err) {
  143. switch (err) {
  144. case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
  145. case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
  146. case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
  147. case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
  148. case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
  149. case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
  150. case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
  151. case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
  152. case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
  153. default: return "unknown error";
  154. }
  155. }
  156. #endif // CUDART_VERSION >= 12000
  157. #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
  158. #if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)
  159. static const char * cu_get_error_str(CUresult err) {
  160. const char * err_str;
  161. cuGetErrorString(err, &err_str);
  162. return err_str;
  163. }
  164. #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
  165. #endif
  166. #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
  167. # define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
  168. do { \
  169. static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \
  170. const int id = ggml_cuda_get_device(); \
  171. if (!shared_memory_limit_raised[id]) { \
  172. CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
  173. shared_memory_limit_raised[id] = true; \
  174. } \
  175. } while (0)
  176. #else
  177. # define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
  178. do { \
  179. GGML_UNUSED(nbytes); \
  180. } while (0)
  181. #endif // !(defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
  182. #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
  183. #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
  184. #else
  185. #define GGML_CUDA_ASSUME(x)
  186. #endif // CUDART_VERSION >= 11010
  187. #if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
  188. #define GGML_USE_VMM
  189. #endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
  190. #if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
  191. #define FP16_AVAILABLE
  192. #endif // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
  193. #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
  194. #define FAST_FP16_AVAILABLE
  195. #endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
  196. #if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
  197. #define AMD_MFMA_AVAILABLE
  198. #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)
  199. #if defined(GGML_USE_HIP) && (defined(RDNA4) || defined(RDNA3))
  200. #define AMD_WMMA_AVAILABLE
  201. #endif // defined(GGML_USE_HIP) && defined(RDNA4)
  202. // The Volta instructions are in principle available on Turing or newer but they are effectively unusable:
  203. #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
  204. #define VOLTA_MMA_AVAILABLE
  205. #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
  206. #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
  207. #define TURING_MMA_AVAILABLE
  208. #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
  209. #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
  210. #define AMPERE_MMA_AVAILABLE
  211. #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
  212. #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL && __CUDA_ARCH__ < GGML_CUDA_CC_RUBIN
  213. # define BLACKWELL_MMA_AVAILABLE
  214. #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL
  215. #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
  216. #define CP_ASYNC_AVAILABLE
  217. #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
  218. #if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
  219. #define FLASH_ATTN_AVAILABLE
  220. #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
  221. static bool fp16_available(const int cc) {
  222. return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
  223. (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
  224. }
  225. static bool fast_fp16_available(const int cc) {
  226. return GGML_CUDA_CC_IS_AMD(cc) ||
  227. (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610) ||
  228. (GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc));
  229. }
  230. // To be used for feature selection of external libraries, e.g. cuBLAS.
  231. static bool fast_fp16_hardware_available(const int cc) {
  232. return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc) ||
  233. (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
  234. }
  235. // To be used for feature selection of external libraries, e.g. cuBLAS.
  236. static bool fp16_mma_hardware_available(const int cc) {
  237. return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
  238. GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc) ||
  239. (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
  240. }
  241. static bool bf16_mma_hardware_available(const int cc) {
  242. return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) ||
  243. GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 ||
  244. (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
  245. }
  246. static bool fp32_mma_hardware_available(const int cc) {
  247. return GGML_CUDA_CC_IS_CDNA(cc);
  248. }
  249. static bool amd_mfma_available(const int cc) {
  250. #if !defined(GGML_HIP_NO_MMQ_MFMA)
  251. return GGML_CUDA_CC_IS_CDNA(cc);
  252. #else
  253. return false;
  254. #endif //!defined(GGML_HIP_NO_MMQ_MFMA)
  255. }
  256. static bool amd_wmma_available(const int cc) {
  257. return (GGML_CUDA_CC_IS_RDNA4(cc) || GGML_CUDA_CC_IS_RDNA3(cc));
  258. }
  259. static bool volta_mma_available(const int cc) {
  260. return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA;
  261. }
  262. static bool turing_mma_available(const int cc) {
  263. return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
  264. }
  265. static bool ampere_mma_available(const int cc) {
  266. return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
  267. }
  268. static bool cp_async_available(const int cc) {
  269. return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
  270. }
  271. static bool blackwell_mma_available(const int cc) {
  272. return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_BLACKWELL &&
  273. ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_RUBIN;
  274. }
  275. static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
  276. #if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
  277. return 64;
  278. #else
  279. return 32;
  280. #endif // defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
  281. }
  282. // Maximum number of bytes that can be copied in a single instruction.
  283. static constexpr __device__ int ggml_cuda_get_max_cpy_bytes() {
  284. #ifdef GGML_USE_HIP
  285. return 16;
  286. #else
  287. #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
  288. return 16;
  289. #else
  290. return 8;
  291. #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
  292. #endif // GGML_USE_HIP
  293. }
  294. [[noreturn]]
  295. static __device__ void no_device_code(
  296. const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
  297. #if defined(GGML_USE_HIP)
  298. printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
  299. file_name, line, function_name, arch);
  300. GGML_UNUSED(arch_list);
  301. #else
  302. printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
  303. file_name, line, function_name, arch, arch_list);
  304. #endif // defined(GGML_USE_HIP)
  305. __trap();
  306. GGML_UNUSED(no_device_code); // suppress unused function warning
  307. #if defined(GGML_USE_MUSA)
  308. __builtin_unreachable();
  309. #endif // defined(GGML_USE_MUSA)
  310. }
  311. #ifdef __CUDA_ARCH__
  312. #define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
  313. #else
  314. #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
  315. #endif // __CUDA_ARCH__
  316. // The compiler is always able to unroll loops if they contain continue expressions.
  317. // In such cases loop unrolling can still be achieved via recursion:
  318. template <int n>
  319. struct ggml_cuda_unroll {
  320. template <typename Func, typename... Args>
  321. __device__ void operator()(const Func & f, Args... args) const {
  322. f(n - 1, args...);
  323. ggml_cuda_unroll<n - 1>{}(f, args...);
  324. }
  325. };
  326. template <>
  327. struct ggml_cuda_unroll<1> {
  328. template <typename Func, typename... Args>
  329. __device__ void operator()(const Func & f, Args... args) const {
  330. f(0, args...);
  331. }
  332. };
  333. template<int width = WARP_SIZE>
  334. static __device__ __forceinline__ int warp_reduce_sum(int x) {
  335. #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
  336. return __reduce_add_sync(0xffffffff, x);
  337. #else
  338. #pragma unroll
  339. for (int offset = width/2; offset > 0; offset >>= 1) {
  340. x += __shfl_xor_sync(0xffffffff, x, offset, width);
  341. }
  342. return x;
  343. #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
  344. }
  345. template<int width = WARP_SIZE>
  346. static __device__ __forceinline__ float warp_reduce_sum(float x) {
  347. #pragma unroll
  348. for (int offset = width/2; offset > 0; offset >>= 1) {
  349. x += __shfl_xor_sync(0xffffffff, x, offset, width);
  350. }
  351. return x;
  352. }
  353. template<int width = WARP_SIZE>
  354. static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
  355. #pragma unroll
  356. for (int offset = width/2; offset > 0; offset >>= 1) {
  357. a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width);
  358. a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width);
  359. }
  360. return a;
  361. }
  362. template<int width = WARP_SIZE>
  363. static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
  364. #ifdef FP16_AVAILABLE
  365. #pragma unroll
  366. for (int offset = width/2; offset > 0; offset >>= 1) {
  367. a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width));
  368. }
  369. return a;
  370. #else
  371. NO_DEVICE_CODE;
  372. return a;
  373. #endif // FP16_AVAILABLE
  374. }
  375. template<int width = WARP_SIZE>
  376. static __device__ __forceinline__ int warp_reduce_all(int x) {
  377. if (width == ggml_cuda_get_physical_warp_size()) {
  378. return __all_sync(0xffffffff, x);
  379. } else {
  380. #pragma unroll
  381. for (int offset = width/2; offset > 0; offset >>= 1) {
  382. x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
  383. }
  384. return x;
  385. }
  386. }
  387. template<int width = WARP_SIZE>
  388. static __device__ __forceinline__ int warp_reduce_any(int x) {
  389. if (width == ggml_cuda_get_physical_warp_size()) {
  390. return __any_sync(0xffffffff, x);
  391. } else {
  392. #pragma unroll
  393. for (int offset = width/2; offset > 0; offset >>= 1) {
  394. x = __shfl_xor_sync(0xffffffff, x, offset, width) || x;
  395. }
  396. return x;
  397. }
  398. }
  399. template<int width = WARP_SIZE>
  400. static __device__ __forceinline__ float warp_reduce_max(float x) {
  401. #pragma unroll
  402. for (int offset = width/2; offset > 0; offset >>= 1) {
  403. x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width));
  404. }
  405. return x;
  406. }
  407. template<typename T, int width = WARP_SIZE>
  408. static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) {
  409. const int lane_id = threadIdx.x % width;
  410. #pragma unroll
  411. for (int offset = 1; offset < width; offset <<= 1) {
  412. const T t = __shfl_up_sync(0xffffffff, x, offset, width);
  413. if (lane_id >= offset) {
  414. x += t;
  415. }
  416. }
  417. return x;
  418. }
  419. template<int width = WARP_SIZE>
  420. static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) {
  421. const int lane_id = threadIdx.x % width;
  422. #pragma unroll
  423. for (int offset = 1; offset < width; offset <<= 1) {
  424. const float t_x = __shfl_up_sync(0xffffffff, a.x, offset, width);
  425. const float t_y = __shfl_up_sync(0xffffffff, a.y, offset, width);
  426. if (lane_id >= offset) {
  427. a.x += t_x;
  428. a.y += t_y;
  429. }
  430. }
  431. return a;
  432. }
  433. template<int width = WARP_SIZE>
  434. static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {
  435. #ifdef FP16_AVAILABLE
  436. const int lane_id = threadIdx.x % width;
  437. #pragma unroll
  438. for (int offset = 1; offset < width; offset <<= 1) {
  439. const half2 t = __shfl_up_sync(0xffffffff, a, offset, width);
  440. if (lane_id >= offset) {
  441. a = __hadd2(a, t);
  442. }
  443. }
  444. return a;
  445. #else
  446. NO_DEVICE_CODE;
  447. return a;
  448. #endif // FP16_AVAILABLE
  449. }
  450. static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
  451. #ifdef FP16_AVAILABLE
  452. #if !defined(GGML_USE_HIP) && CUDART_VERSION < CUDART_HMAX
  453. return __float2half(fmaxf(__half2float(a), __half2float(b)));
  454. #else
  455. return __hmax(a, b);
  456. #endif // !defined(GGML_USE_HIP) && CUDART_VERSION < CUDART_HMAX
  457. #else
  458. NO_DEVICE_CODE;
  459. GGML_UNUSED(b);
  460. return a;
  461. #endif // FP16_AVAILABLE
  462. }
  463. static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
  464. #if defined(GGML_USE_HIP)
  465. return half2(__hmax(a.x, b.x), __hmax(a.y, b.y));
  466. #elif CUDART_VERSION >= CUDART_HMAX
  467. return __hmax2(a, b);
  468. #else
  469. half2 ret;
  470. reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b)));
  471. reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
  472. return ret;
  473. #endif
  474. }
  475. template<int width = WARP_SIZE>
  476. static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
  477. #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP)
  478. #pragma unroll
  479. for (int offset = width/2; offset > 0; offset >>= 1) {
  480. x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
  481. }
  482. return x;
  483. #else
  484. GGML_UNUSED(x);
  485. NO_DEVICE_CODE;
  486. #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || defined(GGML_USE_HIP)
  487. }
  488. #if (defined(CUDART_VERSION) && CUDART_VERSION < CUDART_HMASK) || defined(GGML_USE_HIP) || \
  489. (defined(MUSART_VERSION) && MUSART_VERSION < MUSART_HMASK)
  490. static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
  491. const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
  492. const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
  493. return mask_low | mask_high;
  494. }
  495. #endif // (defined(CUDART_VERSION) && CUDART_VERSION < CUDART_HMASK) || defined(GGML_USE_HIP) || (defined(MUSART_VERSION) && MUSART_VERSION < MUSART_HMASK)
  496. static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
  497. #if defined(GGML_USE_HIP)
  498. #if defined(CDNA) || defined(RDNA2) || defined(__gfx906__)
  499. c = __builtin_amdgcn_sdot4(a, b, c, false);
  500. #elif defined(RDNA3) || defined(RDNA4)
  501. c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
  502. #elif defined(RDNA1) || defined(__gfx900__)
  503. int tmp1;
  504. int tmp2;
  505. asm("\n \
  506. v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
  507. v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
  508. v_add3_u32 %0, %1, %2, %0 \n \
  509. v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
  510. v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
  511. v_add3_u32 %0, %1, %2, %0 \n \
  512. "
  513. : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
  514. : "v"(a), "v"(b)
  515. );
  516. #else
  517. const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
  518. const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
  519. c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
  520. #endif
  521. return c;
  522. #else // defined(GGML_USE_HIP)
  523. #if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
  524. return __dp4a(a, b, c);
  525. #else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
  526. const int8_t * a8 = (const int8_t *) &a;
  527. const int8_t * b8 = (const int8_t *) &b;
  528. return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
  529. #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
  530. #endif // defined(GGML_USE_HIP)
  531. }
  532. static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) {
  533. acc += v*u;
  534. }
  535. static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v, const float2 u) {
  536. acc += v.x*u.x;
  537. acc += v.y*u.y;
  538. }
  539. #if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
  540. #define V_DOT2_F32_F16_AVAILABLE
  541. #endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
  542. static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
  543. #ifdef V_DOT2_F32_F16_AVAILABLE
  544. asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
  545. #else
  546. #ifdef FAST_FP16_AVAILABLE
  547. const float2 tmp = __half22float2(v*u);
  548. acc += tmp.x + tmp.y;
  549. #else
  550. const float2 tmpv = __half22float2(v);
  551. const float2 tmpu = __half22float2(u);
  552. acc += tmpv.x * tmpu.x;
  553. acc += tmpv.y * tmpu.y;
  554. #endif // FAST_FP16_AVAILABLE
  555. #endif // V_DOT2_F32_F16_AVAILABLE
  556. }
  557. static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {
  558. #ifdef FAST_FP16_AVAILABLE
  559. acc += v*u;
  560. #else
  561. const float2 tmpv = __half22float2(v);
  562. const float2 tmpu = __half22float2(u);
  563. float2 tmpacc = __half22float2(acc);
  564. tmpacc.x += tmpv.x * tmpu.x;
  565. tmpacc.y += tmpv.y * tmpu.y;
  566. acc = make_half2(tmpacc.x, tmpacc.y);
  567. #endif // FAST_FP16_AVAILABLE
  568. }
  569. // Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
  570. // Important: do not use this function if dst and src both point at registers.
  571. // Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types.
  572. // The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions.
  573. // If dst and src point at different address spaces then they are guaranteed to not be aliased.
  574. template <int nbytes, int alignment = 0>
  575. static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
  576. static_assert(
  577. nbytes <= ggml_cuda_get_max_cpy_bytes() || alignment == 0,
  578. "You are misusing the alignment parameter for ggml_cuda_memcpy_1. "
  579. "The intent is for the parameter is only as a workaround if either one of the pointers is not properly aligned. "
  580. "If you use it to do more bytes per copy than ggml_cuda_max_cpy_bytes() the reads and writes may not be coalesced. "
  581. "Call ggml_cuda_memcpy_1 in a loop instead.");
  582. if constexpr (alignment != 0) {
  583. static_assert(nbytes % alignment == 0, "bad alignment");
  584. }
  585. constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment;
  586. #pragma unroll
  587. for (int i = 0; i < nbytes/nb_per_cpy; ++i) {
  588. if constexpr (nb_per_cpy == 1) {
  589. ((char *) dst)[i] = ((const char *) src)[i];
  590. } else if constexpr (nb_per_cpy == 2) {
  591. ((short *) dst)[i] = ((const short *) src)[i];
  592. } else if constexpr (nb_per_cpy == 4) {
  593. ((int *) dst)[i] = ((const int *) src)[i];
  594. } else if constexpr (nb_per_cpy == 8) {
  595. ((int2 *) dst)[i] = ((const int2 *) src)[i];
  596. } else if constexpr (nb_per_cpy == 16) {
  597. ((int4 *) dst)[i] = ((const int4 *) src)[i];
  598. } else {
  599. static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
  600. }
  601. }
  602. }
  603. static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
  604. #if CUDART_VERSION >= 12080
  605. const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);
  606. return (float) e;
  607. #else
  608. uint32_t bits;
  609. if (x == 0) {
  610. bits = 0x00400000;
  611. } else {
  612. bits = (uint32_t) x << 23;
  613. }
  614. float result;
  615. memcpy(&result, &bits, sizeof(float));
  616. return result;
  617. #endif // CUDART_VERSION >= 12050
  618. }
  619. __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
  620. const uint8_t sign_bit = (x < 0.0f) << 3;
  621. float ax = fabsf(x) * e;
  622. // Positive LUT
  623. static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f };
  624. int best_i = 0;
  625. float best_err = fabsf(ax - pos_lut[0]);
  626. #pragma unroll
  627. for (int i = 1; i < 8; ++i) {
  628. const float err = fabsf(ax - pos_lut[i]);
  629. if (err < best_err) {
  630. best_err = err;
  631. best_i = i;
  632. }
  633. }
  634. return static_cast<uint8_t>(best_i | sign_bit);
  635. }
  636. // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
  637. // Precompute mp (m' in the paper) and L such that division
  638. // can be computed using a multiply (high 32b of 64b result)
  639. // and a shift:
  640. //
  641. // n/d = (mulhi(n, mp) + n) >> L;
  642. static const uint3 init_fastdiv_values(uint64_t d_64) {
  643. GGML_ASSERT(d_64 != 0);
  644. GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max());
  645. uint32_t d = (uint32_t)d_64;
  646. // compute L = ceil(log2(d));
  647. uint32_t L = 0;
  648. while (L < 32 && (uint32_t{ 1 } << L) < d) {
  649. L++;
  650. }
  651. uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
  652. // pack divisor as well to reduce error surface
  653. return make_uint3(mp, L, d);
  654. }
  655. static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint3 fastdiv_values) {
  656. // expects fastdiv_values to contain <mp, L, divisor> in <x, y, z>
  657. // fastdiv_values.z is unused and optimized away by the compiler.
  658. // Compute high 32 bits of n * mp
  659. const uint32_t hi = __umulhi(n, fastdiv_values.x);
  660. // add n, apply bit shift
  661. return (hi + n) >> fastdiv_values.y;
  662. }
  663. static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fastdiv_values) {
  664. // expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
  665. return n - fastdiv(n, fastdiv_values) * fastdiv_values.z;
  666. }
  667. // Calculate both division and modulo at once, returns <n/divisor, n%divisor>
  668. static __device__ __forceinline__ uint2 fast_div_modulo(uint32_t n, const uint3 fastdiv_values) {
  669. // expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
  670. const uint32_t div_val = fastdiv(n, fastdiv_values);
  671. const uint32_t mod_val = n - div_val * fastdiv_values.z;
  672. return make_uint2(div_val, mod_val);
  673. }
  674. typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);
  675. static __device__ __forceinline__ float get_alibi_slope(
  676. const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
  677. ) {
  678. if (max_bias <= 0.0f) {
  679. return 1.0f;
  680. }
  681. const float base = h < n_head_log2 ? m0 : m1;
  682. const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
  683. return powf(base, exph);
  684. }
  685. template <ggml_type type>
  686. struct ggml_cuda_type_traits;
  687. template<>
  688. struct ggml_cuda_type_traits<GGML_TYPE_F16> {
  689. static constexpr int qk = 1;
  690. static constexpr int qr = 1;
  691. };
  692. template<>
  693. struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
  694. static constexpr int qk = QK4_0;
  695. static constexpr int qr = QR4_0;
  696. static constexpr int qi = QI4_0;
  697. };
  698. template<>
  699. struct ggml_cuda_type_traits<GGML_TYPE_Q4_1> {
  700. static constexpr int qk = QK4_1;
  701. static constexpr int qr = QR4_1;
  702. static constexpr int qi = QI4_1;
  703. };
  704. template<>
  705. struct ggml_cuda_type_traits<GGML_TYPE_Q5_0> {
  706. static constexpr int qk = QK5_0;
  707. static constexpr int qr = QR5_0;
  708. static constexpr int qi = QI5_0;
  709. };
  710. template<>
  711. struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
  712. static constexpr int qk = QK5_1;
  713. static constexpr int qr = QR5_1;
  714. static constexpr int qi = QI5_1;
  715. };
  716. template<>
  717. struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
  718. static constexpr int qk = QK8_0;
  719. static constexpr int qr = QR8_0;
  720. static constexpr int qi = QI8_0;
  721. };
  722. template<>
  723. struct ggml_cuda_type_traits<GGML_TYPE_MXFP4> {
  724. static constexpr int qk = QK_MXFP4;
  725. static constexpr int qr = QR_MXFP4;
  726. static constexpr int qi = QI_MXFP4;
  727. };
  728. template<>
  729. struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
  730. static constexpr int qk = QK_K;
  731. static constexpr int qr = QR2_K;
  732. static constexpr int qi = QI2_K;
  733. };
  734. template<>
  735. struct ggml_cuda_type_traits<GGML_TYPE_Q3_K> {
  736. static constexpr int qk = QK_K;
  737. static constexpr int qr = QR3_K;
  738. static constexpr int qi = QI3_K;
  739. };
  740. template<>
  741. struct ggml_cuda_type_traits<GGML_TYPE_Q4_K> {
  742. static constexpr int qk = QK_K;
  743. static constexpr int qr = QR4_K;
  744. static constexpr int qi = QI4_K;
  745. };
  746. template<>
  747. struct ggml_cuda_type_traits<GGML_TYPE_Q5_K> {
  748. static constexpr int qk = QK_K;
  749. static constexpr int qr = QR5_K;
  750. static constexpr int qi = QI5_K;
  751. };
  752. template<>
  753. struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
  754. static constexpr int qk = QK_K;
  755. static constexpr int qr = QR6_K;
  756. static constexpr int qi = QI6_K;
  757. };
  758. template<>
  759. struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
  760. static constexpr int qk = QK_K;
  761. static constexpr int qr = QR2_XXS;
  762. static constexpr int qi = QI2_XXS;
  763. };
  764. template<>
  765. struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XS> {
  766. static constexpr int qk = QK_K;
  767. static constexpr int qr = QR2_XS;
  768. static constexpr int qi = QI2_XS;
  769. };
  770. template<>
  771. struct ggml_cuda_type_traits<GGML_TYPE_IQ2_S> {
  772. static constexpr int qk = QK_K;
  773. static constexpr int qr = QR2_S;
  774. static constexpr int qi = QI2_S;
  775. };
  776. template<>
  777. struct ggml_cuda_type_traits<GGML_TYPE_IQ3_XXS> {
  778. static constexpr int qk = QK_K;
  779. static constexpr int qr = QR3_XXS;
  780. static constexpr int qi = QI3_XXS;
  781. };
  782. template<>
  783. struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> {
  784. static constexpr int qk = QK_K;
  785. static constexpr int qr = QR1_S;
  786. static constexpr int qi = QI1_S;
  787. };
  788. template<>
  789. struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> {
  790. static constexpr int qk = QK_K;
  791. static constexpr int qr = QR1_M;
  792. static constexpr int qi = QI1_M;
  793. };
  794. template<>
  795. struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
  796. static constexpr int qk = QK4_NL;
  797. static constexpr int qr = QR4_NL;
  798. static constexpr int qi = QI4_NL;
  799. };
  800. template<>
  801. struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
  802. static constexpr int qk = QK_K;
  803. static constexpr int qr = QR4_XS;
  804. static constexpr int qi = QI4_XS;
  805. };
  806. template<>
  807. struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
  808. static constexpr int qk = QK_K;
  809. static constexpr int qr = QR3_S;
  810. static constexpr int qi = QI3_S;
  811. };
  812. //////////////////////
  813. struct ggml_cuda_device_info {
  814. int device_count;
  815. struct cuda_device_info {
  816. int cc; // compute capability
  817. int nsm; // number of streaming multiprocessors
  818. size_t smpb; // max. shared memory per block
  819. size_t smpbo; // max. shared memory per block (with opt-in)
  820. bool integrated; // Device is integrated as opposed to discrete
  821. bool vmm; // virtual memory support
  822. size_t vmm_granularity; // granularity of virtual memory
  823. size_t total_vram;
  824. int warp_size; // Number of threads in a dispatch
  825. bool supports_cooperative_launch; // whether cooperative launch is supported
  826. };
  827. cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
  828. std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
  829. };
  830. const ggml_cuda_device_info & ggml_cuda_info();
  831. void ggml_cuda_set_device(int device);
  832. int ggml_cuda_get_device();
  833. struct ggml_cuda_pool {
  834. virtual ~ggml_cuda_pool() = default;
  835. virtual void * alloc(size_t size, size_t * actual_size) = 0;
  836. virtual void free(void * ptr, size_t size) = 0;
  837. };
  838. template<typename T>
  839. struct ggml_cuda_pool_alloc {
  840. ggml_cuda_pool * pool = nullptr;
  841. T * ptr = nullptr;
  842. size_t actual_size = 0;
  843. ggml_cuda_pool_alloc() = default;
  844. explicit ggml_cuda_pool_alloc(ggml_cuda_pool & pool) : pool(&pool) {
  845. }
  846. ggml_cuda_pool_alloc(ggml_cuda_pool & pool, size_t size) : pool(&pool) {
  847. alloc(size);
  848. }
  849. ~ggml_cuda_pool_alloc() {
  850. if (ptr != nullptr) {
  851. pool->free(ptr, actual_size);
  852. }
  853. }
  854. // size is in number of elements
  855. T * alloc(size_t size) {
  856. GGML_ASSERT(pool != nullptr);
  857. GGML_ASSERT(ptr == nullptr);
  858. ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
  859. return ptr;
  860. }
  861. T * alloc(ggml_cuda_pool & pool, size_t size) {
  862. this->pool = &pool;
  863. return alloc(size);
  864. }
  865. T * get() {
  866. return ptr;
  867. }
  868. ggml_cuda_pool_alloc(const ggml_cuda_pool_alloc &) = delete;
  869. ggml_cuda_pool_alloc(ggml_cuda_pool_alloc &&) = delete;
  870. ggml_cuda_pool_alloc& operator=(const ggml_cuda_pool_alloc &) = delete;
  871. ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete;
  872. };
  873. // backend interface
  874. struct ggml_tensor_extra_gpu {
  875. void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
  876. cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
  877. };
  878. #if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)) || defined(GGML_MUSA_GRAPHS)
  879. #define USE_CUDA_GRAPH
  880. #endif
  881. struct ggml_graph_node_properties {
  882. void * node_address;
  883. ggml_op node_op;
  884. int64_t ne[GGML_MAX_DIMS];
  885. size_t nb[GGML_MAX_DIMS];
  886. void * src_address[GGML_MAX_SRC];
  887. int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
  888. };
  889. struct ggml_cuda_graph {
  890. #ifdef USE_CUDA_GRAPH
  891. ~ggml_cuda_graph() {
  892. if (instance != nullptr) {
  893. CUDA_CHECK(cudaGraphExecDestroy(instance));
  894. }
  895. if (graph != nullptr) {
  896. CUDA_CHECK(cudaGraphDestroy(graph));
  897. }
  898. }
  899. cudaGraph_t graph = nullptr;
  900. cudaGraphExec_t instance = nullptr;
  901. size_t num_nodes = 0;
  902. std::vector<cudaGraphNode_t> nodes;
  903. bool disable_due_to_gpu_arch = false;
  904. bool disable_due_to_too_many_updates = false;
  905. bool disable_due_to_failed_graph_capture = false;
  906. int number_consecutive_updates = 0;
  907. bool cuda_graphs_enabled = false;
  908. std::vector<ggml_graph_node_properties> ggml_graph_properties;
  909. std::vector<ggml_graph_node_properties> extraneous_srcs_properties;
  910. #endif
  911. };
  912. struct ggml_cuda_concurrent_event {
  913. std::vector<cudaEvent_t> join_events;
  914. cudaEvent_t fork_event = nullptr;
  915. int n_streams = 0;
  916. std::unordered_map<const ggml_tensor *, int> stream_mapping;
  917. // Original order of nodes in this concurrent region (before interleaving)
  918. // Used to restore grouping for fusion within streams
  919. std::vector<const ggml_tensor *> original_order;
  920. const ggml_tensor * join_node;
  921. ggml_cuda_concurrent_event() = default;
  922. ggml_cuda_concurrent_event(const ggml_cuda_concurrent_event &) = delete;
  923. ggml_cuda_concurrent_event & operator=(const ggml_cuda_concurrent_event &) = delete;
  924. explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) {
  925. join_events.resize(n_streams);
  926. for (size_t i = 0; i < join_events.size(); ++i) {
  927. CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming));
  928. }
  929. CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming));
  930. }
  931. ggml_cuda_concurrent_event(ggml_cuda_concurrent_event && other) noexcept
  932. : join_events(std::move(other.join_events))
  933. , fork_event(other.fork_event)
  934. , n_streams(other.n_streams)
  935. , stream_mapping(std::move(other.stream_mapping))
  936. , original_order(std::move(other.original_order))
  937. , join_node(other.join_node) {
  938. other.fork_event = nullptr;
  939. }
  940. // 1. check if any branches write to overlapping memory ranges (except the join node)
  941. // 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event
  942. // we assume all nodes have the same buffer
  943. bool is_valid() const {
  944. std::vector<std::vector<std::pair<int64_t, int64_t>>> write_ranges;
  945. write_ranges.resize(n_streams);
  946. // get join_node's memory range to exclude from overlap checking.
  947. // multiple nodes can use join_node's buffer; we synchronize on the join node.
  948. const ggml_tensor * join_t = join_node->view_src ? join_node->view_src : join_node;
  949. const int64_t join_start = (int64_t) join_t->data;
  950. const int64_t join_end = join_start + ggml_nbytes(join_t);
  951. for (const auto & [tensor, stream] : stream_mapping) {
  952. const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
  953. const int64_t t_start = (int64_t) t->data;
  954. const int64_t t_end = t_start + ggml_nbytes(t);
  955. // skip tensors that overlap with join_node's buffer.
  956. if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
  957. continue;
  958. }
  959. // concurrent streams begin from 1
  960. write_ranges[stream - 1].emplace_back(t_start, t_end);
  961. }
  962. for (int i = 0; i < n_streams; ++i) {
  963. // sorts first by start then by end of write range
  964. std::sort(write_ranges[i].begin(), write_ranges[i].end());
  965. }
  966. bool writes_overlap = false;
  967. bool dependent_srcs = false;
  968. for (const auto & [tensor, stream] : stream_mapping) {
  969. const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
  970. const int64_t t_start = (int64_t) t->data;
  971. const int64_t t_end = t_start + ggml_nbytes(t);
  972. // skip tensors that overlap with join_node's buffer
  973. if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
  974. continue;
  975. }
  976. // check if this buffer's write data overlaps with another stream's
  977. std::pair<int64_t, int64_t> data_range = std::make_pair(t_start, t_end);
  978. for (int i = 0; i < n_streams; ++i) {
  979. if (i == stream - 1) {
  980. continue;
  981. }
  982. auto it = std::lower_bound(write_ranges[i].begin(), write_ranges[i].end(), data_range);
  983. if (it != write_ranges[i].end()) {
  984. const std::pair<int64_t, int64_t> & other = *it;
  985. // std::lower_bound returns the first element where other >= data_range (lexicographically).
  986. // This guarantees other.first >= data_range.first.
  987. // Therefore, overlap occurs iff other.first < data_range.second
  988. // (i.e., the other range starts before this range ends).
  989. if (other.first < data_range.second) {
  990. GGML_LOG_DEBUG("Writes overlap for %s", tensor->name);
  991. writes_overlap = true;
  992. break;
  993. }
  994. }
  995. }
  996. //check if all srcs are either in branch or don't have a branch
  997. for (int i = 0; i < GGML_MAX_SRC; ++i) {
  998. if (!tensor->src[i]) {
  999. continue;
  1000. }
  1001. auto it = stream_mapping.find(tensor->src[i]);
  1002. if (it == stream_mapping.end()) {
  1003. continue;
  1004. }
  1005. if (it->second != stream) {
  1006. dependent_srcs = true;
  1007. break;
  1008. }
  1009. }
  1010. if (dependent_srcs || writes_overlap) {
  1011. break;
  1012. }
  1013. }
  1014. return !writes_overlap && !dependent_srcs;
  1015. }
  1016. ~ggml_cuda_concurrent_event() {
  1017. if (fork_event != nullptr) {
  1018. CUDA_CHECK(cudaEventDestroy(fork_event));
  1019. }
  1020. for (cudaEvent_t e : join_events) {
  1021. if (e != nullptr) {
  1022. CUDA_CHECK(cudaEventDestroy(e));
  1023. }
  1024. }
  1025. }
  1026. };
  1027. struct ggml_cuda_stream_context {
  1028. std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;
  1029. void reset() {
  1030. concurrent_events.clear();
  1031. }
  1032. };
  1033. struct ggml_backend_cuda_context {
  1034. int device;
  1035. std::string name;
  1036. cudaEvent_t copy_event = nullptr;
  1037. cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
  1038. cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
  1039. std::unique_ptr<ggml_cuda_graph> cuda_graph;
  1040. int curr_stream_no = 0;
  1041. explicit ggml_backend_cuda_context(int device) :
  1042. device(device),
  1043. name(GGML_CUDA_NAME + std::to_string(device)) {
  1044. }
  1045. ggml_cuda_stream_context concurrent_stream_context;
  1046. ~ggml_backend_cuda_context();
  1047. cudaStream_t stream(int device, int stream) {
  1048. if (streams[device][stream] == nullptr) {
  1049. ggml_cuda_set_device(device);
  1050. CUDA_CHECK(cudaStreamCreateWithFlags(&streams[device][stream], cudaStreamNonBlocking));
  1051. }
  1052. return streams[device][stream];
  1053. }
  1054. cudaStream_t stream() { return stream(device, curr_stream_no); }
  1055. ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; }
  1056. cublasHandle_t cublas_handle(int device) {
  1057. if (cublas_handles[device] == nullptr) {
  1058. ggml_cuda_set_device(device);
  1059. CUBLAS_CHECK(cublasCreate(&cublas_handles[device]));
  1060. CUBLAS_CHECK(cublasSetMathMode(cublas_handles[device], CUBLAS_TF32_TENSOR_OP_MATH));
  1061. }
  1062. return cublas_handles[device];
  1063. }
  1064. cublasHandle_t cublas_handle() {
  1065. return cublas_handle(device);
  1066. }
  1067. // pool
  1068. std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];
  1069. static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);
  1070. ggml_cuda_pool & pool(int device) {
  1071. if (pools[device][curr_stream_no] == nullptr) {
  1072. pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no);
  1073. }
  1074. return *pools[device][curr_stream_no];
  1075. }
  1076. ggml_cuda_pool & pool() {
  1077. return pool(device);
  1078. }
  1079. };
  1080. struct ggml_cuda_mm_fusion_args_host {
  1081. const ggml_tensor * x_bias = nullptr;
  1082. const ggml_tensor * gate = nullptr;
  1083. const ggml_tensor * gate_bias = nullptr;
  1084. ggml_glu_op glu_op;
  1085. };
  1086. struct ggml_cuda_mm_fusion_args_device {
  1087. const void * x_bias = nullptr;
  1088. const void * gate = nullptr;
  1089. const void * gate_bias = nullptr;
  1090. ggml_glu_op glu_op;
  1091. };