common.cuh 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813
  1. #pragma once
  2. #include "ggml.h"
  3. #include "ggml-cuda.h"
  4. #include <cstdint>
  5. #include <memory>
  6. #if defined(GGML_USE_HIP)
  7. #define GGML_COMMON_DECL_HIP
  8. #define GGML_COMMON_IMPL_HIP
  9. #else
  10. #define GGML_COMMON_DECL_CUDA
  11. #define GGML_COMMON_IMPL_CUDA
  12. #if defined(GGML_USE_MUSA)
  13. #define GGML_COMMON_DECL_MUSA
  14. #define GGML_COMMON_IMPL_MUSA
  15. #endif
  16. #endif
  17. #include "ggml-common.h"
  18. #include <cstdio>
  19. #include <array>
  20. #include <cassert>
  21. #include <cfloat>
  22. #include <string>
  23. #include <vector>
  24. #if defined(GGML_USE_HIP)
  25. #include "vendors/hip.h"
  26. #elif defined(GGML_USE_MUSA)
  27. #include "vendors/musa.h"
  28. #else
  29. #include "vendors/cuda.h"
  30. #endif // defined(GGML_USE_HIP)
  31. #define STRINGIZE_IMPL(...) #__VA_ARGS__
  32. #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
  33. #define WARP_SIZE 32
  34. #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
  35. #define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
  36. #define GGML_CUDA_CC_PASCAL 600
  37. #define GGML_CUDA_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
  38. #define GGML_CUDA_CC_VOLTA 700
  39. #define GGML_CUDA_CC_TURING 750
  40. #define GGML_CUDA_CC_AMPERE 800
  41. #define GGML_CUDA_CC_ADA_LOVELACE 890
  42. #define GGML_CUDA_CC_OFFSET_AMD 0x1000000
  43. #define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
  44. #define GGML_CUDA_CC_IS_NVIDIA(cc) (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
  45. // AMD
  46. // GCN/CDNA, wave size is 64
  47. #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
  48. #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
  49. #define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
  50. #define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
  51. #define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
  52. #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300
  53. // RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32
  54. #define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
  55. #define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
  56. #define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
  57. #define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
  58. #define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD)
  59. #define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
  60. #define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
  61. #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
  62. #define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
  63. #define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
  64. #define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
  65. #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
  66. // Moore Threads
  67. #define GGML_CUDA_MUSA_ARCH_IS_QY1 (__MUSA_ARCH__ <= 210)
  68. #define GGML_CUDA_CC_QY1 (GGML_MUSA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
  69. #define GGML_CUDA_CC_QY2 (GGML_MUSA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
  70. #define GGML_CUDA_CC_NG (GGML_MUSA_CC_OFFSET_MTHREADS + 0x310) // TBD
  71. #define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
  72. #define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
  73. #define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NEXT)
  74. #define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
  75. #ifdef __CUDA_ARCH_LIST__
  76. constexpr bool ggml_cuda_has_arch_impl(int) {
  77. return false;
  78. }
  79. template<class ... Archs>
  80. constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) {
  81. return arch == first || ggml_cuda_has_arch_impl(arch, rest...);
  82. }
  83. constexpr bool ggml_cuda_has_arch(const int arch) {
  84. return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
  85. }
  86. constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) {
  87. if (cur == 0) {
  88. GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch);
  89. }
  90. return cur;
  91. }
  92. template<class ... Archs>
  93. constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) {
  94. if (first <= arch && first > cur) {
  95. return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...);
  96. } else {
  97. return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...);
  98. }
  99. }
  100. constexpr int ggml_cuda_highest_compiled_arch(const int arch) {
  101. return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__);
  102. }
  103. #else
  104. static int ggml_cuda_highest_compiled_arch(const int arch) {
  105. return arch;
  106. }
  107. #endif // __CUDA_ARCH_LIST__
  108. // ---------------------------------------------------------------------------------------------------------
  109. #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
  110. #if defined(_MSC_VER)
  111. #pragma warning(disable: 4244 4267) // possible loss of data
  112. #endif
  113. #define GGML_CUDA_MAX_STREAMS 8
  114. [[noreturn]]
  115. void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
  116. #define CUDA_CHECK_GEN(err, success, error_fn) \
  117. do { \
  118. auto err_ = (err); \
  119. if (err_ != (success)) { \
  120. ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_)); \
  121. } \
  122. } while (0)
  123. #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
  124. #if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
  125. static const char * cublas_get_error_str(const cublasStatus_t err) {
  126. return cublasGetStatusString(err);
  127. }
  128. #else
  129. static const char * cublas_get_error_str(const cublasStatus_t err) {
  130. switch (err) {
  131. case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
  132. case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
  133. case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
  134. case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
  135. case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
  136. case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
  137. case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
  138. case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
  139. case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
  140. default: return "unknown error";
  141. }
  142. }
  143. #endif // CUDART_VERSION >= 12000
  144. #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
  145. #if !defined(GGML_USE_HIP)
  146. static const char * cu_get_error_str(CUresult err) {
  147. const char * err_str;
  148. cuGetErrorString(err, &err_str);
  149. return err_str;
  150. }
  151. #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
  152. #endif
  153. #if CUDART_VERSION >= 11010 || defined(GGML_USE_MUSA)
  154. #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
  155. #else
  156. #define GGML_CUDA_ASSUME(x)
  157. #endif // CUDART_VERSION >= 11010
  158. #ifdef GGML_CUDA_F16
  159. typedef half dfloat; // dequantize float
  160. typedef half2 dfloat2;
  161. #else
  162. typedef float dfloat; // dequantize float
  163. typedef float2 dfloat2;
  164. #endif // GGML_CUDA_F16
  165. #if (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
  166. #define GGML_USE_VMM
  167. #endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
  168. #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
  169. #define FP16_AVAILABLE
  170. #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
  171. #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
  172. #define FAST_FP16_AVAILABLE
  173. #endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
  174. #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
  175. #define FP16_MMA_AVAILABLE
  176. #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
  177. #if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
  178. #define FP16_MMA_AVAILABLE
  179. #endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || defined(RDNA4))
  180. #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
  181. #define NEW_MMA_AVAILABLE
  182. #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
  183. #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
  184. #define CP_ASYNC_AVAILABLE
  185. #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
  186. #if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
  187. #define FLASH_ATTN_AVAILABLE
  188. #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && GGML_CUDA_MUSA_ARCH_IS_QY1)
  189. static bool fp16_available(const int cc) {
  190. return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
  191. }
  192. static bool fast_fp16_available(const int cc) {
  193. return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
  194. }
  195. // To be used for feature selection of external libraries, e.g. cuBLAS.
  196. static bool fast_fp16_hardware_available(const int cc) {
  197. return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_PASCAL && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
  198. }
  199. // Any FP16 tensor core instructions are available for ggml code.
  200. static bool fp16_mma_available(const int cc) {
  201. #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
  202. return false;
  203. #else
  204. return (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ||
  205. GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
  206. #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
  207. }
  208. // To be used for feature selection of external libraries, e.g. cuBLAS.
  209. static bool fp16_mma_hardware_available(const int cc) {
  210. return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) ||
  211. GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
  212. }
  213. // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
  214. static bool new_mma_available(const int cc) {
  215. return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
  216. }
  217. static bool cp_async_available(const int cc) {
  218. return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
  219. }
  220. static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
  221. #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
  222. return __AMDGCN_WAVEFRONT_SIZE;
  223. #else
  224. return 32;
  225. #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
  226. }
  227. [[noreturn]]
  228. static __device__ void no_device_code(
  229. const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
  230. #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
  231. printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
  232. file_name, line, function_name, arch);
  233. GGML_UNUSED(arch_list);
  234. #else
  235. printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
  236. file_name, line, function_name, arch, arch_list);
  237. #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
  238. __trap();
  239. GGML_UNUSED(no_device_code); // suppress unused function warning
  240. #if defined(GGML_USE_MUSA)
  241. __builtin_unreachable();
  242. #endif // defined(GGML_USE_MUSA)
  243. }
  244. #ifdef __CUDA_ARCH__
  245. #define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
  246. #else
  247. #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
  248. #endif // __CUDA_ARCH__
  249. template<int width = WARP_SIZE>
  250. static __device__ __forceinline__ int warp_reduce_sum(int x) {
  251. #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
  252. return __reduce_add_sync(0xffffffff, x);
  253. #else
  254. #pragma unroll
  255. for (int offset = width/2; offset > 0; offset >>= 1) {
  256. x += __shfl_xor_sync(0xffffffff, x, offset, width);
  257. }
  258. return x;
  259. #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
  260. }
  261. template<int width = WARP_SIZE>
  262. static __device__ __forceinline__ float warp_reduce_sum(float x) {
  263. #pragma unroll
  264. for (int offset = width/2; offset > 0; offset >>= 1) {
  265. x += __shfl_xor_sync(0xffffffff, x, offset, width);
  266. }
  267. return x;
  268. }
  269. template<int width = WARP_SIZE>
  270. static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
  271. #pragma unroll
  272. for (int offset = width/2; offset > 0; offset >>= 1) {
  273. a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width);
  274. a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width);
  275. }
  276. return a;
  277. }
  278. template<int width = WARP_SIZE>
  279. static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
  280. #ifdef FP16_AVAILABLE
  281. #pragma unroll
  282. for (int offset = width/2; offset > 0; offset >>= 1) {
  283. a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width));
  284. }
  285. return a;
  286. #else
  287. NO_DEVICE_CODE;
  288. return a;
  289. #endif // FP16_AVAILABLE
  290. }
  291. template<int width = WARP_SIZE>
  292. static __device__ __forceinline__ float warp_reduce_max(float x) {
  293. #pragma unroll
  294. for (int offset = width/2; offset > 0; offset >>= 1) {
  295. x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width));
  296. }
  297. return x;
  298. }
  299. static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
  300. #ifdef FP16_AVAILABLE
  301. #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
  302. return __float2half(fmaxf(__half2float(a), __half2float(b)));
  303. #else
  304. return __hmax(a, b);
  305. #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
  306. #else
  307. NO_DEVICE_CODE;
  308. GGML_UNUSED(b);
  309. return a;
  310. #endif // FP16_AVAILABLE
  311. }
  312. static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
  313. #if defined(GGML_USE_HIP) && HIP_VERSION >= 50700000
  314. return half2(__hmax(a.x, b.x), __hmax(a.y, b.y));
  315. #elif !defined(GGML_USE_HIP) && CUDART_VERSION >= CUDART_HMAX
  316. return __hmax2(a, b);
  317. #elif !defined(GGML_USE_HIP)
  318. half2 ret;
  319. reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b)));
  320. reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
  321. return ret;
  322. #else
  323. GGML_UNUSED(a);
  324. GGML_UNUSED(b);
  325. NO_DEVICE_CODE;
  326. #endif
  327. }
  328. template<int width = WARP_SIZE>
  329. static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
  330. #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
  331. #pragma unroll
  332. for (int offset = width/2; offset > 0; offset >>= 1) {
  333. x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
  334. }
  335. return x;
  336. #else
  337. GGML_UNUSED(x);
  338. NO_DEVICE_CODE;
  339. #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
  340. }
  341. #if CUDART_VERSION < CUDART_HMASK
  342. static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
  343. const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
  344. const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
  345. return mask_low | mask_high;
  346. }
  347. #endif // CUDART_VERSION < CUDART_HMASK
  348. static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
  349. #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
  350. #if defined(CDNA) || defined(RDNA2) || defined(__gfx906__)
  351. c = __builtin_amdgcn_sdot4(a, b, c, false);
  352. #elif defined(RDNA3) || defined(RDNA4)
  353. c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
  354. #elif defined(RDNA1) || defined(__gfx900__)
  355. int tmp1;
  356. int tmp2;
  357. asm("\n \
  358. v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
  359. v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
  360. v_add3_u32 %0, %1, %2, %0 \n \
  361. v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
  362. v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
  363. v_add3_u32 %0, %1, %2, %0 \n \
  364. "
  365. : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
  366. : "v"(a), "v"(b)
  367. );
  368. #else
  369. const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
  370. const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
  371. c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
  372. #endif
  373. return c;
  374. #else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
  375. #if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
  376. return __dp4a(a, b, c);
  377. #else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
  378. const int8_t * a8 = (const int8_t *) &a;
  379. const int8_t * b8 = (const int8_t *) &b;
  380. return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
  381. #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A || defined(GGML_USE_MUSA)
  382. #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
  383. }
  384. // TODO: move to ggml-common.h
  385. static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
  386. typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
  387. static __device__ __forceinline__ float get_alibi_slope(
  388. const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
  389. ) {
  390. if (max_bias <= 0.0f) {
  391. return 1.0f;
  392. }
  393. const float base = h < n_head_log2 ? m0 : m1;
  394. const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
  395. return powf(base, exph);
  396. }
  397. template <ggml_type type>
  398. struct ggml_cuda_type_traits;
  399. template<>
  400. struct ggml_cuda_type_traits<GGML_TYPE_F16> {
  401. static constexpr int qk = 1;
  402. static constexpr int qr = 1;
  403. };
  404. template<>
  405. struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
  406. static constexpr int qk = QK4_0;
  407. static constexpr int qr = QR4_0;
  408. static constexpr int qi = QI4_0;
  409. };
  410. template<>
  411. struct ggml_cuda_type_traits<GGML_TYPE_Q4_1> {
  412. static constexpr int qk = QK4_1;
  413. static constexpr int qr = QR4_1;
  414. static constexpr int qi = QI4_1;
  415. };
  416. template<>
  417. struct ggml_cuda_type_traits<GGML_TYPE_Q5_0> {
  418. static constexpr int qk = QK5_0;
  419. static constexpr int qr = QR5_0;
  420. static constexpr int qi = QI5_0;
  421. };
  422. template<>
  423. struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
  424. static constexpr int qk = QK5_1;
  425. static constexpr int qr = QR5_1;
  426. static constexpr int qi = QI5_1;
  427. };
  428. template<>
  429. struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
  430. static constexpr int qk = QK8_0;
  431. static constexpr int qr = QR8_0;
  432. static constexpr int qi = QI8_0;
  433. };
  434. template<>
  435. struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
  436. static constexpr int qk = QK_K;
  437. static constexpr int qr = QR2_K;
  438. static constexpr int qi = QI2_K;
  439. };
  440. template<>
  441. struct ggml_cuda_type_traits<GGML_TYPE_Q3_K> {
  442. static constexpr int qk = QK_K;
  443. static constexpr int qr = QR3_K;
  444. static constexpr int qi = QI3_K;
  445. };
  446. template<>
  447. struct ggml_cuda_type_traits<GGML_TYPE_Q4_K> {
  448. static constexpr int qk = QK_K;
  449. static constexpr int qr = QR4_K;
  450. static constexpr int qi = QI4_K;
  451. };
  452. template<>
  453. struct ggml_cuda_type_traits<GGML_TYPE_Q5_K> {
  454. static constexpr int qk = QK_K;
  455. static constexpr int qr = QR5_K;
  456. static constexpr int qi = QI5_K;
  457. };
  458. template<>
  459. struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
  460. static constexpr int qk = QK_K;
  461. static constexpr int qr = QR6_K;
  462. static constexpr int qi = QI6_K;
  463. };
  464. template<>
  465. struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
  466. static constexpr int qk = QK_K;
  467. static constexpr int qr = QR2_XXS;
  468. static constexpr int qi = QI2_XXS;
  469. };
  470. template<>
  471. struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XS> {
  472. static constexpr int qk = QK_K;
  473. static constexpr int qr = QR2_XS;
  474. static constexpr int qi = QI2_XS;
  475. };
  476. template<>
  477. struct ggml_cuda_type_traits<GGML_TYPE_IQ2_S> {
  478. static constexpr int qk = QK_K;
  479. static constexpr int qr = QR2_S;
  480. static constexpr int qi = QI2_S;
  481. };
  482. template<>
  483. struct ggml_cuda_type_traits<GGML_TYPE_IQ3_XXS> {
  484. static constexpr int qk = QK_K;
  485. static constexpr int qr = QR3_XXS;
  486. static constexpr int qi = QI3_XXS;
  487. };
  488. template<>
  489. struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> {
  490. static constexpr int qk = QK_K;
  491. static constexpr int qr = QR1_S;
  492. static constexpr int qi = QI1_S;
  493. };
  494. template<>
  495. struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> {
  496. static constexpr int qk = QK_K;
  497. static constexpr int qr = QR1_M;
  498. static constexpr int qi = QI1_M;
  499. };
  500. template<>
  501. struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
  502. static constexpr int qk = QK4_NL;
  503. static constexpr int qr = QR4_NL;
  504. static constexpr int qi = QI4_NL;
  505. };
  506. template<>
  507. struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
  508. static constexpr int qk = QK_K;
  509. static constexpr int qr = QR4_XS;
  510. static constexpr int qi = QI4_XS;
  511. };
  512. template<>
  513. struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
  514. static constexpr int qk = QK_K;
  515. static constexpr int qr = QR3_S;
  516. static constexpr int qi = QI3_S;
  517. };
  518. //////////////////////
  519. struct ggml_cuda_device_info {
  520. int device_count;
  521. struct cuda_device_info {
  522. int cc; // compute capability
  523. int nsm; // number of streaming multiprocessors
  524. size_t smpb; // max. shared memory per block
  525. size_t smpbo; // max. shared memory per block (with opt-in)
  526. bool vmm; // virtual memory support
  527. size_t vmm_granularity; // granularity of virtual memory
  528. size_t total_vram;
  529. int warp_size; // Number of threads in a dispatch
  530. };
  531. cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
  532. std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
  533. };
  534. const ggml_cuda_device_info & ggml_cuda_info();
  535. void ggml_cuda_set_device(int device);
  536. int ggml_cuda_get_device();
  537. struct ggml_cuda_pool {
  538. virtual ~ggml_cuda_pool() = default;
  539. virtual void * alloc(size_t size, size_t * actual_size) = 0;
  540. virtual void free(void * ptr, size_t size) = 0;
  541. };
  542. template<typename T>
  543. struct ggml_cuda_pool_alloc {
  544. ggml_cuda_pool * pool = nullptr;
  545. T * ptr = nullptr;
  546. size_t actual_size = 0;
  547. ggml_cuda_pool_alloc() = default;
  548. explicit ggml_cuda_pool_alloc(ggml_cuda_pool & pool) : pool(&pool) {
  549. }
  550. ggml_cuda_pool_alloc(ggml_cuda_pool & pool, size_t size) : pool(&pool) {
  551. alloc(size);
  552. }
  553. ~ggml_cuda_pool_alloc() {
  554. if (ptr != nullptr) {
  555. pool->free(ptr, actual_size);
  556. }
  557. }
  558. // size is in number of elements
  559. T * alloc(size_t size) {
  560. GGML_ASSERT(pool != nullptr);
  561. GGML_ASSERT(ptr == nullptr);
  562. ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
  563. return ptr;
  564. }
  565. T * alloc(ggml_cuda_pool & pool, size_t size) {
  566. this->pool = &pool;
  567. return alloc(size);
  568. }
  569. T * get() {
  570. return ptr;
  571. }
  572. ggml_cuda_pool_alloc(const ggml_cuda_pool_alloc &) = delete;
  573. ggml_cuda_pool_alloc(ggml_cuda_pool_alloc &&) = delete;
  574. ggml_cuda_pool_alloc& operator=(const ggml_cuda_pool_alloc &) = delete;
  575. ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete;
  576. };
  577. // backend interface
  578. struct ggml_tensor_extra_gpu {
  579. void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
  580. cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
  581. };
  582. #if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
  583. #define USE_CUDA_GRAPH
  584. #endif
  585. struct ggml_graph_node_properties {
  586. void * node_address;
  587. ggml_op node_op;
  588. int64_t ne[GGML_MAX_DIMS];
  589. size_t nb[GGML_MAX_DIMS];
  590. void * src_address[GGML_MAX_SRC];
  591. int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
  592. };
  593. struct ggml_cuda_graph {
  594. #ifdef USE_CUDA_GRAPH
  595. ~ggml_cuda_graph() {
  596. if (instance != nullptr) {
  597. CUDA_CHECK(cudaGraphExecDestroy(instance));
  598. }
  599. if (graph != nullptr) {
  600. CUDA_CHECK(cudaGraphDestroy(graph));
  601. }
  602. }
  603. cudaGraph_t graph = nullptr;
  604. cudaGraphExec_t instance = nullptr;
  605. size_t num_nodes = 0;
  606. std::vector<cudaGraphNode_t> nodes;
  607. std::vector<cudaKernelNodeParams> params;
  608. bool disable_due_to_gpu_arch = false;
  609. bool disable_due_to_too_many_updates = false;
  610. bool disable_due_to_failed_graph_capture = false;
  611. int number_consecutive_updates = 0;
  612. std::vector<ggml_graph_node_properties> ggml_graph_properties;
  613. bool use_cpy_indirection = false;
  614. std::vector<char *> cpy_dest_ptrs;
  615. char ** dest_ptrs_d;
  616. int dest_ptrs_size = 0;
  617. // Index to allow each cpy kernel to be aware of it's position within the graph
  618. // relative to other cpy nodes.
  619. int graph_cpynode_index = -1;
  620. #endif
  621. };
  622. struct ggml_backend_cuda_context {
  623. int device;
  624. std::string name;
  625. cudaEvent_t copy_event = nullptr;
  626. cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
  627. cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
  628. std::unique_ptr<ggml_cuda_graph> cuda_graph;
  629. explicit ggml_backend_cuda_context(int device) :
  630. device(device),
  631. name(GGML_CUDA_NAME + std::to_string(device)) {
  632. }
  633. ~ggml_backend_cuda_context() {
  634. if (copy_event != nullptr) {
  635. CUDA_CHECK(cudaEventDestroy(copy_event));
  636. }
  637. for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
  638. for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
  639. if (streams[i][j] != nullptr) {
  640. CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
  641. }
  642. }
  643. if (cublas_handles[i] != nullptr) {
  644. CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
  645. }
  646. }
  647. }
  648. cudaStream_t stream(int device, int stream) {
  649. if (streams[device][stream] == nullptr) {
  650. ggml_cuda_set_device(device);
  651. CUDA_CHECK(cudaStreamCreateWithFlags(&streams[device][stream], cudaStreamNonBlocking));
  652. }
  653. return streams[device][stream];
  654. }
  655. cudaStream_t stream() {
  656. return stream(device, 0);
  657. }
  658. cublasHandle_t cublas_handle(int device) {
  659. if (cublas_handles[device] == nullptr) {
  660. ggml_cuda_set_device(device);
  661. CUBLAS_CHECK(cublasCreate(&cublas_handles[device]));
  662. CUBLAS_CHECK(cublasSetMathMode(cublas_handles[device], CUBLAS_TF32_TENSOR_OP_MATH));
  663. }
  664. return cublas_handles[device];
  665. }
  666. cublasHandle_t cublas_handle() {
  667. return cublas_handle(device);
  668. }
  669. // pool
  670. std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
  671. static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
  672. ggml_cuda_pool & pool(int device) {
  673. if (pools[device] == nullptr) {
  674. pools[device] = new_pool_for_device(device);
  675. }
  676. return *pools[device];
  677. }
  678. ggml_cuda_pool & pool() {
  679. return pool(device);
  680. }
  681. };