ggml-cuda.cu 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898
  1. #include <cstddef>
  2. #include <cstdint>
  3. #include <stdint.h>
  4. #include <stdio.h>
  5. #include <atomic>
  6. #include <cuda_runtime.h>
  7. #include <cublas_v2.h>
  8. #include <cuda_fp16.h>
  9. #include "ggml-cuda.h"
  10. #include "ggml.h"
  11. static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
  12. #define CUDA_CHECK(err) \
  13. do { \
  14. cudaError_t err_ = (err); \
  15. if (err_ != cudaSuccess) { \
  16. fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
  17. cudaGetErrorString(err_)); \
  18. exit(1); \
  19. } \
  20. } while (0)
  21. #define CUBLAS_CHECK(err) \
  22. do { \
  23. cublasStatus_t err_ = (err); \
  24. if (err_ != CUBLAS_STATUS_SUCCESS) { \
  25. fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
  26. exit(1); \
  27. } \
  28. } while (0)
  29. typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
  30. typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
  31. typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
  32. // QK = number of values after dequantization
  33. // QR = QK / number of values before dequantization
  34. #define QK4_0 32
  35. #define QR4_0 2
  36. typedef struct {
  37. float d; // delta
  38. uint8_t qs[QK4_0 / 2]; // nibbles / quants
  39. } block_q4_0;
  40. static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
  41. #define QK4_1 32
  42. #define QR4_1 2
  43. typedef struct {
  44. float d; // delta
  45. float m; // min
  46. uint8_t qs[QK4_1 / 2]; // nibbles / quants
  47. } block_q4_1;
  48. static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
  49. #define QK5_0 32
  50. #define QR5_0 2
  51. typedef struct {
  52. half d; // delta
  53. uint8_t qh[4]; // 5-th bit of quants
  54. uint8_t qs[QK5_0 / 2]; // nibbles / quants
  55. } block_q5_0;
  56. static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
  57. #define QK5_1 32
  58. #define QR5_1 2
  59. typedef struct {
  60. half d; // delta
  61. half m; // min
  62. uint8_t qh[4]; // 5-th bit of quants
  63. uint8_t qs[QK5_1 / 2]; // nibbles / quants
  64. } block_q5_1;
  65. static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
  66. #define QK8_0 32
  67. #define QR8_0 1
  68. typedef struct {
  69. float d; // delta
  70. int8_t qs[QK8_0]; // quants
  71. } block_q8_0;
  72. static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
  73. #define CUDA_DMMV_BLOCK_SIZE 32
  74. static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  75. const block_q4_0 * x = (const block_q4_0 *) vx;
  76. const float d = x[ib].d;
  77. const uint8_t vui = x[ib].qs[iqs];
  78. const int8_t vi0 = vui & 0xF;
  79. const int8_t vi1 = vui >> 4;
  80. v0 = (vi0 - 8)*d;
  81. v1 = (vi1 - 8)*d;
  82. }
  83. static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  84. const block_q4_1 * x = (const block_q4_1 *) vx;
  85. const float d = x[ib].d;
  86. const float m = x[ib].m;
  87. const uint8_t vui = x[ib].qs[iqs];
  88. const int8_t vi0 = vui & 0xF;
  89. const int8_t vi1 = vui >> 4;
  90. v0 = vi0*d + m;
  91. v1 = vi1*d + m;
  92. }
  93. static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  94. const block_q5_0 * x = (const block_q5_0 *) vx;
  95. const float d = x[ib].d;
  96. uint32_t qh;
  97. memcpy(&qh, x[ib].qh, sizeof(qh));
  98. const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
  99. const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
  100. const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16;
  101. const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16;
  102. v0 = x0*d;
  103. v1 = x1*d;
  104. }
  105. static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  106. const block_q5_1 * x = (const block_q5_1 *) vx;
  107. const float d = x[ib].d;
  108. const float m = x[ib].m;
  109. uint32_t qh;
  110. memcpy(&qh, x[ib].qh, sizeof(qh));
  111. const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
  112. const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
  113. const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
  114. const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1);
  115. v0 = x0*d + m;
  116. v1 = x1*d + m;
  117. }
  118. static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  119. const block_q8_0 * x = (const block_q8_0 *) vx;
  120. const float d = x[ib].d;
  121. const int8_t vi0 = x[ib].qs[iqs + 0];
  122. const int8_t vi1 = x[ib].qs[iqs + 1];
  123. v0 = vi0*d;
  124. v1 = vi1*d;
  125. }
  126. static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  127. const half * x = (const half *) vx;
  128. v0 = __half2float(x[ib + 0]);
  129. v1 = __half2float(x[ib + 1]);
  130. }
  131. static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
  132. static const int qk = QK4_0;
  133. const block_q4_0 * x = (const block_q4_0 *) vx;
  134. const int i = blockIdx.x;
  135. const float d = x[i].d;
  136. for (int j = 0; j < qk/2; ++j) {
  137. const int x0 = (x[i].qs[j] & 0xf) - 8;
  138. const int x1 = (x[i].qs[j] >> 4) - 8;
  139. y[i*qk + j + 0 ] = x0*d;
  140. y[i*qk + j + qk/2] = x1*d;
  141. }
  142. }
  143. static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
  144. static const int qk = QK4_1;
  145. const block_q4_1 * x = (const block_q4_1 *) vx;
  146. const int i = blockIdx.x;
  147. const float d = x[i].d;
  148. const float m = x[i].m;
  149. for (int j = 0; j < qk/2; ++j) {
  150. const int x0 = (x[i].qs[j] & 0xf);
  151. const int x1 = (x[i].qs[j] >> 4);
  152. y[i*qk + j + 0 ] = x0*d + m;
  153. y[i*qk + j + qk/2] = x1*d + m;
  154. }
  155. }
  156. static __global__ void dequantize_block_q5_0(const void * vx, float * y) {
  157. static const int qk = QK5_0;
  158. const block_q5_0 * x = (const block_q5_0 *) vx;
  159. const int i = blockIdx.x;
  160. const float d = x[i].d;
  161. uint32_t qh;
  162. memcpy(&qh, x[i].qh, sizeof(qh));
  163. for (int j = 0; j < qk/2; ++j) {
  164. const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
  165. const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
  166. const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
  167. const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
  168. y[i*qk + j + 0 ] = x0*d;
  169. y[i*qk + j + qk/2] = x1*d;
  170. }
  171. }
  172. static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
  173. static const int qk = QK5_1;
  174. const block_q5_1 * x = (const block_q5_1 *) vx;
  175. const int i = blockIdx.x;
  176. const float d = x[i].d;
  177. const float m = x[i].m;
  178. uint32_t qh;
  179. memcpy(&qh, x[i].qh, sizeof(qh));
  180. for (int j = 0; j < qk/2; ++j) {
  181. const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
  182. const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
  183. const int x0 = (x[i].qs[j] & 0xf) | xh_0;
  184. const int x1 = (x[i].qs[j] >> 4) | xh_1;
  185. y[i*qk + j + 0 ] = x0*d + m;
  186. y[i*qk + j + qk/2] = x1*d + m;
  187. }
  188. }
  189. static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
  190. static const int qk = QK8_0;
  191. const block_q8_0 * x = (const block_q8_0 *) vx;
  192. const int i = blockIdx.x;
  193. const float d = x[i].d;
  194. for (int j = 0; j < qk; ++j) {
  195. y[i*qk + j] = x[i].qs[j]*d;
  196. }
  197. }
  198. template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
  199. static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
  200. const int row = blockIdx.x;
  201. const int tid = threadIdx.x;
  202. const int y_offset = qr == 1 ? 1 : qk/2;
  203. __shared__ float tmp[block_size]; // separate sum for each thread
  204. tmp[tid] = 0;
  205. for (int i = 0; i < ncols/block_size; i += 2) {
  206. const int col = i*block_size + 2*tid;
  207. const int ib = (row*ncols + col)/qk; // block index
  208. const int iqs = (col%qk)/qr; // quant index
  209. const int iybs = col - col%qk; // y block start index
  210. // dequantize
  211. float v0, v1;
  212. dequantize_kernel(vx, ib, iqs, v0, v1);
  213. // matrix multiplication
  214. tmp[tid] += v0 * y[iybs + iqs + 0];
  215. tmp[tid] += v1 * y[iybs + iqs + y_offset];
  216. }
  217. // sum up partial sums and write back result
  218. __syncthreads();
  219. for (int s=block_size/2; s>0; s>>=1) {
  220. if (tid < s) {
  221. tmp[tid] += tmp[tid + s];
  222. }
  223. __syncthreads();
  224. }
  225. if (tid == 0) {
  226. dst[row] = tmp[0];
  227. }
  228. }
  229. static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
  230. const int nb = k / QK4_0;
  231. dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
  232. }
  233. static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
  234. const int nb = k / QK4_1;
  235. dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
  236. }
  237. static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
  238. const int nb = k / QK5_0;
  239. dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
  240. }
  241. static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
  242. const int nb = k / QK5_1;
  243. dequantize_block_q5_1<<<nb, 1, 0, stream>>>(vx, y);
  244. }
  245. static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
  246. const int nb = k / QK8_0;
  247. dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
  248. }
  249. static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  250. GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
  251. dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0>
  252. <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
  253. }
  254. static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  255. GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
  256. dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1>
  257. <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
  258. }
  259. static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  260. GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
  261. dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0>
  262. <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
  263. }
  264. static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  265. GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
  266. dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1>
  267. <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
  268. }
  269. static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  270. GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
  271. dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0>
  272. <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
  273. }
  274. // TODO: optimize
  275. static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
  276. const half * x = (const half *) vx;
  277. const int i = blockIdx.x;
  278. y[i] = __half2float(x[i]);
  279. }
  280. static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStream_t stream) {
  281. convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y);
  282. }
  283. static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  284. GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
  285. dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32, 1, convert_f16>
  286. <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
  287. }
  288. static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
  289. switch (type) {
  290. case GGML_TYPE_Q4_0:
  291. return dequantize_row_q4_0_cuda;
  292. case GGML_TYPE_Q4_1:
  293. return dequantize_row_q4_1_cuda;
  294. case GGML_TYPE_Q5_0:
  295. return dequantize_row_q5_0_cuda;
  296. case GGML_TYPE_Q5_1:
  297. return dequantize_row_q5_1_cuda;
  298. case GGML_TYPE_Q8_0:
  299. return dequantize_row_q8_0_cuda;
  300. case GGML_TYPE_F16:
  301. return convert_fp16_to_fp32_cuda;
  302. default:
  303. return nullptr;
  304. }
  305. }
  306. static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_type type) {
  307. switch (type) {
  308. case GGML_TYPE_Q4_0:
  309. return dequantize_mul_mat_vec_q4_0_cuda;
  310. case GGML_TYPE_Q4_1:
  311. return dequantize_mul_mat_vec_q4_1_cuda;
  312. case GGML_TYPE_Q5_0:
  313. return dequantize_mul_mat_vec_q5_0_cuda;
  314. case GGML_TYPE_Q5_1:
  315. return dequantize_mul_mat_vec_q5_1_cuda;
  316. case GGML_TYPE_Q8_0:
  317. return dequantize_mul_mat_vec_q8_0_cuda;
  318. case GGML_TYPE_F16:
  319. return dequantize_mul_mat_vec_q8_0_cuda;
  320. default:
  321. return nullptr;
  322. }
  323. }
  324. // buffer pool for cuda
  325. #define MAX_CUDA_BUFFERS 256
  326. struct scoped_spin_lock {
  327. std::atomic_flag& lock;
  328. scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
  329. while (lock.test_and_set(std::memory_order_acquire)) {
  330. ; // spin
  331. }
  332. }
  333. ~scoped_spin_lock() {
  334. lock.clear(std::memory_order_release);
  335. }
  336. scoped_spin_lock(const scoped_spin_lock&) = delete;
  337. scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
  338. };
  339. struct cuda_buffer {
  340. void * ptr = nullptr;
  341. size_t size = 0;
  342. };
  343. static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
  344. static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
  345. static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
  346. scoped_spin_lock lock(g_cuda_pool_lock);
  347. for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
  348. cuda_buffer& b = g_cuda_buffer_pool[i];
  349. if (b.size >= size && b.ptr != nullptr) {
  350. void * ptr = b.ptr;
  351. *actual_size = b.size;
  352. b.ptr = nullptr;
  353. b.size = 0;
  354. return ptr;
  355. }
  356. }
  357. void * ptr;
  358. CUDA_CHECK(cudaMalloc((void **) &ptr, size));
  359. *actual_size = size;
  360. return ptr;
  361. }
  362. static void ggml_cuda_pool_free(void * ptr, size_t size) {
  363. scoped_spin_lock lock(g_cuda_pool_lock);
  364. for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
  365. cuda_buffer& b = g_cuda_buffer_pool[i];
  366. if (b.ptr == nullptr) {
  367. b.ptr = ptr;
  368. b.size = size;
  369. return;
  370. }
  371. }
  372. fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
  373. CUDA_CHECK(cudaFree(ptr));
  374. }
  375. #define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication.
  376. #define GGML_CUDA_MAX_EVENTS 64
  377. static cublasHandle_t g_cublasH = nullptr;
  378. static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_STREAMS] = { nullptr };
  379. static cudaStream_t g_cudaStreams2[GGML_CUDA_MAX_STREAMS] = { nullptr };
  380. static cudaEvent_t g_cudaEvents[GGML_CUDA_MAX_EVENTS] = { nullptr };
  381. void ggml_init_cublas() {
  382. if (g_cublasH == nullptr) {
  383. // create streams
  384. for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
  385. CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[i], cudaStreamNonBlocking));
  386. CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams2[i], cudaStreamNonBlocking));
  387. }
  388. // create events
  389. for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
  390. CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents[i], cudaEventDisableTiming));
  391. }
  392. // create cublas handle
  393. CUBLAS_CHECK(cublasCreate(&g_cublasH));
  394. CUBLAS_CHECK(cublasSetMathMode(g_cublasH, CUBLAS_TF32_TENSOR_OP_MATH));
  395. // configure logging to stdout
  396. // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
  397. }
  398. }
  399. void * ggml_cuda_host_malloc(size_t size) {
  400. if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
  401. return nullptr;
  402. }
  403. void * ptr = nullptr;
  404. cudaError_t err = cudaMallocHost((void **) &ptr, size);
  405. if (err != cudaSuccess) {
  406. fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
  407. size/1024.0/1024.0, cudaGetErrorString(err));
  408. return nullptr;
  409. }
  410. return ptr;
  411. }
  412. void ggml_cuda_host_free(void * ptr) {
  413. CUDA_CHECK(cudaFreeHost(ptr));
  414. }
  415. static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
  416. const uint64_t ne0 = src->ne[0];
  417. const uint64_t ne1 = src->ne[1];
  418. const uint64_t nb0 = src->nb[0];
  419. const uint64_t nb1 = src->nb[1];
  420. const uint64_t nb2 = src->nb[2];
  421. const uint64_t nb3 = src->nb[3];
  422. const enum ggml_type type = src->type;
  423. const size_t ts = ggml_type_size(type);
  424. const size_t bs = ggml_blck_size(type);
  425. const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
  426. if (nb0 == ts && nb1 == ts*ne0/bs) {
  427. return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
  428. } else if (nb0 == ts) {
  429. return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
  430. } else {
  431. for (uint64_t i1 = 0; i1 < ne1; i1++) {
  432. const void * rx = (const void *) ((const char *) x + i1*nb1);
  433. void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
  434. // pretend the row is a matrix with cols=1
  435. cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
  436. if (r != cudaSuccess) return r;
  437. }
  438. return cudaSuccess;
  439. }
  440. }
  441. static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  442. const int64_t ne00 = src0->ne[0];
  443. const int64_t ne01 = src0->ne[1];
  444. const int64_t ne02 = src0->ne[2];
  445. const int64_t ne03 = src0->ne[3];
  446. const int64_t ne10 = src1->ne[0];
  447. const int64_t ne11 = src1->ne[1];
  448. const int nb2 = dst->nb[2];
  449. const int nb3 = dst->nb[3];
  450. const float alpha = 1.0f;
  451. const float beta = 0.0f;
  452. const int x_ne = ne01 * ne00;
  453. const int y_ne = ne11 * ne10;
  454. const int d_ne = ne11 * ne01;
  455. const int n_mm = ne03 * ne02;
  456. size_t x_size, y_size, d_size;
  457. float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
  458. float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
  459. float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
  460. for (int64_t i03 = 0; i03 < ne03; i03++) {
  461. for (int64_t i02 = 0; i02 < ne02; i02++) {
  462. int i = i03*ne02 + i02;
  463. cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
  464. float * c_X = d_X + i * x_ne;
  465. float * c_Y = d_Y + i * y_ne;
  466. float * c_D = d_D + i * d_ne;
  467. // copy data to device
  468. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
  469. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
  470. // compute
  471. CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
  472. CUBLAS_CHECK(
  473. cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
  474. ne01, ne11, ne10,
  475. &alpha, c_X, ne00,
  476. c_Y, ne10,
  477. &beta, c_D, ne01));
  478. // copy dst to host
  479. float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
  480. CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
  481. }
  482. }
  483. CUDA_CHECK(cudaDeviceSynchronize());
  484. ggml_cuda_pool_free(d_X, x_size);
  485. ggml_cuda_pool_free(d_Y, y_size);
  486. ggml_cuda_pool_free(d_D, d_size);
  487. }
  488. static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
  489. const int64_t ne00 = src0->ne[0];
  490. const int64_t ne01 = src0->ne[1];
  491. const int64_t ne02 = src0->ne[2];
  492. const int64_t ne03 = src0->ne[3];
  493. const int64_t ne10 = src1->ne[0];
  494. const int64_t ne11 = src1->ne[1];
  495. const int nb10 = src1->nb[0];
  496. const int nb11 = src1->nb[1];
  497. const int nb12 = src1->nb[2];
  498. const int nb13 = src1->nb[3];
  499. const int nb2 = dst->nb[2];
  500. const int nb3 = dst->nb[3];
  501. const float alpha = 1.0f;
  502. const float beta = 0.0f;
  503. const int x_ne = ne01 * ne00;
  504. const int y_ne = ne11 * ne10;
  505. const int d_ne = ne11 * ne01;
  506. const int n_mm = ne03 * ne02;
  507. size_t x_size, y_size, d_size;
  508. half * d_X = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * x_ne, &x_size);
  509. half * d_Y = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * y_ne, &y_size);
  510. float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
  511. bool src1_cont_rows = nb10 == sizeof(float);
  512. bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
  513. for (int64_t i03 = 0; i03 < ne03; i03++) {
  514. for (int64_t i02 = 0; i02 < ne02; i02++) {
  515. int i = i03*ne02 + i02;
  516. cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
  517. half * c_X = d_X + i * x_ne;
  518. half * c_Y = d_Y + i * y_ne;
  519. float * c_D = d_D + i * d_ne;
  520. // copy src0 to device
  521. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
  522. // convert src1 to fp16
  523. // TODO: use multiple threads
  524. ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
  525. char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
  526. if (src1_cont_rows) {
  527. if (src1_cont_cols) {
  528. ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
  529. }
  530. else {
  531. for (int64_t i01 = 0; i01 < ne11; i01++) {
  532. ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
  533. }
  534. }
  535. }
  536. else {
  537. for (int64_t i01 = 0; i01 < ne11; i01++) {
  538. for (int64_t i00 = 0; i00 < ne10; i00++) {
  539. // very slow due to no inlining
  540. tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
  541. }
  542. }
  543. }
  544. // copy src1 to device
  545. CUDA_CHECK(cudaMemcpyAsync(c_Y, tmp, sizeof(half) * y_ne, cudaMemcpyHostToDevice, cudaStream));
  546. // compute
  547. CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
  548. CUBLAS_CHECK(
  549. cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
  550. ne01, ne11, ne10,
  551. &alpha, c_X, CUDA_R_16F, ne00,
  552. c_Y, CUDA_R_16F, ne10,
  553. &beta, c_D, CUDA_R_32F, ne01,
  554. CUBLAS_COMPUTE_32F_FAST_16F,
  555. CUBLAS_GEMM_DEFAULT));
  556. // copy dst to host
  557. float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
  558. CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
  559. }
  560. }
  561. CUDA_CHECK(cudaDeviceSynchronize());
  562. ggml_cuda_pool_free(d_X, x_size);
  563. ggml_cuda_pool_free(d_Y, y_size);
  564. ggml_cuda_pool_free(d_D, d_size);
  565. }
  566. static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  567. const int64_t ne00 = src0->ne[0];
  568. const int64_t ne01 = src0->ne[1];
  569. const int64_t ne02 = src0->ne[2];
  570. const int64_t ne03 = src0->ne[3];
  571. const int64_t ne10 = src1->ne[0];
  572. const int64_t ne11 = src1->ne[1];
  573. const int nb2 = dst->nb[2];
  574. const int nb3 = dst->nb[3];
  575. const ggml_type type = src0->type;
  576. const bool mul_mat_vec = ne11 == 1;
  577. const float alpha = 1.0f;
  578. const float beta = 0.0f;
  579. const int x_ne = ne01 * ne00;
  580. const int y_ne = ne11 * ne10;
  581. const int d_ne = ne11 * ne01;
  582. const int n_mm = ne03 * ne02;
  583. const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
  584. size_t x_size, y_size, d_size, q_size;
  585. float * d_X = nullptr;
  586. if (!mul_mat_vec) {
  587. d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
  588. }
  589. float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
  590. float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
  591. char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
  592. const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
  593. dequantize_mul_mat_vec_cuda_t dmmv = ggml_get_dequantize_mul_mat_vec_cuda(type);
  594. GGML_ASSERT(to_fp32_cuda != nullptr);
  595. for (int64_t i03 = 0; i03 < ne03; i03++) {
  596. for (int64_t i02 = 0; i02 < ne02; i02++) {
  597. int i = i03*ne02 + i02;
  598. cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
  599. cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
  600. cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
  601. float * c_Y = d_Y + i * y_ne;
  602. float * c_D = d_D + i * d_ne;
  603. char * c_Q = d_Q + i * q_sz;
  604. // copy src0 to device if necessary
  605. if (src0->backend == GGML_BACKEND_CPU) {
  606. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
  607. } else if (src0->backend == GGML_BACKEND_CUDA) {
  608. c_Q = ((char *) src0->data) + i * q_sz;
  609. } else {
  610. GGML_ASSERT(false);
  611. }
  612. if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
  613. CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
  614. // copy src1 to device
  615. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
  616. // wait for data
  617. CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
  618. // compute
  619. dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
  620. CUDA_CHECK(cudaGetLastError());
  621. } else { // general dequantization kernel + cuBLAS matrix matrix multiplication
  622. float * c_X = d_X + i * x_ne;
  623. // convert src0 to fp32 on device
  624. to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
  625. CUDA_CHECK(cudaGetLastError());
  626. CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
  627. // copy src1 to device
  628. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
  629. // wait for conversion
  630. CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
  631. // compute
  632. CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
  633. CUBLAS_CHECK(
  634. cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
  635. ne01, ne11, ne10,
  636. &alpha, c_X, ne00,
  637. c_Y, ne10,
  638. &beta, c_D, ne01));
  639. }
  640. // copy dst to host
  641. float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
  642. CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
  643. }
  644. }
  645. CUDA_CHECK(cudaDeviceSynchronize());
  646. if (!mul_mat_vec) {
  647. ggml_cuda_pool_free(d_X, x_size);
  648. }
  649. ggml_cuda_pool_free(d_Y, y_size);
  650. ggml_cuda_pool_free(d_D, d_size);
  651. ggml_cuda_pool_free(d_Q, q_size);
  652. }
  653. bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
  654. const int64_t ne10 = src1->ne[0];
  655. const int64_t ne0 = dst->ne[0];
  656. const int64_t ne1 = dst->ne[1];
  657. // TODO: find the optimal values for these
  658. if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
  659. src1->type == GGML_TYPE_F32 &&
  660. dst->type == GGML_TYPE_F32 &&
  661. ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) {
  662. return true;
  663. }
  664. return false;
  665. }
  666. bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
  667. size_t src0_sz = ggml_nbytes(src0);
  668. size_t src1_sz = ggml_nbytes(src1);
  669. // mul_mat_q: src0 is converted to fp32 on device
  670. size_t mul_mat_q_transfer = src0_sz + src1_sz;
  671. // mul_mat_f16: src1 is converted to fp16 on cpu
  672. size_t mul_mat_f16_transfer = src0_sz + sizeof(half) * ggml_nelements(src1);
  673. // choose the smaller one to transfer to the device
  674. // TODO: this is not always the best choice due to the overhead of converting to fp16
  675. return mul_mat_f16_transfer < mul_mat_q_transfer;
  676. }
  677. void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
  678. GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));
  679. if (src0->type == GGML_TYPE_F32) {
  680. ggml_cuda_mul_mat_f32(src0, src1, dst);
  681. }
  682. else if (src0->type == GGML_TYPE_F16) {
  683. if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
  684. ggml_cuda_mul_mat_f16(src0, src1, dst, wdata, wsize);
  685. }
  686. else {
  687. ggml_cuda_mul_mat_q_f32(src0, src1, dst);
  688. }
  689. }
  690. else if (ggml_is_quantized(src0->type)) {
  691. ggml_cuda_mul_mat_q_f32(src0, src1, dst);
  692. }
  693. else {
  694. GGML_ASSERT(false);
  695. }
  696. }
  697. size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
  698. if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
  699. return ggml_nelements(src1) * sizeof(ggml_fp16_t);
  700. }
  701. else {
  702. return 0;
  703. }
  704. }
  705. void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
  706. const int64_t ne0 = tensor->ne[0];
  707. const int64_t ne1 = tensor->ne[1];
  708. const int64_t ne2 = tensor->ne[2];
  709. const int64_t ne3 = tensor->ne[3];
  710. const ggml_type type = tensor->type;
  711. const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
  712. size_t q_size;
  713. char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
  714. cudaStream_t cudaStream2 = g_cudaStreams2[0];
  715. // copy tensor to device
  716. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2));
  717. CUDA_CHECK(cudaDeviceSynchronize());
  718. tensor->data = d_Q;
  719. tensor->backend = GGML_BACKEND_CUDA;
  720. }