common.cuh 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846
  1. #pragma once
  2. #include "ggml.h"
  3. #include "ggml-cuda.h"
  4. #include <memory>
  5. #if defined(GGML_USE_HIPBLAS)
  6. #define GGML_COMMON_DECL_HIP
  7. #define GGML_COMMON_IMPL_HIP
  8. #else
  9. #define GGML_COMMON_DECL_CUDA
  10. #define GGML_COMMON_IMPL_CUDA
  11. #endif
  12. #include "ggml-common.h"
  13. #include <cstdio>
  14. #include <array>
  15. #include <cassert>
  16. #include <cfloat>
  17. #include <string>
  18. #include <vector>
  19. #if defined(GGML_USE_HIPBLAS)
  20. #include <hip/hip_runtime.h>
  21. #include <hipblas/hipblas.h>
  22. #include <hip/hip_fp16.h>
  23. #ifdef __HIP_PLATFORM_AMD__
  24. // for rocblas_initialize()
  25. #include "rocblas/rocblas.h"
  26. #endif // __HIP_PLATFORM_AMD__
  27. #define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
  28. #define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
  29. #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
  30. #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
  31. #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
  32. #define CUBLAS_OP_N HIPBLAS_OP_N
  33. #define CUBLAS_OP_T HIPBLAS_OP_T
  34. #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
  35. #define CUBLAS_TF32_TENSOR_OP_MATH 0
  36. #define CUDA_R_16F HIPBLAS_R_16F
  37. #define CUDA_R_32F HIPBLAS_R_32F
  38. #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
  39. #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
  40. #define cublasCreate hipblasCreate
  41. #define cublasDestroy hipblasDestroy
  42. #define cublasGemmEx hipblasGemmEx
  43. #define cublasGemmBatchedEx hipblasGemmBatchedEx
  44. #define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
  45. #define cublasHandle_t hipblasHandle_t
  46. #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
  47. #define cublasSetStream hipblasSetStream
  48. #define cublasSgemm hipblasSgemm
  49. #define cublasStatus_t hipblasStatus_t
  50. #define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
  51. #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
  52. #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
  53. #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
  54. #define cudaDeviceProp hipDeviceProp_t
  55. #define cudaDeviceSynchronize hipDeviceSynchronize
  56. #define cudaError_t hipError_t
  57. #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
  58. #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
  59. #define cudaEventCreateWithFlags hipEventCreateWithFlags
  60. #define cudaEventDisableTiming hipEventDisableTiming
  61. #define cudaEventRecord hipEventRecord
  62. #define cudaEventSynchronize hipEventSynchronize
  63. #define cudaEvent_t hipEvent_t
  64. #define cudaEventDestroy hipEventDestroy
  65. #define cudaFree hipFree
  66. #define cudaFreeHost hipHostFree
  67. #define cudaGetDevice hipGetDevice
  68. #define cudaGetDeviceCount hipGetDeviceCount
  69. #define cudaGetDeviceProperties hipGetDeviceProperties
  70. #define cudaGetErrorString hipGetErrorString
  71. #define cudaGetLastError hipGetLastError
  72. #define cudaHostRegister hipHostRegister
  73. #define cudaHostRegisterPortable hipHostRegisterPortable
  74. #define cudaHostRegisterReadOnly hipHostRegisterReadOnly
  75. #define cudaHostUnregister hipHostUnregister
  76. #define cudaLaunchHostFunc hipLaunchHostFunc
  77. #define cudaMalloc hipMalloc
  78. #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
  79. #define cudaMemcpy hipMemcpy
  80. #define cudaMemcpyAsync hipMemcpyAsync
  81. #define cudaMemcpyPeerAsync hipMemcpyPeerAsync
  82. #define cudaMemcpy2DAsync hipMemcpy2DAsync
  83. #define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
  84. #define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
  85. #define cudaMemcpyHostToDevice hipMemcpyHostToDevice
  86. #define cudaMemcpyKind hipMemcpyKind
  87. #define cudaMemset hipMemset
  88. #define cudaMemsetAsync hipMemsetAsync
  89. #define cudaMemGetInfo hipMemGetInfo
  90. #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
  91. #define cudaSetDevice hipSetDevice
  92. #define cudaStreamCreateWithFlags hipStreamCreateWithFlags
  93. #define cudaStreamDestroy hipStreamDestroy
  94. #define cudaStreamFireAndForget hipStreamFireAndForget
  95. #define cudaStreamNonBlocking hipStreamNonBlocking
  96. #define cudaStreamPerThread hipStreamPerThread
  97. #define cudaStreamSynchronize hipStreamSynchronize
  98. #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
  99. #define cudaStream_t hipStream_t
  100. #define cudaSuccess hipSuccess
  101. #define __trap abort
  102. #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
  103. #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
  104. #define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
  105. #define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
  106. #define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
  107. #define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
  108. #define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
  109. #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
  110. #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
  111. #else
  112. #include <cuda_runtime.h>
  113. #include <cuda.h>
  114. #include <cublas_v2.h>
  115. #include <cuda_fp16.h>
  116. #if CUDART_VERSION < 11020
  117. #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
  118. #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
  119. #define CUBLAS_COMPUTE_16F CUDA_R_16F
  120. #define CUBLAS_COMPUTE_32F CUDA_R_32F
  121. #define cublasComputeType_t cudaDataType_t
  122. #endif // CUDART_VERSION < 11020
  123. #endif // defined(GGML_USE_HIPBLAS)
  124. #define STRINGIZE_IMPL(...) #__VA_ARGS__
  125. #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
  126. #define WARP_SIZE 32
  127. #define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
  128. #define CUDART_HMASK 12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
  129. #define CC_PASCAL 600
  130. #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
  131. #define CC_VOLTA 700
  132. #define CC_TURING 750
  133. #define CC_AMPERE 800
  134. #define CC_OFFSET_AMD 1000000
  135. #define CC_RDNA1 (CC_OFFSET_AMD + 1010)
  136. #define CC_RDNA2 (CC_OFFSET_AMD + 1030)
  137. #define CC_RDNA3 (CC_OFFSET_AMD + 1100)
  138. // define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
  139. // on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
  140. // for large computational tasks. the drawback is that this requires some extra amount of VRAM:
  141. // - 7B quantum model: +100-200 MB
  142. // - 13B quantum model: +200-400 MB
  143. //
  144. //#define GGML_CUDA_FORCE_MMQ
  145. // TODO: improve this to be correct for more hardware
  146. // for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
  147. #if !defined(GGML_CUDA_FORCE_MMQ)
  148. #define CUDA_USE_TENSOR_CORES
  149. #endif
  150. #define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels
  151. #define MMQ_MAX_BATCH_SIZE 64 // max batch size to use MMQ kernels when tensor cores are available
  152. #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
  153. #if defined(_MSC_VER)
  154. #pragma warning(disable: 4244 4267) // possible loss of data
  155. #endif
  156. #define GGML_CUDA_MAX_STREAMS 8
  157. [[noreturn]]
  158. void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
  159. #define CUDA_CHECK_GEN(err, success, error_fn) \
  160. do { \
  161. auto err_ = (err); \
  162. if (err_ != (success)) { \
  163. ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_)); \
  164. } \
  165. } while (0)
  166. #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
  167. #if CUDART_VERSION >= 12000
  168. static const char * cublas_get_error_str(const cublasStatus_t err) {
  169. return cublasGetStatusString(err);
  170. }
  171. #else
  172. static const char * cublas_get_error_str(const cublasStatus_t err) {
  173. switch (err) {
  174. case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
  175. case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
  176. case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
  177. case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
  178. case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
  179. case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
  180. case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
  181. case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
  182. case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
  183. default: return "unknown error";
  184. }
  185. }
  186. #endif // CUDART_VERSION >= 12000
  187. #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
  188. #if !defined(GGML_USE_HIPBLAS)
  189. static const char * cu_get_error_str(CUresult err) {
  190. const char * err_str;
  191. cuGetErrorString(err, &err_str);
  192. return err_str;
  193. }
  194. #define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
  195. #endif
  196. #if CUDART_VERSION >= 11100
  197. #define GGML_CUDA_ASSUME(x) __builtin_assume(x)
  198. #else
  199. #define GGML_CUDA_ASSUME(x)
  200. #endif // CUDART_VERSION >= 11100
  201. #ifdef GGML_CUDA_F16
  202. typedef half dfloat; // dequantize float
  203. typedef half2 dfloat2;
  204. #else
  205. typedef float dfloat; // dequantize float
  206. typedef float2 dfloat2;
  207. #endif //GGML_CUDA_F16
  208. #if defined(GGML_USE_HIPBLAS)
  209. #define __CUDA_ARCH__ 1300
  210. #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
  211. defined(__gfx1150__) || defined(__gfx1151__)
  212. #define RDNA3
  213. #endif
  214. #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
  215. defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
  216. #define RDNA2
  217. #endif
  218. #ifndef __has_builtin
  219. #define __has_builtin(x) 0
  220. #endif
  221. typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
  222. typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
  223. static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
  224. const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
  225. const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
  226. #if __has_builtin(__builtin_elementwise_sub_sat)
  227. const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
  228. return reinterpret_cast<const int &>(c);
  229. #else
  230. int8x4_t c;
  231. int16_t tmp;
  232. #pragma unroll
  233. for (int i = 0; i < 4; i++) {
  234. tmp = va[i] - vb[i];
  235. if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
  236. if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
  237. c[i] = tmp;
  238. }
  239. return reinterpret_cast<int &>(c);
  240. #endif // __has_builtin(__builtin_elementwise_sub_sat)
  241. }
  242. static __device__ __forceinline__ int __vsub4(const int a, const int b) {
  243. return __vsubss4(a, b);
  244. }
  245. static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
  246. const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
  247. const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
  248. unsigned int c;
  249. uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
  250. #pragma unroll
  251. for (int i = 0; i < 4; ++i) {
  252. vc[i] = va[i] == vb[i] ? 0xff : 0x00;
  253. }
  254. return c;
  255. }
  256. static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
  257. #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
  258. c = __builtin_amdgcn_sdot4(a, b, c, false);
  259. #elif defined(RDNA3)
  260. c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
  261. #elif defined(__gfx1010__) || defined(__gfx900__)
  262. int tmp1;
  263. int tmp2;
  264. asm("\n \
  265. v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
  266. v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
  267. v_add3_u32 %0, %1, %2, %0 \n \
  268. v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
  269. v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
  270. v_add3_u32 %0, %1, %2, %0 \n \
  271. "
  272. : "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
  273. : "v"(a), "v"(b)
  274. );
  275. #else
  276. const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
  277. const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
  278. c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
  279. #endif
  280. return c;
  281. }
  282. #if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
  283. // __shfl_xor() for half2 was added in ROCm 5.6
  284. static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
  285. typedef union half2_b32 {
  286. half2 val;
  287. int b32;
  288. } half2_b32_t;
  289. half2_b32_t tmp;
  290. tmp.val = var;
  291. tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
  292. return tmp.val;
  293. }
  294. #endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
  295. #endif // defined(GGML_USE_HIPBLAS)
  296. #if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
  297. #define FP16_AVAILABLE
  298. #endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
  299. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
  300. #define FP16_MMA_AVAILABLE
  301. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
  302. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
  303. #define INT8_MMA_AVAILABLE
  304. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
  305. static bool fast_fp16_available(const int cc) {
  306. return cc >= CC_PASCAL && cc != 610;
  307. }
  308. static bool fp16_mma_available(const int cc) {
  309. return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
  310. }
  311. static bool int8_mma_available(const int cc) {
  312. return cc < CC_OFFSET_AMD && cc >= CC_TURING;
  313. }
  314. [[noreturn]]
  315. static __device__ void no_device_code(
  316. const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
  317. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  318. printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
  319. file_name, line, function_name, arch);
  320. GGML_UNUSED(arch_list);
  321. #else
  322. printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
  323. file_name, line, function_name, arch, arch_list);
  324. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  325. __trap();
  326. GGML_UNUSED(no_device_code); // suppress unused function warning
  327. }
  328. #ifdef __CUDA_ARCH__
  329. #define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
  330. #else
  331. #define NO_DEVICE_CODE //GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.")
  332. #endif // __CUDA_ARCH__
  333. static __device__ __forceinline__ float warp_reduce_sum(float x) {
  334. #pragma unroll
  335. for (int mask = 16; mask > 0; mask >>= 1) {
  336. x += __shfl_xor_sync(0xffffffff, x, mask, 32);
  337. }
  338. return x;
  339. }
  340. static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
  341. #pragma unroll
  342. for (int mask = 16; mask > 0; mask >>= 1) {
  343. a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
  344. a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
  345. }
  346. return a;
  347. }
  348. static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
  349. #ifdef FP16_AVAILABLE
  350. #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  351. #pragma unroll
  352. for (int mask = 16; mask > 0; mask >>= 1) {
  353. const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
  354. reinterpret_cast<half&>(a.x) += __low2half(a_other);
  355. reinterpret_cast<half&>(a.y) += __high2half(a_other);
  356. }
  357. return a;
  358. #else
  359. #pragma unroll
  360. for (int mask = 16; mask > 0; mask >>= 1) {
  361. a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
  362. }
  363. return a;
  364. #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
  365. #else
  366. NO_DEVICE_CODE;
  367. return a;
  368. #endif // FP16_AVAILABLE
  369. }
  370. static __device__ __forceinline__ float warp_reduce_max(float x) {
  371. #pragma unroll
  372. for (int mask = 16; mask > 0; mask >>= 1) {
  373. x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
  374. }
  375. return x;
  376. }
  377. static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
  378. #ifdef FP16_AVAILABLE
  379. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
  380. return __float2half(fmaxf(__half2float(a), __half2float(b)));
  381. #else
  382. return __hmax(a, b);
  383. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
  384. #else
  385. NO_DEVICE_CODE;
  386. GGML_UNUSED(b);
  387. return a;
  388. #endif // FP16_AVAILABLE
  389. }
  390. static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
  391. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  392. #if CUDART_VERSION >= CUDART_HMAX
  393. return __hmax2(a, b);
  394. #else
  395. half2 ret;
  396. reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a), __low2float(b)));
  397. reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
  398. return ret;
  399. #endif // CUDART_VERSION >= CUDART_HMAX
  400. #else
  401. GGML_UNUSED(a);
  402. GGML_UNUSED(b);
  403. NO_DEVICE_CODE;
  404. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
  405. }
  406. static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
  407. #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
  408. #pragma unroll
  409. for (int mask = 16; mask > 0; mask >>= 1) {
  410. x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
  411. }
  412. return x;
  413. #else
  414. GGML_UNUSED(x);
  415. NO_DEVICE_CODE;
  416. #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
  417. }
  418. #if CUDART_VERSION < CUDART_HMASK
  419. static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
  420. const uint32_t mask_low = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
  421. const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
  422. return mask_low | mask_high;
  423. }
  424. #endif // CUDART_VERSION < 12000
  425. // TODO: move to ggml-common.h
  426. static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
  427. typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
  428. static __device__ __forceinline__ float get_alibi_slope(
  429. const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
  430. ) {
  431. if (max_bias <= 0.0f) {
  432. return 1.0f;
  433. }
  434. const float base = h < n_head_log2 ? m0 : m1;
  435. const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
  436. return powf(base, exph);
  437. }
  438. template <ggml_type type>
  439. struct ggml_cuda_type_traits;
  440. template<>
  441. struct ggml_cuda_type_traits<GGML_TYPE_F16> {
  442. static constexpr int qk = 1;
  443. static constexpr int qr = 1;
  444. };
  445. template<>
  446. struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> {
  447. static constexpr int qk = QK4_0;
  448. static constexpr int qr = QR4_0;
  449. static constexpr int qi = QI4_0;
  450. };
  451. template<>
  452. struct ggml_cuda_type_traits<GGML_TYPE_Q4_1> {
  453. static constexpr int qk = QK4_1;
  454. static constexpr int qr = QR4_1;
  455. static constexpr int qi = QI4_1;
  456. };
  457. template<>
  458. struct ggml_cuda_type_traits<GGML_TYPE_Q5_0> {
  459. static constexpr int qk = QK5_0;
  460. static constexpr int qr = QR5_0;
  461. static constexpr int qi = QI5_0;
  462. };
  463. template<>
  464. struct ggml_cuda_type_traits<GGML_TYPE_Q5_1> {
  465. static constexpr int qk = QK5_1;
  466. static constexpr int qr = QR5_1;
  467. static constexpr int qi = QI5_1;
  468. };
  469. template<>
  470. struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
  471. static constexpr int qk = QK8_0;
  472. static constexpr int qr = QR8_0;
  473. static constexpr int qi = QI8_0;
  474. };
  475. template<>
  476. struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
  477. static constexpr int qk = QK_K;
  478. static constexpr int qr = QR2_K;
  479. static constexpr int qi = QI2_K;
  480. };
  481. template<>
  482. struct ggml_cuda_type_traits<GGML_TYPE_Q3_K> {
  483. static constexpr int qk = QK_K;
  484. static constexpr int qr = QR3_K;
  485. static constexpr int qi = QI3_K;
  486. };
  487. template<>
  488. struct ggml_cuda_type_traits<GGML_TYPE_Q4_K> {
  489. static constexpr int qk = QK_K;
  490. static constexpr int qr = QR4_K;
  491. static constexpr int qi = QI4_K;
  492. };
  493. template<>
  494. struct ggml_cuda_type_traits<GGML_TYPE_Q5_K> {
  495. static constexpr int qk = QK_K;
  496. static constexpr int qr = QR5_K;
  497. static constexpr int qi = QI5_K;
  498. };
  499. template<>
  500. struct ggml_cuda_type_traits<GGML_TYPE_Q6_K> {
  501. static constexpr int qk = QK_K;
  502. static constexpr int qr = QR6_K;
  503. static constexpr int qi = QI6_K;
  504. };
  505. template<>
  506. struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XXS> {
  507. static constexpr int qk = QK_K;
  508. static constexpr int qr = QR2_XXS;
  509. static constexpr int qi = QI2_XXS;
  510. };
  511. template<>
  512. struct ggml_cuda_type_traits<GGML_TYPE_IQ2_XS> {
  513. static constexpr int qk = QK_K;
  514. static constexpr int qr = QR2_XS;
  515. static constexpr int qi = QI2_XS;
  516. };
  517. template<>
  518. struct ggml_cuda_type_traits<GGML_TYPE_IQ2_S> {
  519. static constexpr int qk = QK_K;
  520. static constexpr int qr = QR2_S;
  521. static constexpr int qi = QI2_S;
  522. };
  523. template<>
  524. struct ggml_cuda_type_traits<GGML_TYPE_IQ3_XXS> {
  525. static constexpr int qk = QK_K;
  526. static constexpr int qr = QR3_XXS;
  527. static constexpr int qi = QI3_XXS;
  528. };
  529. template<>
  530. struct ggml_cuda_type_traits<GGML_TYPE_IQ1_S> {
  531. static constexpr int qk = QK_K;
  532. static constexpr int qr = QR1_S;
  533. static constexpr int qi = QI1_S;
  534. };
  535. template<>
  536. struct ggml_cuda_type_traits<GGML_TYPE_IQ1_M> {
  537. static constexpr int qk = QK_K;
  538. static constexpr int qr = QR1_M;
  539. static constexpr int qi = QI1_M;
  540. };
  541. template<>
  542. struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
  543. static constexpr int qk = QK4_NL;
  544. static constexpr int qr = QR4_NL;
  545. static constexpr int qi = QI4_NL;
  546. };
  547. template<>
  548. struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
  549. static constexpr int qk = QK_K;
  550. static constexpr int qr = QR4_XS;
  551. static constexpr int qi = QI4_XS;
  552. };
  553. template<>
  554. struct ggml_cuda_type_traits<GGML_TYPE_IQ3_S> {
  555. static constexpr int qk = QK_K;
  556. static constexpr int qr = QR3_S;
  557. static constexpr int qi = QI3_S;
  558. };
  559. static int get_mmq_x_max_host(const int cc) {
  560. #ifdef CUDA_USE_TENSOR_CORES
  561. return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_MAX_BATCH_SIZE : 64;
  562. #else
  563. return cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128 : 64;
  564. #endif // CUDA_USE_TENSOR_CORES
  565. }
  566. // Round rows to this value for --split-mode row:
  567. static int get_mmq_y_host(const int cc, const int mmq_x) {
  568. return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64;
  569. }
  570. //////////////////////
  571. struct ggml_cuda_device_info {
  572. int device_count;
  573. struct cuda_device_info {
  574. int cc; // compute capability
  575. int nsm; // number of streaming multiprocessors
  576. size_t smpb; // max. shared memory per block
  577. bool vmm; // virtual memory support
  578. size_t vmm_granularity; // granularity of virtual memory
  579. size_t total_vram;
  580. };
  581. cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
  582. std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
  583. };
  584. const ggml_cuda_device_info & ggml_cuda_info();
  585. void ggml_cuda_set_device(int device);
  586. int ggml_cuda_get_device();
  587. struct ggml_cuda_pool {
  588. virtual ~ggml_cuda_pool() = default;
  589. virtual void * alloc(size_t size, size_t * actual_size) = 0;
  590. virtual void free(void * ptr, size_t size) = 0;
  591. };
  592. template<typename T>
  593. struct ggml_cuda_pool_alloc {
  594. ggml_cuda_pool * pool = nullptr;
  595. T * ptr = nullptr;
  596. size_t actual_size = 0;
  597. ggml_cuda_pool_alloc() = default;
  598. explicit ggml_cuda_pool_alloc(ggml_cuda_pool & pool) : pool(&pool) {
  599. }
  600. ggml_cuda_pool_alloc(ggml_cuda_pool & pool, size_t size) : pool(&pool) {
  601. alloc(size);
  602. }
  603. ~ggml_cuda_pool_alloc() {
  604. if (ptr != nullptr) {
  605. pool->free(ptr, actual_size);
  606. }
  607. }
  608. // size is in number of elements
  609. T * alloc(size_t size) {
  610. GGML_ASSERT(pool != nullptr);
  611. GGML_ASSERT(ptr == nullptr);
  612. ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
  613. return ptr;
  614. }
  615. T * alloc(ggml_cuda_pool & pool, size_t size) {
  616. this->pool = &pool;
  617. return alloc(size);
  618. }
  619. T * get() {
  620. return ptr;
  621. }
  622. ggml_cuda_pool_alloc(const ggml_cuda_pool_alloc &) = delete;
  623. ggml_cuda_pool_alloc(ggml_cuda_pool_alloc &&) = delete;
  624. ggml_cuda_pool_alloc& operator=(const ggml_cuda_pool_alloc &) = delete;
  625. ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete;
  626. };
  627. // backend interface
  628. struct ggml_tensor_extra_gpu {
  629. void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
  630. cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
  631. };
  632. #if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)
  633. #define USE_CUDA_GRAPH
  634. #endif
  635. struct ggml_graph_node_properties {
  636. void * node_address;
  637. ggml_op node_op;
  638. int64_t ne[GGML_MAX_DIMS];
  639. size_t nb[GGML_MAX_DIMS];
  640. void * src_address[GGML_MAX_SRC];
  641. };
  642. struct ggml_cuda_graph {
  643. #ifdef USE_CUDA_GRAPH
  644. ~ggml_cuda_graph() {
  645. if (instance != nullptr) {
  646. CUDA_CHECK(cudaGraphExecDestroy(instance));
  647. }
  648. if (graph != nullptr) {
  649. CUDA_CHECK(cudaGraphDestroy(graph));
  650. }
  651. }
  652. cudaGraph_t graph = nullptr;
  653. cudaGraphExec_t instance = nullptr;
  654. size_t num_nodes = 0;
  655. std::vector<cudaGraphNode_t> nodes;
  656. std::vector<cudaKernelNodeParams> params;
  657. bool disable_due_to_gpu_arch = false;
  658. bool disable_due_to_too_many_updates = false;
  659. bool disable_due_to_failed_graph_capture = false;
  660. int number_consecutive_updates = 0;
  661. std::vector<ggml_graph_node_properties> ggml_graph_properties;
  662. std::vector<char **> updated_kernel_arg;
  663. #endif
  664. };
  665. struct ggml_backend_cuda_context {
  666. int device;
  667. std::string name;
  668. cudaEvent_t copy_event = nullptr;
  669. cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
  670. cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
  671. std::unique_ptr<ggml_cuda_graph> cuda_graph;
  672. explicit ggml_backend_cuda_context(int device) :
  673. device(device),
  674. name(GGML_CUDA_NAME + std::to_string(device)) {
  675. }
  676. ~ggml_backend_cuda_context() {
  677. if (copy_event != nullptr) {
  678. CUDA_CHECK(cudaEventDestroy(copy_event));
  679. }
  680. for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
  681. for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
  682. if (streams[i][j] != nullptr) {
  683. CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
  684. }
  685. }
  686. if (cublas_handles[i] != nullptr) {
  687. CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
  688. }
  689. }
  690. }
  691. cudaStream_t stream(int device, int stream) {
  692. if (streams[device][stream] == nullptr) {
  693. ggml_cuda_set_device(device);
  694. CUDA_CHECK(cudaStreamCreateWithFlags(&streams[device][stream], cudaStreamNonBlocking));
  695. }
  696. return streams[device][stream];
  697. }
  698. cudaStream_t stream() {
  699. return stream(device, 0);
  700. }
  701. cublasHandle_t cublas_handle(int device) {
  702. if (cublas_handles[device] == nullptr) {
  703. ggml_cuda_set_device(device);
  704. CUBLAS_CHECK(cublasCreate(&cublas_handles[device]));
  705. CUBLAS_CHECK(cublasSetMathMode(cublas_handles[device], CUBLAS_TF32_TENSOR_OP_MATH));
  706. }
  707. return cublas_handles[device];
  708. }
  709. cublasHandle_t cublas_handle() {
  710. return cublas_handle(device);
  711. }
  712. // pool
  713. std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
  714. static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
  715. ggml_cuda_pool & pool(int device) {
  716. if (pools[device] == nullptr) {
  717. pools[device] = new_pool_for_device(device);
  718. }
  719. return *pools[device];
  720. }
  721. ggml_cuda_pool & pool() {
  722. return pool(device);
  723. }
  724. };