common.cuh 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669
  1. #pragma once
  2. #include "ggml.h"
  3. #include "ggml-cuda.h"
  4. #include <cstdint>
  5. #include <memory>
  6. #if defined(GGML_USE_HIPBLAS)
  7. #define GGML_COMMON_DECL_HIP
  8. #define GGML_COMMON_IMPL_HIP
  9. #else
  10. #define GGML_COMMON_DECL_CUDA
  11. #define GGML_COMMON_IMPL_CUDA
  12. #if defined(GGML_USE_MUSA)
  13. #define GGML_COMMON_DECL_MUSA
  14. #define GGML_COMMON_IMPL_MUSA
  15. #endif
  16. #endif
  17. #include "ggml-common.h"
  18. #include <cstdio>
  19. #include <array>
  20. #include <cassert>
  21. #include <cfloat>
  22. #include <string>
  23. #include <vector>
  24. #if defined(GGML_USE_HIPBLAS)
  25. #include "vendors/hip.h"
  26. #elif defined(GGML_USE_MUSA)
  27. #include "vendors/musa.h"
  28. #else
  29. #include "vendors/cuda.h"
  30. #endif // defined(GGML_USE_HIPBLAS)
  31. #define STRINGIZE_IMPL(...) #__VA_ARGS__
  32. #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
  33. #define WARP_SIZE 32
  34. #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
  35. #define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
  36. #define CC_PASCAL 600
  37. #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
  38. #define CC_VOLTA 700
  39. #define CC_TURING 750
  40. #define CC_AMPERE 800
  41. #define CC_OFFSET_AMD 1000000
  42. #define CC_RDNA1 (CC_OFFSET_AMD + 1010)
  43. #define CC_RDNA2 (CC_OFFSET_AMD + 1030)
  44. #define CC_RDNA3 (CC_OFFSET_AMD + 1100)
  45. #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
  46. #if defined(_MSC_VER)
  47. #pragma warning(disable: 4244 4267) // possible loss of data
  48. #endif
  49. #define GGML_CUDA_MAX_STREAMS 8
  50. [[noreturn]]
  51. void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
  52. #define CUDA_CHECK_GEN(err, success, error_fn) \
  53. do { \
  54. auto err_ = (err); \
  55. if (err_ != (success)) { \
  56. ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_)); \
  57. } \
  58. } while (0)
  59. #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
  60. #if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
  61. static const char * cublas_get_error_str(const cublasStatus_t err) {
  62. return cublasGetStatusString(err);
  63. }
  64. #else
  65. static const char * cublas_get_error_str(const cublasStatus_t err) {
  66. switch (err) {
  67. case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
  68. case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
  69. case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
  70. case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
  71. case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
  72. case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
  73. case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
  74. case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
  75. case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
  76. default: return "unknown error";
  77. }
  78. }
  79. #endif // CUDART_VERSION >= 12000
  80. #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
  81. #if !defined(GGML_USE_HIPBLAS)
  82. static const char * cu_get_error_str(CUresult err) {
  83. const char * err_str;
  84. cuGetErrorString(err, &err_str);
  85. return err_str;
  86. }
  87. #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
  88. #endif
  89. #if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
  90. #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
  91. #else
  92. #define GGML_CUDA_ASSUME(x)
  93. #endif // CUDART_VERSION >= 11100
  94. #ifdef GGML_CUDA_F16
  95. typedef half dfloat; // dequantize float
  96. typedef half2 dfloat2;
  97. #else
  98. typedef float dfloat; // dequantize float
  99. typedef float2 dfloat2;
  100. #endif // GGML_CUDA_F16
  101. #if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
  102. #define FP16_AVAILABLE
  103. #endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
  104. #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
  105. #define FAST_FP16_AVAILABLE
  106. #endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
  107. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
  108. #define FP16_MMA_AVAILABLE
  109. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
  110. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
  111. #define INT8_MMA_AVAILABLE
  112. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
  113. static constexpr bool fast_fp16_available(const int cc) {
  114. return cc >= CC_PASCAL && cc != 610;
  115. }
  116. static constexpr bool fp16_mma_available(const int cc) {
  117. return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
  118. }
  119. static constexpr bool int8_mma_available(const int cc) {
  120. return cc < CC_OFFSET_AMD && cc >= CC_TURING;
  121. }
  122. [[noreturn]]
  123. static __device__ void no_device_code(
  124. const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
  125. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  126. printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
  127. file_name, line, function_name, arch);
  128. GGML_UNUSED(arch_list);
  129. #else
  130. printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
  131. file_name, line, function_name, arch, arch_list);
  132. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  133. __trap();
  134. GGML_UNUSED(no_device_code); // suppress unused function warning
  135. }
  136. #ifdef __CUDA_ARCH__
  137. #define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
  138. #else
  139. #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
  140. #endif // __CUDA_ARCH__
  141. static __device__ __forceinline__ float warp_reduce_sum(float x) {
  142. #pragma unroll
  143. for (int mask = 16; mask > 0; mask >>= 1) {
  144. x += __shfl_xor_sync(0xffffffff, x, mask, 32);
  145. }
  146. return x;
  147. }
  148. static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
  149. #pragma unroll
  150. for (int mask = 16; mask > 0; mask >>= 1) {
  151. a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
  152. a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
  153. }
  154. return a;
  155. }
  156. static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
  157. #ifdef FP16_AVAILABLE
  158. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  159. #pragma unroll
  160. for (int mask = 16; mask > 0; mask >>= 1) {
  161. const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
  162. reinterpret_cast<half&>(a.x) += __low2half(a_other);
  163. reinterpret_cast<half&>(a.y) += __high2half(a_other);
  164. }
  165. return a;
  166. #else
  167. #pragma unroll
  168. for (int mask = 16; mask > 0; mask >>= 1) {
  169. a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
  170. }
  171. return a;
  172. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  173. #else
  174. NO_DEVICE_CODE;
  175. return a;
  176. #endif // FP16_AVAILABLE
  177. }
  178. static __device__ __forceinline__ float warp_reduce_max(float x) {
  179. #pragma unroll
  180. for (int mask = 16; mask > 0; mask >>= 1) {
  181. x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
  182. }
  183. return x;
  184. }
  185. static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
  186. #ifdef FP16_AVAILABLE
  187. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
  188. return __float2half(fmaxf(__half2float(a), __half2float(b)));
  189. #else
  190. return __hmax(a, b);
  191. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
  192. #else
  193. NO_DEVICE_CODE;
  194. GGML_UNUSED(b);
  195. return a;
  196. #endif // FP16_AVAILABLE
  197. }
  198. static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
  199. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  200. #if CUDART_VERSION >= CUDART_HMAX
  201. return __hmax2(a, b);
  202. #else
  203. half2 ret;
  204. reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b)));
  205. reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
  206. return ret;
  207. #endif // CUDART_VERSION >= CUDART_HMAX
  208. #else
  209. GGML_UNUSED(a);
  210. GGML_UNUSED(b);
  211. NO_DEVICE_CODE;
  212. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  213. }
  214. static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
  215. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
  216. #pragma unroll
  217. for (int mask = 16; mask > 0; mask >>= 1) {
  218. x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
  219. }
  220. return x;
  221. #else
  222. GGML_UNUSED(x);
  223. NO_DEVICE_CODE;
  224. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
  225. }
  226. #if CUDART_VERSION < CUDART_HMASK
  227. static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
  228. const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
  229. const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
  230. return mask_low | mask_high;
  231. }
  232. #endif // CUDART_VERSION < CUDART_HMASK
  233. static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
  234. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  235. #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
  236. c = __builtin_amdgcn_sdot4(a, b, c, false);
  237. #elif defined(RDNA3)
  238. c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
  239. #elif defined(__gfx1010__) || defined(__gfx900__)
  240. int tmp1;
  241. int tmp2;
  242. asm("\n \
  243. v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
  244. v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
  245. v_add3_u32 %0, %1, %2, %0 \n \
  246. v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
  247. v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
  248. v_add3_u32 %0, %1, %2, %0 \n \
  249. "
  250. : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
  251. : "v"(a), "v"(b)
  252. );
  253. #else
  254. const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
  255. const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
  256. c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
  257. #endif
  258. return c;
  259. #else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  260. #if __CUDA_ARCH__ >= MIN_CC_DP4A
  261. return __dp4a(a, b, c);
  262. #else // __CUDA_ARCH__ >= MIN_CC_DP4A
  263. const int8_t * a8 = (const int8_t *) &a;
  264. const int8_t * b8 = (const int8_t *) &b;
  265. return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
  266. #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
  267. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  268. }
  269. // TODO: move to ggml-common.h
  270. static constexpr __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
  271. typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
  272. static __device__ __forceinline__ float get_alibi_slope(
  273. const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
  274. ) {
  275. if (max_bias <= 0.0f) {
  276. return 1.0f;
  277. }
  278. const float base = h < n_head_log2 ? m0 : m1;
  279. const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
  280. return powf(base, exph);
  281. }
  282. template <ggml_type type>
  283. struct ggml_cuda_type_traits;
  284. template<>
  285. struct ggml_cuda_type_traits<GGML_TYPE_F16> {
  286. static constexpr int qk = 1;
  287. static constexpr int qr = 1;
  288. };
  289. template<>
  290. struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
  291. static constexpr int qk = QK4_0;
  292. static constexpr int qr = QR4_0;
  293. static constexpr int qi = QI4_0;
  294. };
  295. template<>
  296. struct ggml_cuda_type_traits<GGML_TYPE_Q4_1> {
  297. static constexpr int qk = QK4_1;
  298. static constexpr int qr = QR4_1;
  299. static constexpr int qi = QI4_1;
  300. };
  301. template<>
  302. struct ggml_cuda_type_traits<GGML_TYPE_Q5_0> {
  303. static constexpr int qk = QK5_0;
  304. static constexpr int qr = QR5_0;
  305. static constexpr int qi = QI5_0;
  306. };
  307. template<>
  308. struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
  309. static constexpr int qk = QK5_1;
  310. static constexpr int qr = QR5_1;
  311. static constexpr int qi = QI5_1;
  312. };
  313. template<>
  314. struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
  315. static constexpr int qk = QK8_0;
  316. static constexpr int qr = QR8_0;
  317. static constexpr int qi = QI8_0;
  318. };
  319. template<>
  320. struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
  321. static constexpr int qk = QK_K;
  322. static constexpr int qr = QR2_K;
  323. static constexpr int qi = QI2_K;
  324. };
  325. template<>
  326. struct ggml_cuda_type_traits<GGML_TYPE_Q3_K> {
  327. static constexpr int qk = QK_K;
  328. static constexpr int qr = QR3_K;
  329. static constexpr int qi = QI3_K;
  330. };
  331. template<>
  332. struct ggml_cuda_type_traits<GGML_TYPE_Q4_K> {
  333. static constexpr int qk = QK_K;
  334. static constexpr int qr = QR4_K;
  335. static constexpr int qi = QI4_K;
  336. };
  337. template<>
  338. struct ggml_cuda_type_traits<GGML_TYPE_Q5_K> {
  339. static constexpr int qk = QK_K;
  340. static constexpr int qr = QR5_K;
  341. static constexpr int qi = QI5_K;
  342. };
  343. template<>
  344. struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
  345. static constexpr int qk = QK_K;
  346. static constexpr int qr = QR6_K;
  347. static constexpr int qi = QI6_K;
  348. };
  349. template<>
  350. struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
  351. static constexpr int qk = QK_K;
  352. static constexpr int qr = QR2_XXS;
  353. static constexpr int qi = QI2_XXS;
  354. };
  355. template<>
  356. struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XS> {
  357. static constexpr int qk = QK_K;
  358. static constexpr int qr = QR2_XS;
  359. static constexpr int qi = QI2_XS;
  360. };
  361. template<>
  362. struct ggml_cuda_type_traits<GGML_TYPE_IQ2_S> {
  363. static constexpr int qk = QK_K;
  364. static constexpr int qr = QR2_S;
  365. static constexpr int qi = QI2_S;
  366. };
  367. template<>
  368. struct ggml_cuda_type_traits<GGML_TYPE_IQ3_XXS> {
  369. static constexpr int qk = QK_K;
  370. static constexpr int qr = QR3_XXS;
  371. static constexpr int qi = QI3_XXS;
  372. };
  373. template<>
  374. struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> {
  375. static constexpr int qk = QK_K;
  376. static constexpr int qr = QR1_S;
  377. static constexpr int qi = QI1_S;
  378. };
  379. template<>
  380. struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> {
  381. static constexpr int qk = QK_K;
  382. static constexpr int qr = QR1_M;
  383. static constexpr int qi = QI1_M;
  384. };
  385. template<>
  386. struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
  387. static constexpr int qk = QK4_NL;
  388. static constexpr int qr = QR4_NL;
  389. static constexpr int qi = QI4_NL;
  390. };
  391. template<>
  392. struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
  393. static constexpr int qk = QK_K;
  394. static constexpr int qr = QR4_XS;
  395. static constexpr int qi = QI4_XS;
  396. };
  397. template<>
  398. struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
  399. static constexpr int qk = QK_K;
  400. static constexpr int qr = QR3_S;
  401. static constexpr int qi = QI3_S;
  402. };
  403. //////////////////////
  404. struct ggml_cuda_device_info {
  405. int device_count;
  406. struct cuda_device_info {
  407. int cc; // compute capability
  408. int nsm; // number of streaming multiprocessors
  409. size_t smpb; // max. shared memory per block
  410. size_t smpbo; // max. shared memory per block (with opt-in)
  411. bool vmm; // virtual memory support
  412. size_t vmm_granularity; // granularity of virtual memory
  413. size_t total_vram;
  414. };
  415. cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
  416. std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
  417. };
  418. const ggml_cuda_device_info & ggml_cuda_info();
  419. void ggml_cuda_set_device(int device);
  420. int ggml_cuda_get_device();
  421. struct ggml_cuda_pool {
  422. virtual ~ggml_cuda_pool() = default;
  423. virtual void * alloc(size_t size, size_t * actual_size) = 0;
  424. virtual void free(void * ptr, size_t size) = 0;
  425. };
  426. template<typename T>
  427. struct ggml_cuda_pool_alloc {
  428. ggml_cuda_pool * pool = nullptr;
  429. T * ptr = nullptr;
  430. size_t actual_size = 0;
  431. ggml_cuda_pool_alloc() = default;
  432. explicit ggml_cuda_pool_alloc(ggml_cuda_pool & pool) : pool(&pool) {
  433. }
  434. ggml_cuda_pool_alloc(ggml_cuda_pool & pool, size_t size) : pool(&pool) {
  435. alloc(size);
  436. }
  437. ~ggml_cuda_pool_alloc() {
  438. if (ptr != nullptr) {
  439. pool->free(ptr, actual_size);
  440. }
  441. }
  442. // size is in number of elements
  443. T * alloc(size_t size) {
  444. GGML_ASSERT(pool != nullptr);
  445. GGML_ASSERT(ptr == nullptr);
  446. ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
  447. return ptr;
  448. }
  449. T * alloc(ggml_cuda_pool & pool, size_t size) {
  450. this->pool = &pool;
  451. return alloc(size);
  452. }
  453. T * get() {
  454. return ptr;
  455. }
  456. ggml_cuda_pool_alloc(const ggml_cuda_pool_alloc &) = delete;
  457. ggml_cuda_pool_alloc(ggml_cuda_pool_alloc &&) = delete;
  458. ggml_cuda_pool_alloc& operator=(const ggml_cuda_pool_alloc &) = delete;
  459. ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete;
  460. };
  461. // backend interface
  462. struct ggml_tensor_extra_gpu {
  463. void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
  464. cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
  465. };
  466. #if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)
  467. #define USE_CUDA_GRAPH
  468. #endif
  469. struct ggml_graph_node_properties {
  470. void * node_address;
  471. ggml_op node_op;
  472. int64_t ne[GGML_MAX_DIMS];
  473. size_t nb[GGML_MAX_DIMS];
  474. void * src_address[GGML_MAX_SRC];
  475. };
  476. struct ggml_cuda_graph {
  477. #ifdef USE_CUDA_GRAPH
  478. ~ggml_cuda_graph() {
  479. if (instance != nullptr) {
  480. CUDA_CHECK(cudaGraphExecDestroy(instance));
  481. }
  482. if (graph != nullptr) {
  483. CUDA_CHECK(cudaGraphDestroy(graph));
  484. }
  485. }
  486. cudaGraph_t graph = nullptr;
  487. cudaGraphExec_t instance = nullptr;
  488. size_t num_nodes = 0;
  489. std::vector<cudaGraphNode_t> nodes;
  490. std::vector<cudaKernelNodeParams> params;
  491. bool disable_due_to_gpu_arch = false;
  492. bool disable_due_to_too_many_updates = false;
  493. bool disable_due_to_failed_graph_capture = false;
  494. int number_consecutive_updates = 0;
  495. std::vector<ggml_graph_node_properties> ggml_graph_properties;
  496. std::vector<char **> updated_kernel_arg;
  497. #endif
  498. };
  499. struct ggml_backend_cuda_context {
  500. int device;
  501. std::string name;
  502. cudaEvent_t copy_event = nullptr;
  503. cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
  504. cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
  505. std::unique_ptr<ggml_cuda_graph> cuda_graph;
  506. explicit ggml_backend_cuda_context(int device) :
  507. device(device),
  508. name(GGML_CUDA_NAME + std::to_string(device)) {
  509. }
  510. ~ggml_backend_cuda_context() {
  511. if (copy_event != nullptr) {
  512. CUDA_CHECK(cudaEventDestroy(copy_event));
  513. }
  514. for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
  515. for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
  516. if (streams[i][j] != nullptr) {
  517. CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
  518. }
  519. }
  520. if (cublas_handles[i] != nullptr) {
  521. CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
  522. }
  523. }
  524. }
  525. cudaStream_t stream(int device, int stream) {
  526. if (streams[device][stream] == nullptr) {
  527. ggml_cuda_set_device(device);
  528. CUDA_CHECK(cudaStreamCreateWithFlags(&streams[device][stream], cudaStreamNonBlocking));
  529. }
  530. return streams[device][stream];
  531. }
  532. cudaStream_t stream() {
  533. return stream(device, 0);
  534. }
  535. cublasHandle_t cublas_handle(int device) {
  536. if (cublas_handles[device] == nullptr) {
  537. ggml_cuda_set_device(device);
  538. CUBLAS_CHECK(cublasCreate(&cublas_handles[device]));
  539. CUBLAS_CHECK(cublasSetMathMode(cublas_handles[device], CUBLAS_TF32_TENSOR_OP_MATH));
  540. }
  541. return cublas_handles[device];
  542. }
  543. cublasHandle_t cublas_handle() {
  544. return cublas_handle(device);
  545. }
  546. // pool
  547. std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
  548. static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
  549. ggml_cuda_pool & pool(int device) {
  550. if (pools[device] == nullptr) {
  551. pools[device] = new_pool_for_device(device);
  552. }
  553. return *pools[device];
  554. }
  555. ggml_cuda_pool & pool() {
  556. return pool(device);
  557. }
  558. };