common.cuh 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. #pragma once
  2. #include "ggml.h"
  3. #include "ggml-cuda.h"
  4. #include <memory>
  5. #if defined(GGML_USE_HIPBLAS)
  6. #define GGML_COMMON_DECL_HIP
  7. #define GGML_COMMON_IMPL_HIP
  8. #else
  9. #define GGML_COMMON_DECL_CUDA
  10. #define GGML_COMMON_IMPL_CUDA
  11. #endif
  12. #include "ggml-common.h"
  13. #include <cstdio>
  14. #include <array>
  15. #include <cassert>
  16. #include <cfloat>
  17. #include <string>
  18. #if defined(GGML_USE_HIPBLAS)
  19. #include <hip/hip_runtime.h>
  20. #include <hipblas/hipblas.h>
  21. #include <hip/hip_fp16.h>
  22. #ifdef __HIP_PLATFORM_AMD__
  23. // for rocblas_initialize()
  24. #include "rocblas/rocblas.h"
  25. #endif // __HIP_PLATFORM_AMD__
  26. #define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
  27. #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
  28. #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
  29. #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
  30. #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
  31. #define CUBLAS_OP_N HIPBLAS_OP_N
  32. #define CUBLAS_OP_T HIPBLAS_OP_T
  33. #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
  34. #define CUBLAS_TF32_TENSOR_OP_MATH 0
  35. #define CUDA_R_16F HIPBLAS_R_16F
  36. #define CUDA_R_32F HIPBLAS_R_32F
  37. #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
  38. #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
  39. #define cublasCreate hipblasCreate
  40. #define cublasDestroy hipblasDestroy
  41. #define cublasGemmEx hipblasGemmEx
  42. #define cublasGemmBatchedEx hipblasGemmBatchedEx
  43. #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
  44. #define cublasHandle_t hipblasHandle_t
  45. #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
  46. #define cublasSetStream hipblasSetStream
  47. #define cublasSgemm hipblasSgemm
  48. #define cublasStatus_t hipblasStatus_t
  49. #define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
  50. #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
  51. #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
  52. #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
  53. #define cudaDeviceProp hipDeviceProp_t
  54. #define cudaDeviceSynchronize hipDeviceSynchronize
  55. #define cudaError_t hipError_t
  56. #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
  57. #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
  58. #define cudaEventCreateWithFlags hipEventCreateWithFlags
  59. #define cudaEventDisableTiming hipEventDisableTiming
  60. #define cudaEventRecord hipEventRecord
  61. #define cudaEventSynchronize hipEventSynchronize
  62. #define cudaEvent_t hipEvent_t
  63. #define cudaEventDestroy hipEventDestroy
  64. #define cudaFree hipFree
  65. #define cudaFreeHost hipHostFree
  66. #define cudaGetDevice hipGetDevice
  67. #define cudaGetDeviceCount hipGetDeviceCount
  68. #define cudaGetDeviceProperties hipGetDeviceProperties
  69. #define cudaGetErrorString hipGetErrorString
  70. #define cudaGetLastError hipGetLastError
  71. #define cudaHostRegister hipHostRegister
  72. #define cudaHostRegisterPortable hipHostRegisterPortable
  73. #define cudaHostRegisterReadOnly hipHostRegisterReadOnly
  74. #define cudaHostUnregister hipHostUnregister
  75. #define cudaLaunchHostFunc hipLaunchHostFunc
  76. #ifdef GGML_HIP_UMA
  77. #define cudaMalloc hipMallocManaged
  78. #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size)
  79. #else
  80. #define cudaMalloc hipMalloc
  81. #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
  82. #endif
  83. #define cudaMemcpy hipMemcpy
  84. #define cudaMemcpyAsync hipMemcpyAsync
  85. #define cudaMemcpyPeerAsync hipMemcpyPeerAsync
  86. #define cudaMemcpy2DAsync hipMemcpy2DAsync
  87. #define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
  88. #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
  89. #define cudaMemcpyHostToDevice hipMemcpyHostToDevice
  90. #define cudaMemcpyKind hipMemcpyKind
  91. #define cudaMemset hipMemset
  92. #define cudaMemsetAsync hipMemsetAsync
  93. #define cudaMemGetInfo hipMemGetInfo
  94. #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
  95. #define cudaSetDevice hipSetDevice
  96. #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
  97. #define cudaStreamDestroy hipStreamDestroy
  98. #define cudaStreamFireAndForget hipStreamFireAndForget
  99. #define cudaStreamNonBlocking hipStreamNonBlocking
  100. #define cudaStreamPerThread hipStreamPerThread
  101. #define cudaStreamSynchronize hipStreamSynchronize
  102. #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
  103. #define cudaStream_t hipStream_t
  104. #define cudaSuccess hipSuccess
  105. #define __trap abort
  106. #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
  107. #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
  108. #define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
  109. #define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
  110. #define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
  111. #define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
  112. #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
  113. #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
  114. #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
  115. #else
  116. #include <cuda_runtime.h>
  117. #include <cuda.h>
  118. #include <cublas_v2.h>
  119. #include <cuda_fp16.h>
  120. #if CUDART_VERSION < 11020
  121. #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
  122. #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
  123. #define CUBLAS_COMPUTE_16F CUDA_R_16F
  124. #define CUBLAS_COMPUTE_32F CUDA_R_32F
  125. #define cublasComputeType_t cudaDataType_t
  126. #endif // CUDART_VERSION < 11020
  127. #endif // defined(GGML_USE_HIPBLAS)
  128. #define STRINGIZE_IMPL(...) #__VA_ARGS__
  129. #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
  130. #define WARP_SIZE 32
  131. #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
  132. #define CC_PASCAL 600
  133. #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
  134. #define CC_VOLTA 700
  135. #define CC_OFFSET_AMD 1000000
  136. #define CC_RDNA1 (CC_OFFSET_AMD + 1010)
  137. #define CC_RDNA2 (CC_OFFSET_AMD + 1030)
  138. #define CC_RDNA3 (CC_OFFSET_AMD + 1100)
  139. // define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
  140. // on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
  141. // for large computational tasks. the drawback is that this requires some extra amount of VRAM:
  142. // - 7B quantum model: +100-200 MB
  143. // - 13B quantum model: +200-400 MB
  144. //
  145. //#define GGML_CUDA_FORCE_MMQ
  146. // TODO: improve this to be correct for more hardware
  147. // for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
  148. #if !defined(GGML_CUDA_FORCE_MMQ)
  149. #define CUDA_USE_TENSOR_CORES
  150. #endif
  151. #define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels
  152. #define MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available
  153. #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
  154. #if defined(_MSC_VER)
  155. #pragma warning(disable: 4244 4267) // possible loss of data
  156. #endif
  157. #define GGML_CUDA_MAX_STREAMS 8
  158. [[noreturn]]
  159. void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
  160. #define CUDA_CHECK_GEN(err, success, error_fn) \
  161. do { \
  162. auto err_ = (err); \
  163. if (err_ != (success)) { \
  164. ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_)); \
  165. } \
  166. } while (0)
  167. #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
  168. #if CUDART_VERSION >= 12000
  169. static const char * cublas_get_error_str(const cublasStatus_t err) {
  170. return cublasGetStatusString(err);
  171. }
  172. #else
  173. static const char * cublas_get_error_str(const cublasStatus_t err) {
  174. switch (err) {
  175. case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
  176. case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
  177. case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
  178. case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
  179. case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
  180. case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
  181. case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
  182. case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
  183. case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
  184. default: return "unknown error";
  185. }
  186. }
  187. #endif // CUDART_VERSION >= 12000
  188. #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
  189. #if !defined(GGML_USE_HIPBLAS)
  190. static const char * cu_get_error_str(CUresult err) {
  191. const char * err_str;
  192. cuGetErrorString(err, &err_str);
  193. return err_str;
  194. }
  195. #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
  196. #endif
  197. #if CUDART_VERSION >= 11100
  198. #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
  199. #else
  200. #define GGML_CUDA_ASSUME(x)
  201. #endif // CUDART_VERSION >= 11100
  202. #ifdef GGML_CUDA_F16
  203. typedef half dfloat; // dequantize float
  204. typedef half2 dfloat2;
  205. #else
  206. typedef float dfloat; // dequantize float
  207. typedef float2 dfloat2;
  208. #endif //GGML_CUDA_F16
  209. [[noreturn]]
  210. static __device__ void no_device_code(
  211. const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
  212. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  213. printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
  214. file_name, line, function_name, arch);
  215. GGML_UNUSED(arch_list);
  216. #else
  217. printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
  218. file_name, line, function_name, arch, arch_list);
  219. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  220. __trap();
  221. GGML_UNUSED(no_device_code); // suppress unused function warning
  222. }
  223. #ifdef __CUDA_ARCH__
  224. #define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
  225. #else
  226. #define NO_DEVICE_CODE //GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.")
  227. #endif // __CUDA_ARCH__
  228. static __device__ __forceinline__ float warp_reduce_sum(float x) {
  229. #pragma unroll
  230. for (int mask = 16; mask > 0; mask >>= 1) {
  231. x += __shfl_xor_sync(0xffffffff, x, mask, 32);
  232. }
  233. return x;
  234. }
  235. static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
  236. #pragma unroll
  237. for (int mask = 16; mask > 0; mask >>= 1) {
  238. a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
  239. a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
  240. }
  241. return a;
  242. }
  243. #ifdef GGML_CUDA_F16
  244. static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
  245. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
  246. #pragma unroll
  247. for (int mask = 16; mask > 0; mask >>= 1) {
  248. a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
  249. }
  250. return a;
  251. #else
  252. GGML_UNUSED(a);
  253. NO_DEVICE_CODE;
  254. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
  255. }
  256. #endif // GGML_CUDA_F16
  257. static __device__ __forceinline__ float warp_reduce_max(float x) {
  258. #pragma unroll
  259. for (int mask = 16; mask > 0; mask >>= 1) {
  260. x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
  261. }
  262. return x;
  263. }
  264. //static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
  265. //#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
  266. //#pragma unroll
  267. // for (int mask = 16; mask > 0; mask >>= 1) {
  268. // x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
  269. // }
  270. // return x;
  271. //#else
  272. // GGML_UNUSED(x);
  273. // NO_DEVICE_CODE;
  274. //#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
  275. //}
  276. #if defined(GGML_USE_HIPBLAS)
  277. #define __CUDA_ARCH__ 1300
  278. #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
  279. defined(__gfx1150__) || defined(__gfx1151__)
  280. #define RDNA3
  281. #endif
  282. #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
  283. defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
  284. #define RDNA2
  285. #endif
  286. #ifndef __has_builtin
  287. #define __has_builtin(x) 0
  288. #endif
  289. typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
  290. typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
  291. static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
  292. const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
  293. const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
  294. #if __has_builtin(__builtin_elementwise_sub_sat)
  295. const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
  296. return reinterpret_cast<const int &>(c);
  297. #else
  298. int8x4_t c;
  299. int16_t tmp;
  300. #pragma unroll
  301. for (int i = 0; i < 4; i++) {
  302. tmp = va[i] - vb[i];
  303. if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
  304. if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
  305. c[i] = tmp;
  306. }
  307. return reinterpret_cast<int &>(c);
  308. #endif // __has_builtin(__builtin_elementwise_sub_sat)
  309. }
  310. static __device__ __forceinline__ int __vsub4(const int a, const int b) {
  311. return __vsubss4(a, b);
  312. }
  313. static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
  314. const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
  315. const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
  316. unsigned int c;
  317. uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
  318. #pragma unroll
  319. for (int i = 0; i < 4; ++i) {
  320. vc[i] = va[i] == vb[i] ? 0xff : 0x00;
  321. }
  322. return c;
  323. }
  324. static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
  325. #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
  326. c = __builtin_amdgcn_sdot4(a, b, c, false);
  327. #elif defined(RDNA3)
  328. c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
  329. #elif defined(__gfx1010__) || defined(__gfx900__)
  330. int tmp1;
  331. int tmp2;
  332. asm("\n \
  333. v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
  334. v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
  335. v_add3_u32 %0, %1, %2, %0 \n \
  336. v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
  337. v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
  338. v_add3_u32 %0, %1, %2, %0 \n \
  339. "
  340. : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
  341. : "v"(a), "v"(b)
  342. );
  343. #else
  344. const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
  345. const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
  346. c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
  347. #endif
  348. return c;
  349. }
  350. #endif // defined(GGML_USE_HIPBLAS)
  351. // TODO: move to ggml-common.h
  352. static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
  353. typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
  354. //////////////////////
  355. struct ggml_cuda_device_info {
  356. int device_count;
  357. struct cuda_device_info {
  358. int cc; // compute capability
  359. size_t smpb; // max. shared memory per block
  360. bool vmm; // virtual memory support
  361. size_t vmm_granularity; // granularity of virtual memory
  362. size_t total_vram;
  363. };
  364. cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
  365. std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
  366. };
  367. const ggml_cuda_device_info & ggml_cuda_info();
  368. void ggml_cuda_set_device(int device);
  369. int ggml_cuda_get_device();
  370. struct ggml_cuda_pool {
  371. virtual ~ggml_cuda_pool() = default;
  372. virtual void * alloc(size_t size, size_t * actual_size) = 0;
  373. virtual void free(void * ptr, size_t size) = 0;
  374. };
  375. template<typename T>
  376. struct ggml_cuda_pool_alloc {
  377. ggml_cuda_pool * pool = nullptr;
  378. T * ptr = nullptr;
  379. size_t actual_size = 0;
  380. ggml_cuda_pool_alloc() = default;
  381. explicit ggml_cuda_pool_alloc(ggml_cuda_pool & pool) : pool(&pool) {
  382. }
  383. ggml_cuda_pool_alloc(ggml_cuda_pool & pool, size_t size) : pool(&pool) {
  384. alloc(size);
  385. }
  386. ~ggml_cuda_pool_alloc() {
  387. if (ptr != nullptr) {
  388. pool->free(ptr, actual_size);
  389. }
  390. }
  391. // size is in number of elements
  392. T * alloc(size_t size) {
  393. GGML_ASSERT(pool != nullptr);
  394. GGML_ASSERT(ptr == nullptr);
  395. ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
  396. return ptr;
  397. }
  398. T * alloc(ggml_cuda_pool & pool, size_t size) {
  399. this->pool = &pool;
  400. return alloc(size);
  401. }
  402. T * get() {
  403. return ptr;
  404. }
  405. ggml_cuda_pool_alloc(const ggml_cuda_pool_alloc &) = delete;
  406. ggml_cuda_pool_alloc(ggml_cuda_pool_alloc &&) = delete;
  407. ggml_cuda_pool_alloc& operator=(const ggml_cuda_pool_alloc &) = delete;
  408. ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete;
  409. };
  410. // backend interface
  411. struct ggml_tensor_extra_gpu {
  412. void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
  413. cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
  414. };
  415. struct ggml_backend_cuda_context {
  416. int device;
  417. std::string name;
  418. cudaEvent_t copy_event = nullptr;
  419. cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
  420. cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
  421. explicit ggml_backend_cuda_context(int device) :
  422. device(device),
  423. name(GGML_CUDA_NAME + std::to_string(device)) {
  424. }
  425. ~ggml_backend_cuda_context() {
  426. if (copy_event != nullptr) {
  427. CUDA_CHECK(cudaEventDestroy(copy_event));
  428. }
  429. for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
  430. for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
  431. if (streams[i][j] != nullptr) {
  432. CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
  433. }
  434. }
  435. if (cublas_handles[i] != nullptr) {
  436. CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
  437. }
  438. }
  439. }
  440. cudaStream_t stream(int device, int stream) {
  441. if (streams[device][stream] == nullptr) {
  442. ggml_cuda_set_device(device);
  443. CUDA_CHECK(cudaStreamCreateWithFlags(&streams[device][stream], cudaStreamNonBlocking));
  444. }
  445. return streams[device][stream];
  446. }
  447. cudaStream_t stream() {
  448. return stream(device, 0);
  449. }
  450. cublasHandle_t cublas_handle(int device) {
  451. if (cublas_handles[device] == nullptr) {
  452. ggml_cuda_set_device(device);
  453. CUBLAS_CHECK(cublasCreate(&cublas_handles[device]));
  454. CUBLAS_CHECK(cublasSetMathMode(cublas_handles[device], CUBLAS_TF32_TENSOR_OP_MATH));
  455. }
  456. return cublas_handles[device];
  457. }
  458. cublasHandle_t cublas_handle() {
  459. return cublas_handle(device);
  460. }
  461. // pool
  462. std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
  463. static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
  464. ggml_cuda_pool & pool(int device) {
  465. if (pools[device] == nullptr) {
  466. pools[device] = new_pool_for_device(device);
  467. }
  468. return *pools[device];
  469. }
  470. ggml_cuda_pool & pool() {
  471. return pool(device);
  472. }
  473. };