ggml-cuda.cu 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810
  1. #include <cstddef>
  2. #include <cstdint>
  3. #include <stdint.h>
  4. #include <stdio.h>
  5. #include <atomic>
  6. #include <cuda_runtime.h>
  7. #include <cublas_v2.h>
  8. #include <cuda_fp16.h>
  9. #include "ggml-cuda.h"
  10. #include "ggml.h"
  11. static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
  12. #define CUDA_CHECK(err) \
  13. do { \
  14. cudaError_t err_ = (err); \
  15. if (err_ != cudaSuccess) { \
  16. fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
  17. cudaGetErrorString(err_)); \
  18. exit(1); \
  19. } \
  20. } while (0)
  21. #define CUBLAS_CHECK(err) \
  22. do { \
  23. cublasStatus_t err_ = (err); \
  24. if (err_ != CUBLAS_STATUS_SUCCESS) { \
  25. fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
  26. exit(1); \
  27. } \
  28. } while (0)
  29. typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
  30. typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
  31. typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
  32. // QK = number of values after dequantization
  33. // QR = QK / number of values before dequantization
  34. #define QK4_0 32
  35. #define QR4_0 2
  36. typedef struct {
  37. half d; // delta
  38. uint8_t qs[QK4_0 / 2]; // nibbles / quants
  39. } block_q4_0;
  40. static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
  41. #define QK4_1 32
  42. #define QR4_1 2
  43. typedef struct {
  44. half d; // delta
  45. half m; // min
  46. uint8_t qs[QK4_1 / 2]; // nibbles / quants
  47. } block_q4_1;
  48. static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
  49. #define QK5_0 32
  50. #define QR5_0 2
  51. typedef struct {
  52. half d; // delta
  53. uint8_t qh[4]; // 5-th bit of quants
  54. uint8_t qs[QK5_0 / 2]; // nibbles / quants
  55. } block_q5_0;
  56. static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
  57. #define QK5_1 32
  58. #define QR5_1 2
  59. typedef struct {
  60. half d; // delta
  61. half m; // min
  62. uint8_t qh[4]; // 5-th bit of quants
  63. uint8_t qs[QK5_1 / 2]; // nibbles / quants
  64. } block_q5_1;
  65. static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
  66. #define QK8_0 32
  67. #define QR8_0 1
  68. typedef struct {
  69. half d; // delta
  70. int8_t qs[QK8_0]; // quants
  71. } block_q8_0;
  72. static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
  73. #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
  74. #define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
  75. static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  76. const block_q4_0 * x = (const block_q4_0 *) vx;
  77. const float d = x[ib].d;
  78. const uint8_t vui = x[ib].qs[iqs];
  79. const int8_t vi0 = vui & 0xF;
  80. const int8_t vi1 = vui >> 4;
  81. v0 = (vi0 - 8)*d;
  82. v1 = (vi1 - 8)*d;
  83. }
  84. static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  85. const block_q4_1 * x = (const block_q4_1 *) vx;
  86. const float d = x[ib].d;
  87. const float m = x[ib].m;
  88. const uint8_t vui = x[ib].qs[iqs];
  89. const int8_t vi0 = vui & 0xF;
  90. const int8_t vi1 = vui >> 4;
  91. v0 = vi0*d + m;
  92. v1 = vi1*d + m;
  93. }
  94. static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  95. const block_q5_0 * x = (const block_q5_0 *) vx;
  96. const float d = x[ib].d;
  97. uint32_t qh;
  98. memcpy(&qh, x[ib].qh, sizeof(qh));
  99. const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
  100. const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
  101. const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16;
  102. const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16;
  103. v0 = x0*d;
  104. v1 = x1*d;
  105. }
  106. static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  107. const block_q5_1 * x = (const block_q5_1 *) vx;
  108. const float d = x[ib].d;
  109. const float m = x[ib].m;
  110. uint32_t qh;
  111. memcpy(&qh, x[ib].qh, sizeof(qh));
  112. const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
  113. const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
  114. const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
  115. const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1);
  116. v0 = x0*d + m;
  117. v1 = x1*d + m;
  118. }
  119. static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  120. const block_q8_0 * x = (const block_q8_0 *) vx;
  121. const float d = x[ib].d;
  122. const int8_t vi0 = x[ib].qs[iqs + 0];
  123. const int8_t vi1 = x[ib].qs[iqs + 1];
  124. v0 = vi0*d;
  125. v1 = vi1*d;
  126. }
  127. static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  128. const half * x = (const half *) vx;
  129. v0 = __half2float(x[ib + 0]);
  130. v1 = __half2float(x[ib + 1]);
  131. }
  132. template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
  133. static __global__ void dequantize_block(const void * vx, float * y, const int k) {
  134. const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
  135. if (i >= k) {
  136. return;
  137. }
  138. const int ib = i/qk; // block index
  139. const int iqs = (i%qk)/qr; // quant index
  140. const int iybs = i - i%qk; // y block start index
  141. const int y_offset = qr == 1 ? 1 : qk/2;
  142. // dequantize
  143. float & v0 = y[iybs + iqs + 0];
  144. float & v1 = y[iybs + iqs + y_offset];
  145. dequantize_kernel(vx, ib, iqs, v0, v1);
  146. }
  147. template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
  148. static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
  149. const int row = blockIdx.x;
  150. const int tid = threadIdx.x;
  151. const int y_offset = qr == 1 ? 1 : qk/2;
  152. __shared__ float tmp[block_size]; // separate sum for each thread
  153. tmp[tid] = 0;
  154. for (int i = 0; i < ncols/block_size; i += 2) {
  155. const int col = i*block_size + 2*tid;
  156. const int ib = (row*ncols + col)/qk; // block index
  157. const int iqs = (col%qk)/qr; // quant index
  158. const int iybs = col - col%qk; // y block start index
  159. // dequantize
  160. float v0, v1;
  161. dequantize_kernel(vx, ib, iqs, v0, v1);
  162. // matrix multiplication
  163. tmp[tid] += v0 * y[iybs + iqs + 0];
  164. tmp[tid] += v1 * y[iybs + iqs + y_offset];
  165. }
  166. // sum up partial sums and write back result
  167. __syncthreads();
  168. for (int s=block_size/2; s>0; s>>=1) {
  169. if (tid < s) {
  170. tmp[tid] += tmp[tid + s];
  171. }
  172. __syncthreads();
  173. }
  174. if (tid == 0) {
  175. dst[row] = tmp[0];
  176. }
  177. }
  178. static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  179. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  180. dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  181. }
  182. static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  183. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  184. dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  185. }
  186. static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  187. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  188. dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  189. }
  190. static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  191. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  192. dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  193. }
  194. static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  195. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  196. dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  197. }
  198. static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  199. GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
  200. dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0>
  201. <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
  202. }
  203. static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  204. GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
  205. dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1>
  206. <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
  207. }
  208. static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  209. GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
  210. dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0>
  211. <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
  212. }
  213. static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  214. GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
  215. dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1>
  216. <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
  217. }
  218. static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  219. GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
  220. dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0>
  221. <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
  222. }
  223. static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  224. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  225. dequantize_block<32, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  226. }
  227. static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  228. GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
  229. dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32, 1, convert_f16>
  230. <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
  231. }
  232. static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
  233. switch (type) {
  234. case GGML_TYPE_Q4_0:
  235. return dequantize_row_q4_0_cuda;
  236. case GGML_TYPE_Q4_1:
  237. return dequantize_row_q4_1_cuda;
  238. case GGML_TYPE_Q5_0:
  239. return dequantize_row_q5_0_cuda;
  240. case GGML_TYPE_Q5_1:
  241. return dequantize_row_q5_1_cuda;
  242. case GGML_TYPE_Q8_0:
  243. return dequantize_row_q8_0_cuda;
  244. case GGML_TYPE_F16:
  245. return convert_fp16_to_fp32_cuda;
  246. default:
  247. return nullptr;
  248. }
  249. }
  250. static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_type type) {
  251. switch (type) {
  252. case GGML_TYPE_Q4_0:
  253. return dequantize_mul_mat_vec_q4_0_cuda;
  254. case GGML_TYPE_Q4_1:
  255. return dequantize_mul_mat_vec_q4_1_cuda;
  256. case GGML_TYPE_Q5_0:
  257. return dequantize_mul_mat_vec_q5_0_cuda;
  258. case GGML_TYPE_Q5_1:
  259. return dequantize_mul_mat_vec_q5_1_cuda;
  260. case GGML_TYPE_Q8_0:
  261. return dequantize_mul_mat_vec_q8_0_cuda;
  262. case GGML_TYPE_F16:
  263. return convert_mul_mat_vec_f16_cuda;
  264. default:
  265. return nullptr;
  266. }
  267. }
  268. // buffer pool for cuda
  269. #define MAX_CUDA_BUFFERS 256
  270. struct scoped_spin_lock {
  271. std::atomic_flag& lock;
  272. scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
  273. while (lock.test_and_set(std::memory_order_acquire)) {
  274. ; // spin
  275. }
  276. }
  277. ~scoped_spin_lock() {
  278. lock.clear(std::memory_order_release);
  279. }
  280. scoped_spin_lock(const scoped_spin_lock&) = delete;
  281. scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
  282. };
  283. struct cuda_buffer {
  284. void * ptr = nullptr;
  285. size_t size = 0;
  286. };
  287. static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
  288. static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
  289. static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
  290. scoped_spin_lock lock(g_cuda_pool_lock);
  291. for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
  292. cuda_buffer& b = g_cuda_buffer_pool[i];
  293. if (b.size >= size && b.ptr != nullptr) {
  294. void * ptr = b.ptr;
  295. *actual_size = b.size;
  296. b.ptr = nullptr;
  297. b.size = 0;
  298. return ptr;
  299. }
  300. }
  301. void * ptr;
  302. CUDA_CHECK(cudaMalloc((void **) &ptr, size));
  303. *actual_size = size;
  304. return ptr;
  305. }
  306. static void ggml_cuda_pool_free(void * ptr, size_t size) {
  307. scoped_spin_lock lock(g_cuda_pool_lock);
  308. for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
  309. cuda_buffer& b = g_cuda_buffer_pool[i];
  310. if (b.ptr == nullptr) {
  311. b.ptr = ptr;
  312. b.size = size;
  313. return;
  314. }
  315. }
  316. fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
  317. CUDA_CHECK(cudaFree(ptr));
  318. }
  319. #define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication.
  320. #define GGML_CUDA_MAX_EVENTS 64
  321. static cublasHandle_t g_cublasH = nullptr;
  322. static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_STREAMS] = { nullptr };
  323. static cudaStream_t g_cudaStreams2[GGML_CUDA_MAX_STREAMS] = { nullptr };
  324. static cudaEvent_t g_cudaEvents[GGML_CUDA_MAX_EVENTS] = { nullptr };
  325. void ggml_init_cublas() {
  326. if (g_cublasH == nullptr) {
  327. // create streams
  328. for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
  329. CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[i], cudaStreamNonBlocking));
  330. CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams2[i], cudaStreamNonBlocking));
  331. }
  332. // create events
  333. for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
  334. CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents[i], cudaEventDisableTiming));
  335. }
  336. // create cublas handle
  337. CUBLAS_CHECK(cublasCreate(&g_cublasH));
  338. CUBLAS_CHECK(cublasSetMathMode(g_cublasH, CUBLAS_TF32_TENSOR_OP_MATH));
  339. // configure logging to stdout
  340. // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
  341. }
  342. }
  343. void * ggml_cuda_host_malloc(size_t size) {
  344. if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
  345. return nullptr;
  346. }
  347. void * ptr = nullptr;
  348. cudaError_t err = cudaMallocHost((void **) &ptr, size);
  349. if (err != cudaSuccess) {
  350. fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
  351. size/1024.0/1024.0, cudaGetErrorString(err));
  352. return nullptr;
  353. }
  354. return ptr;
  355. }
  356. void ggml_cuda_host_free(void * ptr) {
  357. CUDA_CHECK(cudaFreeHost(ptr));
  358. }
  359. static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
  360. const uint64_t ne0 = src->ne[0];
  361. const uint64_t ne1 = src->ne[1];
  362. const uint64_t nb0 = src->nb[0];
  363. const uint64_t nb1 = src->nb[1];
  364. const uint64_t nb2 = src->nb[2];
  365. const uint64_t nb3 = src->nb[3];
  366. const enum ggml_type type = src->type;
  367. const size_t ts = ggml_type_size(type);
  368. const size_t bs = ggml_blck_size(type);
  369. const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
  370. if (nb0 == ts && nb1 == ts*ne0/bs) {
  371. return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
  372. } else if (nb0 == ts) {
  373. return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
  374. } else {
  375. for (uint64_t i1 = 0; i1 < ne1; i1++) {
  376. const void * rx = (const void *) ((const char *) x + i1*nb1);
  377. void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
  378. // pretend the row is a matrix with cols=1
  379. cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
  380. if (r != cudaSuccess) return r;
  381. }
  382. return cudaSuccess;
  383. }
  384. }
  385. static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  386. const int64_t ne00 = src0->ne[0];
  387. const int64_t ne01 = src0->ne[1];
  388. const int64_t ne02 = src0->ne[2];
  389. const int64_t ne03 = src0->ne[3];
  390. const int64_t ne10 = src1->ne[0];
  391. const int64_t ne11 = src1->ne[1];
  392. const int nb2 = dst->nb[2];
  393. const int nb3 = dst->nb[3];
  394. const float alpha = 1.0f;
  395. const float beta = 0.0f;
  396. const int x_ne = ne01 * ne00;
  397. const int y_ne = ne11 * ne10;
  398. const int d_ne = ne11 * ne01;
  399. const int n_mm = ne03 * ne02;
  400. size_t x_size, y_size, d_size;
  401. float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
  402. float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
  403. float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
  404. for (int64_t i03 = 0; i03 < ne03; i03++) {
  405. for (int64_t i02 = 0; i02 < ne02; i02++) {
  406. int i = i03*ne02 + i02;
  407. cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
  408. float * c_X = d_X + i * x_ne;
  409. float * c_Y = d_Y + i * y_ne;
  410. float * c_D = d_D + i * d_ne;
  411. // copy data to device
  412. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
  413. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
  414. // compute
  415. CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
  416. CUBLAS_CHECK(
  417. cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
  418. ne01, ne11, ne10,
  419. &alpha, c_X, ne00,
  420. c_Y, ne10,
  421. &beta, c_D, ne01));
  422. // copy dst to host
  423. float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
  424. CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
  425. }
  426. }
  427. CUDA_CHECK(cudaDeviceSynchronize());
  428. ggml_cuda_pool_free(d_X, x_size);
  429. ggml_cuda_pool_free(d_Y, y_size);
  430. ggml_cuda_pool_free(d_D, d_size);
  431. }
  432. static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
  433. const int64_t ne00 = src0->ne[0];
  434. const int64_t ne01 = src0->ne[1];
  435. const int64_t ne02 = src0->ne[2];
  436. const int64_t ne03 = src0->ne[3];
  437. const int64_t ne10 = src1->ne[0];
  438. const int64_t ne11 = src1->ne[1];
  439. const int nb10 = src1->nb[0];
  440. const int nb11 = src1->nb[1];
  441. const int nb12 = src1->nb[2];
  442. const int nb13 = src1->nb[3];
  443. const int nb2 = dst->nb[2];
  444. const int nb3 = dst->nb[3];
  445. const float alpha = 1.0f;
  446. const float beta = 0.0f;
  447. const int x_ne = ne01 * ne00;
  448. const int y_ne = ne11 * ne10;
  449. const int d_ne = ne11 * ne01;
  450. const int n_mm = ne03 * ne02;
  451. size_t x_size, y_size, d_size;
  452. half * d_X = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * x_ne, &x_size);
  453. half * d_Y = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * y_ne, &y_size);
  454. float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
  455. bool src1_cont_rows = nb10 == sizeof(float);
  456. bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
  457. for (int64_t i03 = 0; i03 < ne03; i03++) {
  458. for (int64_t i02 = 0; i02 < ne02; i02++) {
  459. int i = i03*ne02 + i02;
  460. cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
  461. half * c_X = d_X + i * x_ne;
  462. half * c_Y = d_Y + i * y_ne;
  463. float * c_D = d_D + i * d_ne;
  464. // copy src0 to device
  465. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
  466. // convert src1 to fp16
  467. // TODO: use multiple threads
  468. ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
  469. char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
  470. if (src1_cont_rows) {
  471. if (src1_cont_cols) {
  472. ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
  473. }
  474. else {
  475. for (int64_t i01 = 0; i01 < ne11; i01++) {
  476. ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
  477. }
  478. }
  479. }
  480. else {
  481. for (int64_t i01 = 0; i01 < ne11; i01++) {
  482. for (int64_t i00 = 0; i00 < ne10; i00++) {
  483. // very slow due to no inlining
  484. tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
  485. }
  486. }
  487. }
  488. // copy src1 to device
  489. CUDA_CHECK(cudaMemcpyAsync(c_Y, tmp, sizeof(half) * y_ne, cudaMemcpyHostToDevice, cudaStream));
  490. // compute
  491. CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
  492. CUBLAS_CHECK(
  493. cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
  494. ne01, ne11, ne10,
  495. &alpha, c_X, CUDA_R_16F, ne00,
  496. c_Y, CUDA_R_16F, ne10,
  497. &beta, c_D, CUDA_R_32F, ne01,
  498. CUBLAS_COMPUTE_32F_FAST_16F,
  499. CUBLAS_GEMM_DEFAULT));
  500. // copy dst to host
  501. float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
  502. CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
  503. }
  504. }
  505. CUDA_CHECK(cudaDeviceSynchronize());
  506. ggml_cuda_pool_free(d_X, x_size);
  507. ggml_cuda_pool_free(d_Y, y_size);
  508. ggml_cuda_pool_free(d_D, d_size);
  509. }
  510. static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  511. const int64_t ne00 = src0->ne[0];
  512. const int64_t ne01 = src0->ne[1];
  513. const int64_t ne02 = src0->ne[2];
  514. const int64_t ne03 = src0->ne[3];
  515. const int64_t ne10 = src1->ne[0];
  516. const int64_t ne11 = src1->ne[1];
  517. const int nb2 = dst->nb[2];
  518. const int nb3 = dst->nb[3];
  519. const ggml_type type = src0->type;
  520. const bool mul_mat_vec = ne11 == 1;
  521. const float alpha = 1.0f;
  522. const float beta = 0.0f;
  523. const int x_ne = ne01 * ne00;
  524. const int y_ne = ne11 * ne10;
  525. const int d_ne = ne11 * ne01;
  526. const int n_mm = ne03 * ne02;
  527. const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
  528. size_t x_size, y_size, d_size, q_size;
  529. float * d_X = nullptr;
  530. if (!mul_mat_vec) {
  531. d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
  532. }
  533. float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
  534. float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
  535. char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
  536. const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
  537. dequantize_mul_mat_vec_cuda_t dmmv = ggml_get_dequantize_mul_mat_vec_cuda(type);
  538. GGML_ASSERT(to_fp32_cuda != nullptr);
  539. for (int64_t i03 = 0; i03 < ne03; i03++) {
  540. for (int64_t i02 = 0; i02 < ne02; i02++) {
  541. int i = i03*ne02 + i02;
  542. cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
  543. cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
  544. cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
  545. float * c_Y = d_Y + i * y_ne;
  546. float * c_D = d_D + i * d_ne;
  547. char * c_Q = d_Q + i * q_sz;
  548. // copy src0 to device if necessary
  549. if (src0->backend == GGML_BACKEND_CPU) {
  550. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
  551. } else if (src0->backend == GGML_BACKEND_CUDA) {
  552. c_Q = ((char *) src0->data) + i * q_sz;
  553. } else {
  554. GGML_ASSERT(false);
  555. }
  556. if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
  557. CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
  558. // copy src1 to device
  559. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
  560. // wait for data
  561. CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
  562. // compute
  563. dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
  564. CUDA_CHECK(cudaGetLastError());
  565. } else { // general dequantization kernel + cuBLAS matrix matrix multiplication
  566. float * c_X = d_X + i * x_ne;
  567. // convert src0 to fp32 on device
  568. to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
  569. CUDA_CHECK(cudaGetLastError());
  570. CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
  571. // copy src1 to device
  572. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
  573. // wait for conversion
  574. CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
  575. // compute
  576. CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
  577. CUBLAS_CHECK(
  578. cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
  579. ne01, ne11, ne10,
  580. &alpha, c_X, ne00,
  581. c_Y, ne10,
  582. &beta, c_D, ne01));
  583. }
  584. // copy dst to host
  585. float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
  586. CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
  587. }
  588. }
  589. CUDA_CHECK(cudaDeviceSynchronize());
  590. if (!mul_mat_vec) {
  591. ggml_cuda_pool_free(d_X, x_size);
  592. }
  593. ggml_cuda_pool_free(d_Y, y_size);
  594. ggml_cuda_pool_free(d_D, d_size);
  595. ggml_cuda_pool_free(d_Q, q_size);
  596. }
  597. bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
  598. const int64_t ne10 = src1->ne[0];
  599. const int64_t ne0 = dst->ne[0];
  600. const int64_t ne1 = dst->ne[1];
  601. // TODO: find the optimal values for these
  602. if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
  603. src1->type == GGML_TYPE_F32 &&
  604. dst->type == GGML_TYPE_F32 &&
  605. ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) {
  606. return true;
  607. }
  608. return false;
  609. }
  610. bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
  611. size_t src0_sz = ggml_nbytes(src0);
  612. size_t src1_sz = ggml_nbytes(src1);
  613. // mul_mat_q: src0 is converted to fp32 on device
  614. size_t mul_mat_q_transfer = src0_sz + src1_sz;
  615. // mul_mat_f16: src1 is converted to fp16 on cpu
  616. size_t mul_mat_f16_transfer = src0_sz + sizeof(half) * ggml_nelements(src1);
  617. // choose the smaller one to transfer to the device
  618. // TODO: this is not always the best choice due to the overhead of converting to fp16
  619. return mul_mat_f16_transfer < mul_mat_q_transfer;
  620. }
  621. void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
  622. GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));
  623. if (src0->type == GGML_TYPE_F32) {
  624. ggml_cuda_mul_mat_f32(src0, src1, dst);
  625. }
  626. else if (src0->type == GGML_TYPE_F16) {
  627. if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
  628. ggml_cuda_mul_mat_f16(src0, src1, dst, wdata, wsize);
  629. }
  630. else {
  631. ggml_cuda_mul_mat_q_f32(src0, src1, dst);
  632. }
  633. }
  634. else if (ggml_is_quantized(src0->type)) {
  635. ggml_cuda_mul_mat_q_f32(src0, src1, dst);
  636. }
  637. else {
  638. GGML_ASSERT(false);
  639. }
  640. }
  641. size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
  642. if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
  643. return ggml_nelements(src1) * sizeof(ggml_fp16_t);
  644. }
  645. else {
  646. return 0;
  647. }
  648. }
  649. void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
  650. const int64_t ne0 = tensor->ne[0];
  651. const int64_t ne1 = tensor->ne[1];
  652. const int64_t ne2 = tensor->ne[2];
  653. const int64_t ne3 = tensor->ne[3];
  654. const ggml_type type = tensor->type;
  655. const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
  656. size_t q_size;
  657. char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
  658. cudaStream_t cudaStream2 = g_cudaStreams2[0];
  659. // copy tensor to device
  660. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2));
  661. CUDA_CHECK(cudaDeviceSynchronize());
  662. tensor->data = d_Q;
  663. tensor->backend = GGML_BACKEND_CUDA;
  664. }