ggml-cuda.cu 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657
  1. #include <cstddef>
  2. #include <cstdint>
  3. #include <stdint.h>
  4. #include <stdio.h>
  5. #include <atomic>
  6. #include <cuda_runtime.h>
  7. #include <cublas_v2.h>
  8. #include <cuda_fp16.h>
  9. #include "ggml-cuda.h"
  10. #include "ggml.h"
  11. static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
  12. #define CUDA_CHECK(err) \
  13. do { \
  14. cudaError_t err_ = (err); \
  15. if (err_ != cudaSuccess) { \
  16. fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
  17. cudaGetErrorString(err_)); \
  18. exit(1); \
  19. } \
  20. } while (0)
  21. #define CUBLAS_CHECK(err) \
  22. do { \
  23. cublasStatus_t err_ = (err); \
  24. if (err_ != CUBLAS_STATUS_SUCCESS) { \
  25. fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
  26. exit(1); \
  27. } \
  28. } while (0)
  29. typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
  30. #define QK4_0 32
  31. typedef struct {
  32. float d; // delta
  33. uint8_t qs[QK4_0 / 2]; // nibbles / quants
  34. } block_q4_0;
  35. static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
  36. #define QK4_1 32
  37. typedef struct {
  38. float d; // delta
  39. float m; // min
  40. uint8_t qs[QK4_1 / 2]; // nibbles / quants
  41. } block_q4_1;
  42. static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
  43. #define QK5_0 32
  44. typedef struct {
  45. half d; // delta
  46. uint8_t qh[4]; // 5-th bit of quants
  47. uint8_t qs[QK5_0 / 2]; // nibbles / quants
  48. } block_q5_0;
  49. static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
  50. #define QK5_1 32
  51. typedef struct {
  52. half d; // delta
  53. half m; // min
  54. uint8_t qh[4]; // 5-th bit of quants
  55. uint8_t qs[QK5_1 / 2]; // nibbles / quants
  56. } block_q5_1;
  57. static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
  58. #define QK8_0 32
  59. typedef struct {
  60. float d; // delta
  61. int8_t qs[QK8_0]; // quants
  62. } block_q8_0;
  63. static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
  64. static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
  65. static const int qk = QK4_0;
  66. const block_q4_0 * x = (const block_q4_0 *) vx;
  67. const int i = blockIdx.x;
  68. const float d = x[i].d;
  69. for (int j = 0; j < qk/2; ++j) {
  70. const int x0 = (x[i].qs[j] & 0xf) - 8;
  71. const int x1 = (x[i].qs[j] >> 4) - 8;
  72. y[i*qk + j + 0 ] = x0*d;
  73. y[i*qk + j + qk/2] = x1*d;
  74. }
  75. }
  76. static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
  77. static const int qk = QK4_1;
  78. const block_q4_1 * x = (const block_q4_1 *) vx;
  79. const int i = blockIdx.x;
  80. const float d = x[i].d;
  81. const float m = x[i].m;
  82. for (int j = 0; j < qk/2; ++j) {
  83. const int x0 = (x[i].qs[j] & 0xf);
  84. const int x1 = (x[i].qs[j] >> 4);
  85. y[i*qk + j + 0 ] = x0*d + m;
  86. y[i*qk + j + qk/2] = x1*d + m;
  87. }
  88. }
  89. static __global__ void dequantize_block_q5_0(const void * vx, float * y) {
  90. static const int qk = QK5_0;
  91. const block_q5_0 * x = (const block_q5_0 *) vx;
  92. const int i = blockIdx.x;
  93. const float d = x[i].d;
  94. uint32_t qh;
  95. memcpy(&qh, x[i].qh, sizeof(qh));
  96. for (int j = 0; j < qk/2; ++j) {
  97. const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
  98. const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
  99. const int32_t x0 = ((x[i].qs[j] & 0xf) | xh_0) - 16;
  100. const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
  101. y[i*qk + j + 0 ] = x0*d;
  102. y[i*qk + j + qk/2] = x1*d;
  103. }
  104. }
  105. static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
  106. static const int qk = QK5_1;
  107. const block_q5_1 * x = (const block_q5_1 *) vx;
  108. const int i = blockIdx.x;
  109. const float d = x[i].d;
  110. const float m = x[i].m;
  111. uint32_t qh;
  112. memcpy(&qh, x[i].qh, sizeof(qh));
  113. for (int j = 0; j < qk/2; ++j) {
  114. const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
  115. const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
  116. const int x0 = (x[i].qs[j] & 0xf) | xh_0;
  117. const int x1 = (x[i].qs[j] >> 4) | xh_1;
  118. y[i*qk + j + 0 ] = x0*d + m;
  119. y[i*qk + j + qk/2] = x1*d + m;
  120. }
  121. }
  122. static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
  123. static const int qk = QK8_0;
  124. const block_q8_0 * x = (const block_q8_0 *) vx;
  125. const int i = blockIdx.x;
  126. const float d = x[i].d;
  127. for (int j = 0; j < qk; ++j) {
  128. y[i*qk + j] = x[i].qs[j]*d;
  129. }
  130. }
  131. static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
  132. const int nb = k / QK4_0;
  133. dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
  134. }
  135. static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
  136. const int nb = k / QK4_1;
  137. dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
  138. }
  139. static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
  140. const int nb = k / QK5_0;
  141. dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
  142. }
  143. static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
  144. const int nb = k / QK5_1;
  145. dequantize_block_q5_1<<<nb, 1, 0, stream>>>(vx, y);
  146. }
  147. static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
  148. const int nb = k / QK8_0;
  149. dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
  150. }
  151. // TODO: optimize
  152. static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
  153. const half * x = (const half *) vx;
  154. const int i = blockIdx.x;
  155. y[i] = __half2float(x[i]);
  156. }
  157. static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStream_t stream) {
  158. convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y);
  159. }
  160. static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
  161. switch (type) {
  162. case GGML_TYPE_Q4_0:
  163. return dequantize_row_q4_0_cuda;
  164. case GGML_TYPE_Q4_1:
  165. return dequantize_row_q4_1_cuda;
  166. case GGML_TYPE_Q5_0:
  167. return dequantize_row_q5_0_cuda;
  168. case GGML_TYPE_Q5_1:
  169. return dequantize_row_q5_1_cuda;
  170. case GGML_TYPE_Q8_0:
  171. return dequantize_row_q8_0_cuda;
  172. case GGML_TYPE_F16:
  173. return convert_fp16_to_fp32_cuda;
  174. default:
  175. return nullptr;
  176. }
  177. }
  178. // buffer pool for cuda
  179. #define MAX_CUDA_BUFFERS 16
  180. struct scoped_spin_lock {
  181. std::atomic_flag& lock;
  182. scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
  183. while (lock.test_and_set(std::memory_order_acquire)) {
  184. ; // spin
  185. }
  186. }
  187. ~scoped_spin_lock() {
  188. lock.clear(std::memory_order_release);
  189. }
  190. scoped_spin_lock(const scoped_spin_lock&) = delete;
  191. scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
  192. };
  193. struct cuda_buffer {
  194. void * ptr = nullptr;
  195. size_t size = 0;
  196. };
  197. static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
  198. static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
  199. static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
  200. scoped_spin_lock lock(g_cuda_pool_lock);
  201. for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
  202. cuda_buffer& b = g_cuda_buffer_pool[i];
  203. if (b.size >= size && b.ptr != nullptr) {
  204. void * ptr = b.ptr;
  205. *actual_size = b.size;
  206. b.ptr = nullptr;
  207. b.size = 0;
  208. return ptr;
  209. }
  210. }
  211. void * ptr;
  212. CUDA_CHECK(cudaMalloc((void **) &ptr, size));
  213. *actual_size = size;
  214. return ptr;
  215. }
  216. static void ggml_cuda_pool_free(void * ptr, size_t size) {
  217. scoped_spin_lock lock(g_cuda_pool_lock);
  218. for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
  219. cuda_buffer& b = g_cuda_buffer_pool[i];
  220. if (b.ptr == nullptr) {
  221. b.ptr = ptr;
  222. b.size = size;
  223. return;
  224. }
  225. }
  226. fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
  227. CUDA_CHECK(cudaFree(ptr));
  228. }
  229. #define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication.
  230. #define GGML_CUDA_MAX_EVENTS 64
  231. static cublasHandle_t g_cublasH = nullptr;
  232. static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_STREAMS] = { nullptr };
  233. static cudaStream_t g_cudaStreams2[GGML_CUDA_MAX_STREAMS] = { nullptr };
  234. static cudaEvent_t g_cudaEvents[GGML_CUDA_MAX_EVENTS] = { nullptr };
  235. void ggml_init_cublas() {
  236. if (g_cublasH == nullptr) {
  237. // create streams
  238. for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
  239. CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[i], cudaStreamNonBlocking));
  240. CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams2[i], cudaStreamNonBlocking));
  241. }
  242. // create events
  243. for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
  244. CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents[i], cudaEventDisableTiming));
  245. }
  246. // create cublas handle
  247. CUBLAS_CHECK(cublasCreate(&g_cublasH));
  248. CUBLAS_CHECK(cublasSetMathMode(g_cublasH, CUBLAS_TF32_TENSOR_OP_MATH));
  249. // configure logging to stdout
  250. // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
  251. }
  252. }
  253. void * ggml_cuda_host_malloc(size_t size) {
  254. if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
  255. return nullptr;
  256. }
  257. void * ptr = nullptr;
  258. cudaError_t err = cudaMallocHost((void **) &ptr, size);
  259. if (err != cudaSuccess) {
  260. fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
  261. size/1024.0/1024.0, cudaGetErrorString(err));
  262. return nullptr;
  263. }
  264. return ptr;
  265. }
  266. void ggml_cuda_host_free(void * ptr) {
  267. CUDA_CHECK(cudaFreeHost(ptr));
  268. }
  269. static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
  270. const uint64_t ne0 = src->ne[0];
  271. const uint64_t ne1 = src->ne[1];
  272. const uint64_t nb0 = src->nb[0];
  273. const uint64_t nb1 = src->nb[1];
  274. const uint64_t nb2 = src->nb[2];
  275. const uint64_t nb3 = src->nb[3];
  276. const enum ggml_type type = src->type;
  277. const size_t ts = ggml_type_size(type);
  278. const size_t bs = ggml_blck_size(type);
  279. const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
  280. if (nb0 == ts && nb1 == ts*ne0/bs) {
  281. return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
  282. } else if (nb0 == ts) {
  283. return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
  284. } else {
  285. for (uint64_t i1 = 0; i1 < ne1; i1++) {
  286. const void * rx = (const void *) ((const char *) x + i1*nb1);
  287. void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
  288. // pretend the row is a matrix with cols=1
  289. cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
  290. if (r != cudaSuccess) return r;
  291. }
  292. return cudaSuccess;
  293. }
  294. }
  295. static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  296. const int64_t ne00 = src0->ne[0];
  297. const int64_t ne01 = src0->ne[1];
  298. const int64_t ne02 = src0->ne[2];
  299. const int64_t ne03 = src0->ne[3];
  300. const int64_t ne10 = src1->ne[0];
  301. const int64_t ne11 = src1->ne[1];
  302. const int nb2 = dst->nb[2];
  303. const int nb3 = dst->nb[3];
  304. const float alpha = 1.0f;
  305. const float beta = 0.0f;
  306. const int x_ne = ne01 * ne00;
  307. const int y_ne = ne11 * ne10;
  308. const int d_ne = ne11 * ne01;
  309. const int n_mm = ne03 * ne02;
  310. size_t x_size, y_size, d_size;
  311. float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
  312. float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
  313. float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
  314. for (int64_t i03 = 0; i03 < ne03; i03++) {
  315. for (int64_t i02 = 0; i02 < ne02; i02++) {
  316. int i = i03*ne02 + i02;
  317. cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
  318. float * c_X = d_X + i * x_ne;
  319. float * c_Y = d_Y + i * y_ne;
  320. float * c_D = d_D + i * d_ne;
  321. // copy data to device
  322. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
  323. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
  324. // compute
  325. CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
  326. CUBLAS_CHECK(
  327. cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
  328. ne01, ne11, ne10,
  329. &alpha, c_X, ne00,
  330. c_Y, ne10,
  331. &beta, c_D, ne01));
  332. // copy dst to host
  333. float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
  334. CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
  335. }
  336. }
  337. CUDA_CHECK(cudaDeviceSynchronize());
  338. ggml_cuda_pool_free(d_X, x_size);
  339. ggml_cuda_pool_free(d_Y, y_size);
  340. ggml_cuda_pool_free(d_D, d_size);
  341. }
  342. static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
  343. const int64_t ne00 = src0->ne[0];
  344. const int64_t ne01 = src0->ne[1];
  345. const int64_t ne02 = src0->ne[2];
  346. const int64_t ne03 = src0->ne[3];
  347. const int64_t ne10 = src1->ne[0];
  348. const int64_t ne11 = src1->ne[1];
  349. const int nb10 = src1->nb[0];
  350. const int nb11 = src1->nb[1];
  351. const int nb12 = src1->nb[2];
  352. const int nb13 = src1->nb[3];
  353. const int nb2 = dst->nb[2];
  354. const int nb3 = dst->nb[3];
  355. const float alpha = 1.0f;
  356. const float beta = 0.0f;
  357. const int x_ne = ne01 * ne00;
  358. const int y_ne = ne11 * ne10;
  359. const int d_ne = ne11 * ne01;
  360. const int n_mm = ne03 * ne02;
  361. size_t x_size, y_size, d_size;
  362. half * d_X = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * x_ne, &x_size);
  363. half * d_Y = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * y_ne, &y_size);
  364. float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
  365. bool src1_cont_rows = nb10 == sizeof(float);
  366. bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
  367. for (int64_t i03 = 0; i03 < ne03; i03++) {
  368. for (int64_t i02 = 0; i02 < ne02; i02++) {
  369. int i = i03*ne02 + i02;
  370. cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
  371. half * c_X = d_X + i * x_ne;
  372. half * c_Y = d_Y + i * y_ne;
  373. float * c_D = d_D + i * d_ne;
  374. // copy src0 to device
  375. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
  376. // convert src1 to fp16
  377. // TODO: use multiple threads
  378. ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
  379. char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
  380. if (src1_cont_rows) {
  381. if (src1_cont_cols) {
  382. ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
  383. }
  384. else {
  385. for (int64_t i01 = 0; i01 < ne11; i01++) {
  386. ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
  387. }
  388. }
  389. }
  390. else {
  391. for (int64_t i01 = 0; i01 < ne11; i01++) {
  392. for (int64_t i00 = 0; i00 < ne10; i00++) {
  393. // very slow due to no inlining
  394. tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
  395. }
  396. }
  397. }
  398. // copy src1 to device
  399. CUDA_CHECK(cudaMemcpyAsync(c_Y, tmp, sizeof(half) * y_ne, cudaMemcpyHostToDevice, cudaStream));
  400. // compute
  401. CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
  402. CUBLAS_CHECK(
  403. cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
  404. ne01, ne11, ne10,
  405. &alpha, c_X, CUDA_R_16F, ne00,
  406. c_Y, CUDA_R_16F, ne10,
  407. &beta, c_D, CUDA_R_32F, ne01,
  408. CUBLAS_COMPUTE_32F_FAST_16F,
  409. CUBLAS_GEMM_DEFAULT));
  410. // copy dst to host
  411. float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
  412. CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
  413. }
  414. }
  415. CUDA_CHECK(cudaDeviceSynchronize());
  416. ggml_cuda_pool_free(d_X, x_size);
  417. ggml_cuda_pool_free(d_Y, y_size);
  418. ggml_cuda_pool_free(d_D, d_size);
  419. }
  420. static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  421. const int64_t ne00 = src0->ne[0];
  422. const int64_t ne01 = src0->ne[1];
  423. const int64_t ne02 = src0->ne[2];
  424. const int64_t ne03 = src0->ne[3];
  425. const int64_t ne10 = src1->ne[0];
  426. const int64_t ne11 = src1->ne[1];
  427. const int nb2 = dst->nb[2];
  428. const int nb3 = dst->nb[3];
  429. const ggml_type type = src0->type;
  430. const float alpha = 1.0f;
  431. const float beta = 0.0f;
  432. const int x_ne = ne01 * ne00;
  433. const int y_ne = ne11 * ne10;
  434. const int d_ne = ne11 * ne01;
  435. const int n_mm = ne03 * ne02;
  436. const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
  437. size_t x_size, y_size, d_size, q_size;
  438. float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
  439. float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
  440. float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
  441. char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
  442. const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
  443. GGML_ASSERT(to_fp32_cuda != nullptr);
  444. for (int64_t i03 = 0; i03 < ne03; i03++) {
  445. for (int64_t i02 = 0; i02 < ne02; i02++) {
  446. int i = i03*ne02 + i02;
  447. cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
  448. cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
  449. cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
  450. float * c_X = d_X + i * x_ne;
  451. float * c_Y = d_Y + i * y_ne;
  452. float * c_D = d_D + i * d_ne;
  453. char * c_Q = d_Q + i * q_sz;
  454. // copy src0 and convert to fp32 on device
  455. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
  456. to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
  457. CUDA_CHECK(cudaGetLastError());
  458. CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
  459. // copy src1 to device
  460. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
  461. // wait for conversion
  462. CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
  463. // compute
  464. CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
  465. CUBLAS_CHECK(
  466. cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
  467. ne01, ne11, ne10,
  468. &alpha, c_X, ne00,
  469. c_Y, ne10,
  470. &beta, c_D, ne01));
  471. // copy dst to host
  472. float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
  473. CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
  474. }
  475. }
  476. CUDA_CHECK(cudaDeviceSynchronize());
  477. ggml_cuda_pool_free(d_X, x_size);
  478. ggml_cuda_pool_free(d_Y, y_size);
  479. ggml_cuda_pool_free(d_D, d_size);
  480. ggml_cuda_pool_free(d_Q, q_size);
  481. }
  482. bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
  483. const int64_t ne10 = src1->ne[0];
  484. const int64_t ne0 = dst->ne[0];
  485. const int64_t ne1 = dst->ne[1];
  486. // TODO: find the optimal values for these
  487. if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
  488. src1->type == GGML_TYPE_F32 &&
  489. dst->type == GGML_TYPE_F32 &&
  490. (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
  491. return true;
  492. }
  493. return false;
  494. }
  495. bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
  496. size_t src0_sz = ggml_nbytes(src0);
  497. size_t src1_sz = ggml_nbytes(src1);
  498. // mul_mat_q: src0 is converted to fp32 on device
  499. size_t mul_mat_q_transfer = src0_sz + src1_sz;
  500. // mul_mat_f16: src1 is converted to fp16 on cpu
  501. size_t mul_mat_f16_transfer = src0_sz + sizeof(half) * ggml_nelements(src1);
  502. // choose the smaller one to transfer to the device
  503. // TODO: this is not always the best choice due to the overhead of converting to fp16
  504. return mul_mat_f16_transfer < mul_mat_q_transfer;
  505. }
  506. void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
  507. GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));
  508. if (src0->type == GGML_TYPE_F32) {
  509. ggml_cuda_mul_mat_f32(src0, src1, dst);
  510. }
  511. else if (src0->type == GGML_TYPE_F16) {
  512. if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
  513. ggml_cuda_mul_mat_f16(src0, src1, dst, wdata, wsize);
  514. }
  515. else {
  516. ggml_cuda_mul_mat_q_f32(src0, src1, dst);
  517. }
  518. }
  519. else if (ggml_is_quantized(src0->type)) {
  520. ggml_cuda_mul_mat_q_f32(src0, src1, dst);
  521. }
  522. else {
  523. GGML_ASSERT(false);
  524. }
  525. }
  526. size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
  527. if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
  528. return ggml_nelements(src1) * sizeof(ggml_fp16_t);
  529. }
  530. else {
  531. return 0;
  532. }
  533. }