ggml-cuda.cu 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925
  1. #include <cstddef>
  2. #include <cstdint>
  3. #include <stdint.h>
  4. #include <stdio.h>
  5. #include <atomic>
  6. #include <cuda_runtime.h>
  7. #include <cublas_v2.h>
  8. #include <cuda_fp16.h>
  9. #include "ggml-cuda.h"
  10. #include "ggml.h"
  11. static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
  12. #define CUDA_CHECK(err) \
  13. do { \
  14. cudaError_t err_ = (err); \
  15. if (err_ != cudaSuccess) { \
  16. fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
  17. cudaGetErrorString(err_)); \
  18. exit(1); \
  19. } \
  20. } while (0)
  21. #define CUBLAS_CHECK(err) \
  22. do { \
  23. cublasStatus_t err_ = (err); \
  24. if (err_ != CUBLAS_STATUS_SUCCESS) { \
  25. fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
  26. exit(1); \
  27. } \
  28. } while (0)
  29. typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
  30. typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
  31. typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
  32. // QK = number of values after dequantization
  33. // QR = QK / number of values before dequantization
  34. #define QK4_0 32
  35. #define QR4_0 2
  36. typedef struct {
  37. half d; // delta
  38. uint8_t qs[QK4_0 / 2]; // nibbles / quants
  39. } block_q4_0;
  40. static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
  41. #define QK4_1 32
  42. #define QR4_1 2
  43. typedef struct {
  44. half d; // delta
  45. half m; // min
  46. uint8_t qs[QK4_1 / 2]; // nibbles / quants
  47. } block_q4_1;
  48. static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
  49. #define QK5_0 32
  50. #define QR5_0 2
  51. typedef struct {
  52. half d; // delta
  53. uint8_t qh[4]; // 5-th bit of quants
  54. uint8_t qs[QK5_0 / 2]; // nibbles / quants
  55. } block_q5_0;
  56. static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
  57. #define QK5_1 32
  58. #define QR5_1 2
  59. typedef struct {
  60. half d; // delta
  61. half m; // min
  62. uint8_t qh[4]; // 5-th bit of quants
  63. uint8_t qs[QK5_1 / 2]; // nibbles / quants
  64. } block_q5_1;
  65. static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
  66. #define QK8_0 32
  67. #define QR8_0 1
  68. typedef struct {
  69. half d; // delta
  70. int8_t qs[QK8_0]; // quants
  71. } block_q8_0;
  72. static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
  73. #define CUDA_MUL_BLOCK_SIZE 256
  74. #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
  75. #define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
  76. static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
  77. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  78. if (i >= kx) {
  79. return;
  80. }
  81. dst[i] = x[i] * y[i%ky];
  82. }
  83. static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  84. const block_q4_0 * x = (const block_q4_0 *) vx;
  85. const float d = x[ib].d;
  86. const uint8_t vui = x[ib].qs[iqs];
  87. const int8_t vi0 = vui & 0xF;
  88. const int8_t vi1 = vui >> 4;
  89. v0 = (vi0 - 8)*d;
  90. v1 = (vi1 - 8)*d;
  91. }
  92. static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  93. const block_q4_1 * x = (const block_q4_1 *) vx;
  94. const float d = x[ib].d;
  95. const float m = x[ib].m;
  96. const uint8_t vui = x[ib].qs[iqs];
  97. const int8_t vi0 = vui & 0xF;
  98. const int8_t vi1 = vui >> 4;
  99. v0 = vi0*d + m;
  100. v1 = vi1*d + m;
  101. }
  102. static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  103. const block_q5_0 * x = (const block_q5_0 *) vx;
  104. const float d = x[ib].d;
  105. uint32_t qh;
  106. memcpy(&qh, x[ib].qh, sizeof(qh));
  107. const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
  108. const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
  109. const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16;
  110. const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16;
  111. v0 = x0*d;
  112. v1 = x1*d;
  113. }
  114. static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  115. const block_q5_1 * x = (const block_q5_1 *) vx;
  116. const float d = x[ib].d;
  117. const float m = x[ib].m;
  118. uint32_t qh;
  119. memcpy(&qh, x[ib].qh, sizeof(qh));
  120. const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
  121. const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
  122. const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
  123. const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1);
  124. v0 = x0*d + m;
  125. v1 = x1*d + m;
  126. }
  127. static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  128. const block_q8_0 * x = (const block_q8_0 *) vx;
  129. const float d = x[ib].d;
  130. const int8_t vi0 = x[ib].qs[iqs + 0];
  131. const int8_t vi1 = x[ib].qs[iqs + 1];
  132. v0 = vi0*d;
  133. v1 = vi1*d;
  134. }
  135. static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  136. const half * x = (const half *) vx;
  137. v0 = __half2float(x[ib + 0]);
  138. v1 = __half2float(x[ib + 1]);
  139. }
  140. template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
  141. static __global__ void dequantize_block(const void * vx, float * y, const int k) {
  142. const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
  143. if (i >= k) {
  144. return;
  145. }
  146. const int ib = i/qk; // block index
  147. const int iqs = (i%qk)/qr; // quant index
  148. const int iybs = i - i%qk; // y block start index
  149. const int y_offset = qr == 1 ? 1 : qk/2;
  150. // dequantize
  151. float & v0 = y[iybs + iqs + 0];
  152. float & v1 = y[iybs + iqs + y_offset];
  153. dequantize_kernel(vx, ib, iqs, v0, v1);
  154. }
  155. template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
  156. static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
  157. const int row = blockIdx.x;
  158. const int tid = threadIdx.x;
  159. const int y_offset = qr == 1 ? 1 : qk/2;
  160. __shared__ float tmp[block_size]; // separate sum for each thread
  161. tmp[tid] = 0;
  162. for (int i = 0; i < ncols/block_size; i += 2) {
  163. const int col = i*block_size + 2*tid;
  164. const int ib = (row*ncols + col)/qk; // block index
  165. const int iqs = (col%qk)/qr; // quant index
  166. const int iybs = col - col%qk; // y block start index
  167. // dequantize
  168. float v0, v1;
  169. dequantize_kernel(vx, ib, iqs, v0, v1);
  170. // matrix multiplication
  171. tmp[tid] += v0 * y[iybs + iqs + 0];
  172. tmp[tid] += v1 * y[iybs + iqs + y_offset];
  173. }
  174. // sum up partial sums and write back result
  175. __syncthreads();
  176. for (int s=block_size/2; s>0; s>>=1) {
  177. if (tid < s) {
  178. tmp[tid] += tmp[tid + s];
  179. }
  180. __syncthreads();
  181. }
  182. if (tid == 0) {
  183. dst[row] = tmp[0];
  184. }
  185. }
  186. static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
  187. const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
  188. mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
  189. }
  190. static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  191. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  192. dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  193. }
  194. static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  195. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  196. dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  197. }
  198. static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  199. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  200. dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  201. }
  202. static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  203. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  204. dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  205. }
  206. static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  207. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  208. dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  209. }
  210. static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  211. GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
  212. dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0>
  213. <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
  214. }
  215. static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  216. GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
  217. dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1>
  218. <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
  219. }
  220. static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  221. GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
  222. dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0>
  223. <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
  224. }
  225. static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  226. GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
  227. dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1>
  228. <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
  229. }
  230. static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  231. GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
  232. dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0>
  233. <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
  234. }
  235. static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  236. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  237. dequantize_block<32, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  238. }
  239. static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  240. GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
  241. dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32, 1, convert_f16>
  242. <<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
  243. }
  244. static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
  245. switch (type) {
  246. case GGML_TYPE_Q4_0:
  247. return dequantize_row_q4_0_cuda;
  248. case GGML_TYPE_Q4_1:
  249. return dequantize_row_q4_1_cuda;
  250. case GGML_TYPE_Q5_0:
  251. return dequantize_row_q5_0_cuda;
  252. case GGML_TYPE_Q5_1:
  253. return dequantize_row_q5_1_cuda;
  254. case GGML_TYPE_Q8_0:
  255. return dequantize_row_q8_0_cuda;
  256. case GGML_TYPE_F16:
  257. return convert_fp16_to_fp32_cuda;
  258. default:
  259. return nullptr;
  260. }
  261. }
  262. static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_type type) {
  263. switch (type) {
  264. case GGML_TYPE_Q4_0:
  265. return dequantize_mul_mat_vec_q4_0_cuda;
  266. case GGML_TYPE_Q4_1:
  267. return dequantize_mul_mat_vec_q4_1_cuda;
  268. case GGML_TYPE_Q5_0:
  269. return dequantize_mul_mat_vec_q5_0_cuda;
  270. case GGML_TYPE_Q5_1:
  271. return dequantize_mul_mat_vec_q5_1_cuda;
  272. case GGML_TYPE_Q8_0:
  273. return dequantize_mul_mat_vec_q8_0_cuda;
  274. case GGML_TYPE_F16:
  275. return convert_mul_mat_vec_f16_cuda;
  276. default:
  277. return nullptr;
  278. }
  279. }
  280. // buffer pool for cuda
  281. #define MAX_CUDA_BUFFERS 256
  282. struct scoped_spin_lock {
  283. std::atomic_flag& lock;
  284. scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
  285. while (lock.test_and_set(std::memory_order_acquire)) {
  286. ; // spin
  287. }
  288. }
  289. ~scoped_spin_lock() {
  290. lock.clear(std::memory_order_release);
  291. }
  292. scoped_spin_lock(const scoped_spin_lock&) = delete;
  293. scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
  294. };
  295. struct cuda_buffer {
  296. void * ptr = nullptr;
  297. size_t size = 0;
  298. };
  299. static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
  300. static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
  301. static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
  302. scoped_spin_lock lock(g_cuda_pool_lock);
  303. for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
  304. cuda_buffer& b = g_cuda_buffer_pool[i];
  305. if (b.size >= size && b.ptr != nullptr) {
  306. void * ptr = b.ptr;
  307. *actual_size = b.size;
  308. b.ptr = nullptr;
  309. b.size = 0;
  310. return ptr;
  311. }
  312. }
  313. void * ptr;
  314. CUDA_CHECK(cudaMalloc((void **) &ptr, size));
  315. *actual_size = size;
  316. return ptr;
  317. }
  318. static void ggml_cuda_pool_free(void * ptr, size_t size) {
  319. scoped_spin_lock lock(g_cuda_pool_lock);
  320. for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
  321. cuda_buffer& b = g_cuda_buffer_pool[i];
  322. if (b.ptr == nullptr) {
  323. b.ptr = ptr;
  324. b.size = size;
  325. return;
  326. }
  327. }
  328. fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
  329. CUDA_CHECK(cudaFree(ptr));
  330. }
  331. #define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication.
  332. #define GGML_CUDA_MAX_EVENTS 64
  333. static cublasHandle_t g_cublasH = nullptr;
  334. static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_STREAMS] = { nullptr };
  335. static cudaStream_t g_cudaStreams2[GGML_CUDA_MAX_STREAMS] = { nullptr };
  336. static cudaEvent_t g_cudaEvents[GGML_CUDA_MAX_EVENTS] = { nullptr };
  337. void ggml_init_cublas() {
  338. if (g_cublasH == nullptr) {
  339. // create streams
  340. for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
  341. CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[i], cudaStreamNonBlocking));
  342. CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams2[i], cudaStreamNonBlocking));
  343. }
  344. // create events
  345. for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
  346. CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents[i], cudaEventDisableTiming));
  347. }
  348. // create cublas handle
  349. CUBLAS_CHECK(cublasCreate(&g_cublasH));
  350. CUBLAS_CHECK(cublasSetMathMode(g_cublasH, CUBLAS_TF32_TENSOR_OP_MATH));
  351. // configure logging to stdout
  352. // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
  353. }
  354. }
  355. void * ggml_cuda_host_malloc(size_t size) {
  356. if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
  357. return nullptr;
  358. }
  359. void * ptr = nullptr;
  360. cudaError_t err = cudaMallocHost((void **) &ptr, size);
  361. if (err != cudaSuccess) {
  362. fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
  363. size/1024.0/1024.0, cudaGetErrorString(err));
  364. return nullptr;
  365. }
  366. return ptr;
  367. }
  368. void ggml_cuda_host_free(void * ptr) {
  369. CUDA_CHECK(cudaFreeHost(ptr));
  370. }
  371. static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
  372. const uint64_t ne0 = src->ne[0];
  373. const uint64_t ne1 = src->ne[1];
  374. const uint64_t nb0 = src->nb[0];
  375. const uint64_t nb1 = src->nb[1];
  376. const uint64_t nb2 = src->nb[2];
  377. const uint64_t nb3 = src->nb[3];
  378. const enum ggml_type type = src->type;
  379. const size_t ts = ggml_type_size(type);
  380. const size_t bs = ggml_blck_size(type);
  381. const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
  382. if (nb0 == ts && nb1 == ts*ne0/bs) {
  383. return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
  384. } else if (nb0 == ts) {
  385. return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
  386. } else {
  387. for (uint64_t i1 = 0; i1 < ne1; i1++) {
  388. const void * rx = (const void *) ((const char *) x + i1*nb1);
  389. void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
  390. // pretend the row is a matrix with cols=1
  391. cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
  392. if (r != cudaSuccess) return r;
  393. }
  394. return cudaSuccess;
  395. }
  396. }
  397. static void ggml_cuda_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  398. GGML_ASSERT(src1->backend == GGML_BACKEND_CUDA);
  399. const int64_t ne00 = src0->ne[0];
  400. const int64_t ne01 = src0->ne[1];
  401. const int64_t ne02 = src0->ne[2];
  402. const int64_t ne03 = src0->ne[2];
  403. const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
  404. const int64_t ne10 = src1->ne[0];
  405. const int64_t ne11 = src1->ne[1];
  406. const int64_t ne12 = src1->ne[2];
  407. const int64_t ne13 = src1->ne[3];
  408. const int nb2 = dst->nb[2];
  409. const int nb3 = dst->nb[3];
  410. size_t x_size, d_size;
  411. float * d_X = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &x_size); // src0
  412. float * d_Y = (float *) src1->data; // src1 is already on device, broadcasted.
  413. float * d_D = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &d_size); // dst
  414. for (int64_t i03 = 0; i03 < ne03; i03++) {
  415. for (int64_t i02 = 0; i02 < ne02; i02++) {
  416. const int i0 = i03*ne02 + i02;
  417. float * c_X2 = d_X + i0*ne01*ne00;
  418. float * c_D2 = d_D + i0*ne01*ne00;
  419. cudaStream_t cudaStream = g_cudaStreams[i0 % GGML_CUDA_MAX_STREAMS];
  420. cudaStream_t cudaStream2 = g_cudaStreams2[i0 % GGML_CUDA_MAX_STREAMS];
  421. cudaEvent_t cudaEvent = g_cudaEvents[i0 % GGML_CUDA_MAX_EVENTS];
  422. // copy src0 to device
  423. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X2, src0, i03, i02, cudaStream2));
  424. CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
  425. // wait for data
  426. CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
  427. for (int64_t i01 = 0; i01 < ne01; i01++) {
  428. const int64_t i13 = i03%ne13;
  429. const int64_t i12 = i02%ne12;
  430. const int64_t i11 = i01%ne11;
  431. const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
  432. float * c_X1 = c_X2 + i01*ne00;
  433. float * c_Y = d_Y + i1*ne10;
  434. float * c_D1 = c_D2 + i01*ne00;
  435. // compute
  436. mul_f32_cuda(c_X1, c_Y, c_D1, ne00, ne10, cudaStream);
  437. CUDA_CHECK(cudaGetLastError());
  438. }
  439. // copy dst to host
  440. float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
  441. CUDA_CHECK(cudaMemcpyAsync(d, c_D2, sizeof(float)*ne00*ne01, cudaMemcpyDeviceToHost, cudaStream));
  442. }
  443. }
  444. CUDA_CHECK(cudaDeviceSynchronize());
  445. ggml_cuda_pool_free(d_X, x_size);
  446. ggml_cuda_pool_free(d_D, d_size);
  447. }
  448. static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  449. const int64_t ne00 = src0->ne[0];
  450. const int64_t ne01 = src0->ne[1];
  451. const int64_t ne02 = src0->ne[2];
  452. const int64_t ne03 = src0->ne[3];
  453. const int64_t ne10 = src1->ne[0];
  454. const int64_t ne11 = src1->ne[1];
  455. const int nb2 = dst->nb[2];
  456. const int nb3 = dst->nb[3];
  457. const float alpha = 1.0f;
  458. const float beta = 0.0f;
  459. const int x_ne = ne01 * ne00;
  460. const int y_ne = ne11 * ne10;
  461. const int d_ne = ne11 * ne01;
  462. const int n_mm = ne03 * ne02;
  463. size_t x_size, y_size, d_size;
  464. float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
  465. float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
  466. float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
  467. for (int64_t i03 = 0; i03 < ne03; i03++) {
  468. for (int64_t i02 = 0; i02 < ne02; i02++) {
  469. int i = i03*ne02 + i02;
  470. cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
  471. float * c_X = d_X + i * x_ne;
  472. float * c_Y = d_Y + i * y_ne;
  473. float * c_D = d_D + i * d_ne;
  474. // copy data to device
  475. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
  476. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
  477. // compute
  478. CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
  479. CUBLAS_CHECK(
  480. cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
  481. ne01, ne11, ne10,
  482. &alpha, c_X, ne00,
  483. c_Y, ne10,
  484. &beta, c_D, ne01));
  485. // copy dst to host
  486. float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
  487. CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
  488. }
  489. }
  490. CUDA_CHECK(cudaDeviceSynchronize());
  491. ggml_cuda_pool_free(d_X, x_size);
  492. ggml_cuda_pool_free(d_Y, y_size);
  493. ggml_cuda_pool_free(d_D, d_size);
  494. }
  495. static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
  496. const int64_t ne00 = src0->ne[0];
  497. const int64_t ne01 = src0->ne[1];
  498. const int64_t ne02 = src0->ne[2];
  499. const int64_t ne03 = src0->ne[3];
  500. const int64_t ne10 = src1->ne[0];
  501. const int64_t ne11 = src1->ne[1];
  502. const int nb10 = src1->nb[0];
  503. const int nb11 = src1->nb[1];
  504. const int nb12 = src1->nb[2];
  505. const int nb13 = src1->nb[3];
  506. const int nb2 = dst->nb[2];
  507. const int nb3 = dst->nb[3];
  508. const float alpha = 1.0f;
  509. const float beta = 0.0f;
  510. const int x_ne = ne01 * ne00;
  511. const int y_ne = ne11 * ne10;
  512. const int d_ne = ne11 * ne01;
  513. const int n_mm = ne03 * ne02;
  514. size_t x_size, y_size, d_size;
  515. half * d_X = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * x_ne, &x_size);
  516. half * d_Y = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * y_ne, &y_size);
  517. float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
  518. bool src1_cont_rows = nb10 == sizeof(float);
  519. bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
  520. for (int64_t i03 = 0; i03 < ne03; i03++) {
  521. for (int64_t i02 = 0; i02 < ne02; i02++) {
  522. int i = i03*ne02 + i02;
  523. cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
  524. half * c_X = d_X + i * x_ne;
  525. half * c_Y = d_Y + i * y_ne;
  526. float * c_D = d_D + i * d_ne;
  527. // copy src0 to device
  528. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
  529. // convert src1 to fp16
  530. // TODO: use multiple threads
  531. ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
  532. char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
  533. if (src1_cont_rows) {
  534. if (src1_cont_cols) {
  535. ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
  536. }
  537. else {
  538. for (int64_t i01 = 0; i01 < ne11; i01++) {
  539. ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
  540. }
  541. }
  542. }
  543. else {
  544. for (int64_t i01 = 0; i01 < ne11; i01++) {
  545. for (int64_t i00 = 0; i00 < ne10; i00++) {
  546. // very slow due to no inlining
  547. tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
  548. }
  549. }
  550. }
  551. // copy src1 to device
  552. CUDA_CHECK(cudaMemcpyAsync(c_Y, tmp, sizeof(half) * y_ne, cudaMemcpyHostToDevice, cudaStream));
  553. // compute
  554. CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
  555. CUBLAS_CHECK(
  556. cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
  557. ne01, ne11, ne10,
  558. &alpha, c_X, CUDA_R_16F, ne00,
  559. c_Y, CUDA_R_16F, ne10,
  560. &beta, c_D, CUDA_R_32F, ne01,
  561. CUBLAS_COMPUTE_32F_FAST_16F,
  562. CUBLAS_GEMM_DEFAULT));
  563. // copy dst to host
  564. float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
  565. CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
  566. }
  567. }
  568. CUDA_CHECK(cudaDeviceSynchronize());
  569. ggml_cuda_pool_free(d_X, x_size);
  570. ggml_cuda_pool_free(d_Y, y_size);
  571. ggml_cuda_pool_free(d_D, d_size);
  572. }
  573. static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  574. const int64_t ne00 = src0->ne[0];
  575. const int64_t ne01 = src0->ne[1];
  576. const int64_t ne02 = src0->ne[2];
  577. const int64_t ne03 = src0->ne[3];
  578. const int64_t ne10 = src1->ne[0];
  579. const int64_t ne11 = src1->ne[1];
  580. const int nb2 = dst->nb[2];
  581. const int nb3 = dst->nb[3];
  582. const ggml_type type = src0->type;
  583. const bool mul_mat_vec = ne11 == 1;
  584. const float alpha = 1.0f;
  585. const float beta = 0.0f;
  586. const int x_ne = ne01 * ne00;
  587. const int y_ne = ne11 * ne10;
  588. const int d_ne = ne11 * ne01;
  589. const int n_mm = ne03 * ne02;
  590. const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
  591. size_t x_size, y_size, d_size, q_size;
  592. float * d_X = nullptr;
  593. if (!mul_mat_vec) {
  594. d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
  595. }
  596. float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
  597. float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
  598. char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
  599. const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
  600. dequantize_mul_mat_vec_cuda_t dmmv = ggml_get_dequantize_mul_mat_vec_cuda(type);
  601. GGML_ASSERT(to_fp32_cuda != nullptr);
  602. for (int64_t i03 = 0; i03 < ne03; i03++) {
  603. for (int64_t i02 = 0; i02 < ne02; i02++) {
  604. int i = i03*ne02 + i02;
  605. cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
  606. cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
  607. cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
  608. float * c_Y = d_Y + i * y_ne;
  609. float * c_D = d_D + i * d_ne;
  610. char * c_Q = d_Q + i * q_sz;
  611. // copy src0 to device if necessary
  612. if (src0->backend == GGML_BACKEND_CPU) {
  613. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
  614. } else if (src0->backend == GGML_BACKEND_CUDA) {
  615. c_Q = ((char *) src0->data) + i * q_sz;
  616. } else {
  617. GGML_ASSERT(false);
  618. }
  619. if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
  620. CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
  621. // copy src1 to device
  622. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
  623. // wait for data
  624. CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
  625. // compute
  626. dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
  627. CUDA_CHECK(cudaGetLastError());
  628. } else { // general dequantization kernel + cuBLAS matrix matrix multiplication
  629. float * c_X = d_X + i * x_ne;
  630. // convert src0 to fp32 on device
  631. to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
  632. CUDA_CHECK(cudaGetLastError());
  633. CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
  634. // copy src1 to device
  635. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
  636. // wait for conversion
  637. CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
  638. // compute
  639. CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
  640. CUBLAS_CHECK(
  641. cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
  642. ne01, ne11, ne10,
  643. &alpha, c_X, ne00,
  644. c_Y, ne10,
  645. &beta, c_D, ne01));
  646. }
  647. // copy dst to host
  648. float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
  649. CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
  650. }
  651. }
  652. CUDA_CHECK(cudaDeviceSynchronize());
  653. if (!mul_mat_vec) {
  654. ggml_cuda_pool_free(d_X, x_size);
  655. }
  656. ggml_cuda_pool_free(d_Y, y_size);
  657. ggml_cuda_pool_free(d_D, d_size);
  658. ggml_cuda_pool_free(d_Q, q_size);
  659. }
  660. void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
  661. GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
  662. ggml_cuda_mul_f32(src0, src1, dst);
  663. }
  664. bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
  665. const int64_t ne10 = src1->ne[0];
  666. const int64_t ne0 = dst->ne[0];
  667. const int64_t ne1 = dst->ne[1];
  668. // TODO: find the optimal values for these
  669. if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
  670. src1->type == GGML_TYPE_F32 &&
  671. dst->type == GGML_TYPE_F32 &&
  672. ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) {
  673. return true;
  674. }
  675. return false;
  676. }
  677. bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
  678. size_t src0_sz = ggml_nbytes(src0);
  679. size_t src1_sz = ggml_nbytes(src1);
  680. // mul_mat_q: src0 is converted to fp32 on device
  681. size_t mul_mat_q_transfer = src0_sz + src1_sz;
  682. // mul_mat_f16: src1 is converted to fp16 on cpu
  683. size_t mul_mat_f16_transfer = src0_sz + sizeof(half) * ggml_nelements(src1);
  684. // choose the smaller one to transfer to the device
  685. // TODO: this is not always the best choice due to the overhead of converting to fp16
  686. return mul_mat_f16_transfer < mul_mat_q_transfer;
  687. }
  688. void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
  689. GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));
  690. if (src0->type == GGML_TYPE_F32) {
  691. ggml_cuda_mul_mat_f32(src0, src1, dst);
  692. }
  693. else if (src0->type == GGML_TYPE_F16) {
  694. if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
  695. ggml_cuda_mul_mat_f16(src0, src1, dst, wdata, wsize);
  696. }
  697. else {
  698. ggml_cuda_mul_mat_q_f32(src0, src1, dst);
  699. }
  700. }
  701. else if (ggml_is_quantized(src0->type)) {
  702. ggml_cuda_mul_mat_q_f32(src0, src1, dst);
  703. }
  704. else {
  705. GGML_ASSERT(false);
  706. }
  707. }
  708. size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
  709. if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
  710. return ggml_nelements(src1) * sizeof(ggml_fp16_t);
  711. }
  712. else {
  713. return 0;
  714. }
  715. }
  716. void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
  717. const int64_t ne0 = tensor->ne[0];
  718. const int64_t ne1 = tensor->ne[1];
  719. const int64_t ne2 = tensor->ne[2];
  720. const int64_t ne3 = tensor->ne[3];
  721. const ggml_type type = tensor->type;
  722. const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
  723. size_t q_size;
  724. char * dst = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
  725. cudaStream_t cudaStream2 = g_cudaStreams2[0];
  726. // copy tensor to device
  727. for (int64_t i3 = 0; i3 < ne3; i3++) {
  728. for (int64_t i2 = 0; i2 < ne2; i2++) {
  729. int i = i3*ne2 + i2;
  730. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(dst + i*ne0*ne1, tensor, i3, i2, cudaStream2));
  731. }
  732. }
  733. tensor->data = dst;
  734. tensor->backend = GGML_BACKEND_CUDA;
  735. }
  736. void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
  737. FILE * fp = fopen(fname, "rb");
  738. const size_t size = ggml_nbytes(tensor);
  739. void * buf;
  740. CUDA_CHECK(cudaMalloc(&buf, size));
  741. void * buf_host = malloc(size);
  742. #ifdef _WIN32
  743. int ret = _fseeki64(fp, (__int64) offset, SEEK_SET);
  744. #else
  745. int ret = fseek(fp, (long) offset, SEEK_SET);
  746. #endif
  747. GGML_ASSERT(ret == 0); // same
  748. size_t ret2 = fread(buf_host, size, 1, fp);
  749. if (ret2 != 1) {
  750. fprintf(stderr, "unexpectedly reached end of file");
  751. exit(1);
  752. }
  753. cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
  754. cudaDeviceSynchronize();
  755. tensor->data = buf;
  756. free(buf_host);
  757. fclose(fp);
  758. }