ggml-cuda.cu 54 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448
  1. #include <cstddef>
  2. #include <cstdint>
  3. #include <stdint.h>
  4. #include <stdio.h>
  5. #include <atomic>
  6. #include <assert.h>
  7. #include <cuda_runtime.h>
  8. #include <cublas_v2.h>
  9. #include <cuda_fp16.h>
  10. #include "ggml-cuda.h"
  11. #include "ggml.h"
  12. static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
  13. #define CUDA_CHECK(err) \
  14. do { \
  15. cudaError_t err_ = (err); \
  16. if (err_ != cudaSuccess) { \
  17. fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
  18. cudaGetErrorString(err_)); \
  19. exit(1); \
  20. } \
  21. } while (0)
  22. #define CUBLAS_CHECK(err) \
  23. do { \
  24. cublasStatus_t err_ = (err); \
  25. if (err_ != CUBLAS_STATUS_SUCCESS) { \
  26. fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
  27. exit(1); \
  28. } \
  29. } while (0)
  30. typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
  31. typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
  32. typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
  33. typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
  34. // QK = number of values after dequantization
  35. // QR = QK / number of values before dequantization
  36. #define QK4_0 32
  37. #define QR4_0 2
  38. typedef struct {
  39. half d; // delta
  40. uint8_t qs[QK4_0 / 2]; // nibbles / quants
  41. } block_q4_0;
  42. static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
  43. #define QK4_1 32
  44. #define QR4_1 2
  45. typedef struct {
  46. half d; // delta
  47. half m; // min
  48. uint8_t qs[QK4_1 / 2]; // nibbles / quants
  49. } block_q4_1;
  50. static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
  51. #define QK5_0 32
  52. #define QR5_0 2
  53. typedef struct {
  54. half d; // delta
  55. uint8_t qh[4]; // 5-th bit of quants
  56. uint8_t qs[QK5_0 / 2]; // nibbles / quants
  57. } block_q5_0;
  58. static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
  59. #define QK5_1 32
  60. #define QR5_1 2
  61. typedef struct {
  62. half d; // delta
  63. half m; // min
  64. uint8_t qh[4]; // 5-th bit of quants
  65. uint8_t qs[QK5_1 / 2]; // nibbles / quants
  66. } block_q5_1;
  67. static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
  68. #define QK8_0 32
  69. #define QR8_0 1
  70. typedef struct {
  71. half d; // delta
  72. int8_t qs[QK8_0]; // quants
  73. } block_q8_0;
  74. static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
  75. //================================= k-quants
  76. #define QK_K 256
  77. typedef struct {
  78. uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
  79. uint8_t qs[QK_K/4]; // quants
  80. half d; // super-block scale for quantized scales
  81. half dmin; // super-block scale for quantized mins
  82. } block_q2_k;
  83. static_assert(sizeof(block_q2_k) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_k block size/padding");
  84. typedef struct {
  85. uint8_t hmask[QK_K/8];
  86. uint8_t qs[QK_K/4]; // nibbles / quants
  87. uint8_t scales[3*QK_K/64];
  88. half d;
  89. } block_q3_k;
  90. static_assert(sizeof(block_q3_k) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_k block size/padding");
  91. typedef struct {
  92. half d; // super-block scale for quantized scales
  93. half dmin; // super-block scale for quantized mins
  94. uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
  95. uint8_t qs[QK_K/2]; // 4--bit quants
  96. } block_q4_k;
  97. static_assert(sizeof(block_q4_k) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_k block size/padding");
  98. typedef struct {
  99. half d; // super-block scale for quantized scales
  100. half dmin; // super-block scale for quantized mins
  101. uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
  102. uint8_t qh[QK_K/8]; // quants, high bit
  103. uint8_t qs[QK_K/2]; // quants, low 4 bits
  104. } block_q5_k;
  105. static_assert(sizeof(block_q5_k) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_k block size/padding");
  106. typedef struct {
  107. uint8_t ql[QK_K/2]; // quants, lower 4 bits
  108. uint8_t qh[QK_K/4]; // quants, upper 2 bits
  109. int8_t scales[QK_K/16]; // scales
  110. half d; // delta
  111. } block_q6_k;
  112. static_assert(sizeof(block_q6_k) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_k block size/padding");
  113. #define WARP_SIZE 32
  114. #define CUDA_MUL_BLOCK_SIZE 256
  115. #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
  116. // dmmv = dequantize_mul_mat_vec
  117. #ifndef GGML_CUDA_DMMV_X
  118. #define GGML_CUDA_DMMV_X 32
  119. #endif
  120. #ifndef GGML_CUDA_DMMV_Y
  121. #define GGML_CUDA_DMMV_Y 1
  122. #endif
  123. static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
  124. const int i = blockDim.x*blockIdx.x + threadIdx.x;
  125. if (i >= kx) {
  126. return;
  127. }
  128. dst[i] = x[i] * y[i%ky];
  129. }
  130. static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  131. const block_q4_0 * x = (const block_q4_0 *) vx;
  132. const float d = x[ib].d;
  133. const uint8_t vui = x[ib].qs[iqs];
  134. const int8_t vi0 = vui & 0xF;
  135. const int8_t vi1 = vui >> 4;
  136. v0 = (vi0 - 8)*d;
  137. v1 = (vi1 - 8)*d;
  138. }
  139. static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  140. const block_q4_1 * x = (const block_q4_1 *) vx;
  141. const float d = x[ib].d;
  142. const float m = x[ib].m;
  143. const uint8_t vui = x[ib].qs[iqs];
  144. const int8_t vi0 = vui & 0xF;
  145. const int8_t vi1 = vui >> 4;
  146. v0 = vi0*d + m;
  147. v1 = vi1*d + m;
  148. }
  149. static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  150. const block_q5_0 * x = (const block_q5_0 *) vx;
  151. const float d = x[ib].d;
  152. uint32_t qh;
  153. memcpy(&qh, x[ib].qh, sizeof(qh));
  154. const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
  155. const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
  156. const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16;
  157. const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16;
  158. v0 = x0*d;
  159. v1 = x1*d;
  160. }
  161. static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  162. const block_q5_1 * x = (const block_q5_1 *) vx;
  163. const float d = x[ib].d;
  164. const float m = x[ib].m;
  165. uint32_t qh;
  166. memcpy(&qh, x[ib].qh, sizeof(qh));
  167. const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
  168. const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
  169. const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
  170. const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1);
  171. v0 = x0*d + m;
  172. v1 = x1*d + m;
  173. }
  174. static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  175. const block_q8_0 * x = (const block_q8_0 *) vx;
  176. const float d = x[ib].d;
  177. const int8_t vi0 = x[ib].qs[iqs + 0];
  178. const int8_t vi1 = x[ib].qs[iqs + 1];
  179. v0 = vi0*d;
  180. v1 = vi1*d;
  181. }
  182. //================================== k-quants
  183. static __global__ void dequantize_block_q2_k(const void * vx, float * yy) {
  184. const int i = blockIdx.x;
  185. const int tid = threadIdx.x;
  186. const int n = tid/32;
  187. const int l = tid - 32*n;
  188. const int is = 8*n + l/16;
  189. const block_q2_k * x = (const block_q2_k *) vx;
  190. const uint8_t q = x[i].qs[32*n + l];
  191. float * y = yy + i*QK_K + 128*n;
  192. float dall = x[i].d;
  193. float dmin = x[i].dmin;
  194. y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
  195. y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
  196. y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
  197. y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
  198. }
  199. static __device__ void vec_dot_q2_k(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
  200. const block_q2_k * x = (const block_q2_k *) vx;
  201. // if n is 0, we want to do the lower 128, else the upper 128,
  202. // covering y[l+0], y[l+32], y[l+64], y[l+96] and
  203. // y[l+16], y[l+48], y[l+80], y[l+112]
  204. int n = iqs/128; // 0 or 1
  205. int r = iqs - 128*n; // 0...120 in steps of 8
  206. int l = r/8; // 0...15 in steps of 1
  207. const float * y = yy + 128*n + l;
  208. const uint8_t * q = x[ib].qs + 32*n + l;
  209. const uint8_t * s = x[ib].scales + 8*n;
  210. const float dall = x[ib].d;
  211. const float dmin = x[ib].dmin;
  212. float sum = y[ 0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4))
  213. + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
  214. + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
  215. + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
  216. + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
  217. + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
  218. + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
  219. + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
  220. result = sum;
  221. }
  222. static __global__ void dequantize_block_q3_k(const void * vx, float * yy) {
  223. int r = threadIdx.x/4;
  224. int i = blockIdx.x;
  225. int tid = r/2;
  226. int is0 = r%2;
  227. int l0 = 16*is0 + 4*(threadIdx.x%4);
  228. int n = tid / 4;
  229. int j = tid - 4*n;
  230. const block_q3_k * x = (const block_q3_k *) vx;
  231. uint8_t m = 1 << (4*n + j);
  232. int is = 8*n + 2*j + is0;
  233. int shift = 2*j;
  234. int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
  235. is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
  236. is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
  237. (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
  238. float d_all = x[i].d;
  239. float dl = d_all * (us - 32);
  240. float * y = yy + i*QK_K + 128*n + 32*j;
  241. const uint8_t * q = x[i].qs + 32*n;
  242. const uint8_t * hm = x[i].hmask;
  243. for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
  244. }
  245. static __device__ void vec_dot_q3_k(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
  246. const block_q3_k * x = (const block_q3_k *) vx;
  247. const uint32_t kmask1 = 0x03030303;
  248. const uint32_t kmask2 = 0x0f0f0f0f;
  249. uint32_t aux[3];
  250. uint32_t utmp[4];
  251. // if n is 0, we want to do the lower 128, else the upper 128,
  252. // covering y[l+0], y[l+32], y[l+64], y[l+96] and
  253. // y[l+16], y[l+48], y[l+80], y[l+112]
  254. int n = iqs/128; // 0 or 1
  255. int r = iqs - 128*n; // 0...120 in steps of 8
  256. int l = r/8; // 0...15 in steps of 1
  257. const float * y = yy + 128*n + l;
  258. const uint8_t * q = x[ib].qs + 32*n + l;
  259. const uint8_t * hm = x[ib].hmask + l;
  260. const int8_t * s = (const int8_t *)utmp + 8*n;
  261. memcpy(aux, x[ib].scales, 12);
  262. utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
  263. utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
  264. utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
  265. utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
  266. const float dall = x[ib].d;
  267. const uint8_t m = 1 << (4*n);
  268. float sum = y[ 0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4))
  269. + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4))
  270. + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
  271. + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
  272. + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
  273. + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
  274. + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
  275. + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
  276. result = sum * dall;
  277. }
  278. static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
  279. if (j < 4) {
  280. d = q[j] & 63; m = q[j + 4] & 63;
  281. } else {
  282. d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
  283. m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
  284. }
  285. }
  286. static __global__ void dequantize_block_q4_k(const void * vx, float * yy) {
  287. const block_q4_k * x = (const block_q4_k *) vx;
  288. const int i = blockIdx.x;
  289. //// assume 64 threads - this is very slightly better than the one below
  290. //const int tid = threadIdx.x;
  291. //const int il = tid/16;
  292. //const int ir = tid%16;
  293. //const int is = 2*il;
  294. //const int n = 2;
  295. // assume 32 threads
  296. const int tid = threadIdx.x;
  297. const int il = tid/8;
  298. const int ir = tid%8;
  299. const int is = 2*il;
  300. const int n = 4;
  301. float * y = yy + i*QK_K + 64*il + n*ir;
  302. const float dall = x[i].d;
  303. const float dmin = x[i].dmin;
  304. const uint8_t * q = x[i].qs + 32*il + n*ir;
  305. uint8_t sc, m;
  306. get_scale_min_k4(is + 0, x[i].scales, sc, m);
  307. const float d1 = dall * sc; const float m1 = dmin * m;
  308. get_scale_min_k4(is + 1, x[i].scales, sc, m);
  309. const float d2 = dall * sc; const float m2 = dmin * m;
  310. for (int l = 0; l < n; ++l) {
  311. y[l + 0] = d1 * (q[l] & 0xF) - m1;
  312. y[l +32] = d2 * (q[l] >> 4) - m2;
  313. }
  314. }
  315. static __device__ void vec_dot_q4_k(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
  316. const block_q4_k * x = (const block_q4_k *) vx;
  317. // iqs is in 0...248 in steps of 8 =>
  318. const int j = iqs / 64; // j is in 0...3
  319. const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
  320. const int is = 2*j; // is is in 0...6 in steps of 2
  321. const float * y = yy + 64*j + ir;
  322. const uint8_t * q = x[ib].qs + 32*j + ir;
  323. const float dall = x[ib].d;
  324. const float dmin = x[ib].dmin;
  325. uint8_t sc, m;
  326. get_scale_min_k4(is + 0, x[ib].scales, sc, m);
  327. const float d1 = dall * sc;
  328. const float m1 = dmin * m;
  329. get_scale_min_k4(is + 1, x[ib].scales, sc, m);
  330. const float d2 = dall * sc;
  331. const float m2 = dmin * m;
  332. float sum = 0;
  333. for (int k = 0; k < 4; ++k) {
  334. sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1);
  335. sum += y[k + 32] * (d2 * (q[k] >> 4) - m2);
  336. }
  337. result = sum;
  338. }
  339. static __global__ void dequantize_block_q5_k(const void * vx, float * yy) {
  340. const block_q5_k * x = (const block_q5_k *) vx;
  341. const int i = blockIdx.x;
  342. // assume 64 threads - this is very slightly better than the one below
  343. const int tid = threadIdx.x;
  344. const int il = tid/16; // il is in 0...3
  345. const int ir = tid%16; // ir is in 0...15
  346. const int is = 2*il; // is is in 0...6
  347. float * y = yy + i*QK_K + 64*il + 2*ir;
  348. const float dall = x[i].d;
  349. const float dmin = x[i].dmin;
  350. const uint8_t * ql = x[i].qs + 32*il + 2*ir;
  351. const uint8_t * qh = x[i].qh + 2*ir;
  352. uint8_t sc, m;
  353. get_scale_min_k4(is + 0, x[i].scales, sc, m);
  354. const float d1 = dall * sc; const float m1 = dmin * m;
  355. get_scale_min_k4(is + 1, x[i].scales, sc, m);
  356. const float d2 = dall * sc; const float m2 = dmin * m;
  357. uint8_t hm = 1 << (2*il);
  358. y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
  359. y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
  360. hm <<= 1;
  361. y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
  362. y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
  363. }
  364. static __device__ void vec_dot_q5_k(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
  365. const block_q5_k * x = (const block_q5_k *) vx;
  366. // iqs is in 0...248 in steps of 8 =>
  367. const int j = iqs / 64; // j is in 0...3
  368. const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
  369. const int is = 2*j; // is is in 0...6 in steps of 2
  370. const float * y = yy + 64*j + ir;
  371. const uint8_t * ql = x[ib].qs + 32*j + ir;
  372. const uint8_t * qh = x[ib].qh + ir;
  373. const float dall = x[ib].d;
  374. const float dmin = x[ib].dmin;
  375. uint8_t sc, m;
  376. get_scale_min_k4(is + 0, x[ib].scales, sc, m);
  377. const float d1 = dall * sc;
  378. const float m1 = dmin * m;
  379. get_scale_min_k4(is + 1, x[ib].scales, sc, m);
  380. const float d2 = dall * sc;
  381. const float m2 = dmin * m;
  382. uint8_t hm = 1 << is;
  383. float sum = 0;
  384. for (int k = 0; k < 4; ++k) {
  385. sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
  386. }
  387. hm <<= 1;
  388. for (int k = 0; k < 4; ++k) {
  389. sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2);
  390. }
  391. result = sum;
  392. }
  393. static __global__ void dequantize_block_q6_k(const void * vx, float * yy) {
  394. const block_q6_k * x = (const block_q6_k *) vx;
  395. const int i = blockIdx.x;
  396. // assume 64 threads - this is very slightly better than the one below
  397. const int tid = threadIdx.x;
  398. const int ip = tid/32; // ip is 0 or 1
  399. const int il = tid - 32*ip; // 0...32
  400. const int is = 8*ip + il/16;
  401. float * y = yy + i*QK_K + 128*ip + il;
  402. const float d = x[i].d;
  403. const uint8_t * ql = x[i].ql + 64*ip + il;
  404. const uint8_t qh = x[i].qh[32*ip + il];
  405. const int8_t * sc = x[i].scales + is;
  406. y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
  407. y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
  408. y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
  409. y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
  410. }
  411. static __device__ void vec_dot_q6_k(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
  412. const block_q6_k * x = (const block_q6_k *) vx;
  413. const int ip = iqs / 128; // 0 or 1
  414. const int il = (iqs - 128*ip)/8; // 0...15
  415. const int is = 8*ip;
  416. const float * y = yy + 128*ip + il;
  417. const float d = x[ib].d;
  418. const uint8_t * ql = x[ib].ql + 64*ip + il;
  419. const uint8_t * qh = x[ib].qh + 32*ip + il;
  420. const int8_t * sc = x[ib].scales + is;
  421. result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32)
  422. + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
  423. + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
  424. + y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
  425. + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
  426. + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
  427. + y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
  428. + y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
  429. }
  430. static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
  431. const half * x = (const half *) vx;
  432. v0 = __half2float(x[ib + 0]);
  433. v1 = __half2float(x[ib + 1]);
  434. }
  435. template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
  436. static __global__ void dequantize_block(const void * vx, float * y, const int k) {
  437. const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
  438. if (i >= k) {
  439. return;
  440. }
  441. const int ib = i/qk; // block index
  442. const int iqs = (i%qk)/qr; // quant index
  443. const int iybs = i - i%qk; // y block start index
  444. const int y_offset = qr == 1 ? 1 : qk/2;
  445. // dequantize
  446. float & v0 = y[iybs + iqs + 0];
  447. float & v1 = y[iybs + iqs + y_offset];
  448. dequantize_kernel(vx, ib, iqs, v0, v1);
  449. }
  450. template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
  451. static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
  452. // qk = quantized weights per x block
  453. // qr = number of quantized weights per data value in x block
  454. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  455. const int tid = threadIdx.x;
  456. const int iter_stride = 2*GGML_CUDA_DMMV_X;
  457. const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
  458. const int y_offset = qr == 1 ? 1 : qk/2;
  459. float tmp = 0; // partial sum for thread in warp
  460. for (int i = 0; i < ncols; i += iter_stride) {
  461. const int col = i + vals_per_iter*tid;
  462. const int ib = (row*ncols + col)/qk; // x block index
  463. const int iqs = (col%qk)/qr; // x quant index
  464. const int iybs = col - col%qk; // y block start index
  465. // processing >2 values per i iter is faster for fast GPUs
  466. #pragma unroll
  467. for (int j = 0; j < vals_per_iter; j += 2) {
  468. // process 2 vals per j iter
  469. // dequantize
  470. float v0, v1;
  471. dequantize_kernel(vx, ib, iqs + j/qr, v0, v1);
  472. // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
  473. // matrix multiplication
  474. tmp += v0 * y[iybs + iqs + j/qr + 0];
  475. tmp += v1 * y[iybs + iqs + j/qr + y_offset];
  476. // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
  477. }
  478. }
  479. // sum up partial sums and write back result
  480. __syncthreads();
  481. #pragma unroll
  482. for (int mask = 16; mask > 0; mask >>= 1) {
  483. tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
  484. }
  485. if (tid == 0) {
  486. dst[row] = tmp;
  487. }
  488. }
  489. template <int n_thread, dot_kernel_k_t dot_kernel>
  490. static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols) {
  491. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  492. const int tid = threadIdx.x;
  493. const int iter_stride = QK_K;
  494. const int vals_per_iter = iter_stride / n_thread;
  495. const int num_blocks_per_row = ncols / QK_K;
  496. const int ib0 = row*num_blocks_per_row;
  497. float tmp = 0; // partial sum for thread in warp
  498. for (int i = 0; i < ncols; i += iter_stride) {
  499. const int col = i + vals_per_iter*tid;
  500. const int ib = ib0 + col/QK_K; // x block index
  501. const int iqs = col%QK_K; // x quant index
  502. const int iybs = col - col%QK_K; // y block start index
  503. float v;
  504. dot_kernel(vx, ib, iqs, y + iybs, v);
  505. tmp += v;
  506. }
  507. // sum up partial sums and write back result
  508. __syncthreads();
  509. #pragma unroll
  510. for (int mask = 16; mask > 0; mask >>= 1) {
  511. tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
  512. }
  513. if (tid == 0) {
  514. dst[row] = tmp;
  515. }
  516. }
  517. static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
  518. const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
  519. mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
  520. }
  521. static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  522. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  523. dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  524. }
  525. static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  526. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  527. dequantize_block<QK4_1, QR4_1, dequantize_q4_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  528. }
  529. static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  530. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  531. dequantize_block<QK5_0, QR5_0, dequantize_q5_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  532. }
  533. static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  534. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  535. dequantize_block<QK5_1, QR5_1, dequantize_q5_1><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  536. }
  537. static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  538. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  539. dequantize_block<QK8_0, QR8_0, dequantize_q8_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  540. }
  541. static void dequantize_row_q2_k_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  542. const int nb = k / QK_K;
  543. dequantize_block_q2_k<<<nb, 64, 0, stream>>>(vx, y);
  544. }
  545. static void dequantize_row_q3_k_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  546. const int nb = k / QK_K;
  547. dequantize_block_q3_k<<<nb, 64, 0, stream>>>(vx, y);
  548. }
  549. static void dequantize_row_q4_k_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  550. const int nb = k / QK_K;
  551. dequantize_block_q4_k<<<nb, 32, 0, stream>>>(vx, y);
  552. }
  553. static void dequantize_row_q5_k_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  554. const int nb = k / QK_K;
  555. dequantize_block_q5_k<<<nb, 64, 0, stream>>>(vx, y);
  556. }
  557. static void dequantize_row_q6_k_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  558. const int nb = k / QK_K;
  559. dequantize_block_q6_k<<<nb, 64, 0, stream>>>(vx, y);
  560. }
  561. static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  562. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  563. GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
  564. const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
  565. dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
  566. <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
  567. }
  568. static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  569. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  570. GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
  571. const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
  572. dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
  573. <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
  574. }
  575. static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  576. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  577. GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
  578. const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
  579. dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
  580. <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
  581. }
  582. static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  583. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  584. GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
  585. const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
  586. dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
  587. <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
  588. }
  589. static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  590. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  591. GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
  592. const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
  593. dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
  594. <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
  595. }
  596. static void dequantize_mul_mat_vec_q2_k_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  597. GGML_ASSERT(ncols % QK_K == 0);
  598. const int ny = 2;
  599. const dim3 block_dims(32, ny, 1);
  600. dequantize_mul_mat_vec_k<32, vec_dot_q2_k><<<(nrows + ny - 1)/ny, block_dims, 0, stream>>>(vx, y, dst, ncols);
  601. }
  602. static void dequantize_mul_mat_vec_q3_k_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  603. GGML_ASSERT(ncols % QK_K == 0);
  604. const dim3 block_dims(32, 2, 1);
  605. dequantize_mul_mat_vec_k<32, vec_dot_q3_k><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
  606. }
  607. static void dequantize_mul_mat_vec_q4_k_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  608. GGML_ASSERT(ncols % QK_K == 0);
  609. const dim3 block_dims(32, 2, 1);
  610. dequantize_mul_mat_vec_k<32, vec_dot_q4_k><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
  611. }
  612. static void dequantize_mul_mat_vec_q5_k_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  613. GGML_ASSERT(ncols % QK_K == 0);
  614. const dim3 block_dims(32, 2, 1);
  615. dequantize_mul_mat_vec_k<32, vec_dot_q5_k><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
  616. }
  617. static void dequantize_mul_mat_vec_q6_k_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  618. GGML_ASSERT(ncols % QK_K == 0);
  619. const dim3 block_dims(32, 2, 1);
  620. dequantize_mul_mat_vec_k<32, vec_dot_q6_k><<<nrows/2, block_dims, 0, stream>>>(vx, y, dst, ncols);
  621. }
  622. static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
  623. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  624. dequantize_block<32, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  625. }
  626. static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  627. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  628. GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
  629. const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
  630. dequantize_mul_mat_vec<1, 1, convert_f16>
  631. <<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
  632. }
  633. static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
  634. switch (type) {
  635. case GGML_TYPE_Q4_0:
  636. return dequantize_row_q4_0_cuda;
  637. case GGML_TYPE_Q4_1:
  638. return dequantize_row_q4_1_cuda;
  639. case GGML_TYPE_Q5_0:
  640. return dequantize_row_q5_0_cuda;
  641. case GGML_TYPE_Q5_1:
  642. return dequantize_row_q5_1_cuda;
  643. case GGML_TYPE_Q8_0:
  644. return dequantize_row_q8_0_cuda;
  645. case GGML_TYPE_Q2_K:
  646. return dequantize_row_q2_k_cuda;
  647. case GGML_TYPE_Q3_K:
  648. return dequantize_row_q3_k_cuda;
  649. case GGML_TYPE_Q4_K:
  650. return dequantize_row_q4_k_cuda;
  651. case GGML_TYPE_Q5_K:
  652. return dequantize_row_q5_k_cuda;
  653. case GGML_TYPE_Q6_K:
  654. return dequantize_row_q6_k_cuda;
  655. case GGML_TYPE_F16:
  656. return convert_fp16_to_fp32_cuda;
  657. default:
  658. return nullptr;
  659. }
  660. }
  661. static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_type type) {
  662. switch (type) {
  663. case GGML_TYPE_Q4_0:
  664. return dequantize_mul_mat_vec_q4_0_cuda;
  665. case GGML_TYPE_Q4_1:
  666. return dequantize_mul_mat_vec_q4_1_cuda;
  667. case GGML_TYPE_Q5_0:
  668. return dequantize_mul_mat_vec_q5_0_cuda;
  669. case GGML_TYPE_Q5_1:
  670. return dequantize_mul_mat_vec_q5_1_cuda;
  671. case GGML_TYPE_Q8_0:
  672. return dequantize_mul_mat_vec_q8_0_cuda;
  673. case GGML_TYPE_Q2_K:
  674. return dequantize_mul_mat_vec_q2_k_cuda;
  675. case GGML_TYPE_Q3_K:
  676. return dequantize_mul_mat_vec_q3_k_cuda;
  677. case GGML_TYPE_Q4_K:
  678. return dequantize_mul_mat_vec_q4_k_cuda;
  679. case GGML_TYPE_Q5_K:
  680. return dequantize_mul_mat_vec_q5_k_cuda;
  681. case GGML_TYPE_Q6_K:
  682. return dequantize_mul_mat_vec_q6_k_cuda;
  683. case GGML_TYPE_F16:
  684. return convert_mul_mat_vec_f16_cuda;
  685. default:
  686. return nullptr;
  687. }
  688. }
  689. // buffer pool for cuda
  690. #define MAX_CUDA_BUFFERS 256
  691. struct scoped_spin_lock {
  692. std::atomic_flag& lock;
  693. scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
  694. while (lock.test_and_set(std::memory_order_acquire)) {
  695. ; // spin
  696. }
  697. }
  698. ~scoped_spin_lock() {
  699. lock.clear(std::memory_order_release);
  700. }
  701. scoped_spin_lock(const scoped_spin_lock&) = delete;
  702. scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
  703. };
  704. struct cuda_buffer {
  705. void * ptr = nullptr;
  706. size_t size = 0;
  707. };
  708. static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
  709. static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
  710. static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
  711. scoped_spin_lock lock(g_cuda_pool_lock);
  712. for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
  713. cuda_buffer& b = g_cuda_buffer_pool[i];
  714. if (b.size >= size && b.ptr != nullptr) {
  715. void * ptr = b.ptr;
  716. *actual_size = b.size;
  717. b.ptr = nullptr;
  718. b.size = 0;
  719. return ptr;
  720. }
  721. }
  722. void * ptr;
  723. CUDA_CHECK(cudaMalloc((void **) &ptr, size));
  724. *actual_size = size;
  725. return ptr;
  726. }
  727. static void ggml_cuda_pool_free(void * ptr, size_t size) {
  728. scoped_spin_lock lock(g_cuda_pool_lock);
  729. for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
  730. cuda_buffer& b = g_cuda_buffer_pool[i];
  731. if (b.ptr == nullptr) {
  732. b.ptr = ptr;
  733. b.size = size;
  734. return;
  735. }
  736. }
  737. fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
  738. CUDA_CHECK(cudaFree(ptr));
  739. }
  740. #define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication.
  741. #define GGML_CUDA_MAX_EVENTS 64
  742. static cublasHandle_t g_cublasH = nullptr;
  743. static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_STREAMS] = { nullptr };
  744. static cudaStream_t g_cudaStreams2[GGML_CUDA_MAX_STREAMS] = { nullptr };
  745. static cudaEvent_t g_cudaEvents[GGML_CUDA_MAX_EVENTS] = { nullptr };
  746. void ggml_init_cublas() {
  747. if (g_cublasH == nullptr) {
  748. // create streams
  749. for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
  750. CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[i], cudaStreamNonBlocking));
  751. CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams2[i], cudaStreamNonBlocking));
  752. }
  753. // create events
  754. for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
  755. CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents[i], cudaEventDisableTiming));
  756. }
  757. // create cublas handle
  758. CUBLAS_CHECK(cublasCreate(&g_cublasH));
  759. CUBLAS_CHECK(cublasSetMathMode(g_cublasH, CUBLAS_TF32_TENSOR_OP_MATH));
  760. // configure logging to stdout
  761. // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
  762. }
  763. }
  764. void * ggml_cuda_host_malloc(size_t size) {
  765. if (getenv("GGML_CUDA_NO_PINNED") != nullptr) {
  766. return nullptr;
  767. }
  768. void * ptr = nullptr;
  769. cudaError_t err = cudaMallocHost((void **) &ptr, size);
  770. if (err != cudaSuccess) {
  771. fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n",
  772. size/1024.0/1024.0, cudaGetErrorString(err));
  773. return nullptr;
  774. }
  775. return ptr;
  776. }
  777. void ggml_cuda_host_free(void * ptr) {
  778. CUDA_CHECK(cudaFreeHost(ptr));
  779. }
  780. static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
  781. const uint64_t ne0 = src->ne[0];
  782. const uint64_t ne1 = src->ne[1];
  783. const uint64_t nb0 = src->nb[0];
  784. const uint64_t nb1 = src->nb[1];
  785. const uint64_t nb2 = src->nb[2];
  786. const uint64_t nb3 = src->nb[3];
  787. const enum ggml_type type = src->type;
  788. const size_t ts = ggml_type_size(type);
  789. const size_t bs = ggml_blck_size(type);
  790. const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
  791. if (nb0 == ts && nb1 == ts*ne0/bs) {
  792. return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
  793. } else if (nb0 == ts) {
  794. return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
  795. } else {
  796. for (uint64_t i1 = 0; i1 < ne1; i1++) {
  797. const void * rx = (const void *) ((const char *) x + i1*nb1);
  798. void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
  799. // pretend the row is a matrix with cols=1
  800. cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
  801. if (r != cudaSuccess) return r;
  802. }
  803. return cudaSuccess;
  804. }
  805. }
  806. static void ggml_cuda_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  807. GGML_ASSERT(src1->backend == GGML_BACKEND_CUDA);
  808. const int64_t ne00 = src0->ne[0];
  809. const int64_t ne01 = src0->ne[1];
  810. const int64_t ne02 = src0->ne[2];
  811. const int64_t ne03 = src0->ne[2];
  812. const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
  813. const int64_t ne10 = src1->ne[0];
  814. const int64_t ne11 = src1->ne[1];
  815. const int64_t ne12 = src1->ne[2];
  816. const int64_t ne13 = src1->ne[3];
  817. const int nb2 = dst->nb[2];
  818. const int nb3 = dst->nb[3];
  819. size_t x_size, d_size;
  820. float * d_X = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &x_size); // src0
  821. float * d_Y = (float *) src1->data; // src1 is already on device, broadcasted.
  822. float * d_D = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &d_size); // dst
  823. for (int64_t i03 = 0; i03 < ne03; i03++) {
  824. for (int64_t i02 = 0; i02 < ne02; i02++) {
  825. const int i0 = i03*ne02 + i02;
  826. float * c_X2 = d_X + i0*ne01*ne00;
  827. float * c_D2 = d_D + i0*ne01*ne00;
  828. cudaStream_t cudaStream = g_cudaStreams[i0 % GGML_CUDA_MAX_STREAMS];
  829. cudaStream_t cudaStream2 = g_cudaStreams2[i0 % GGML_CUDA_MAX_STREAMS];
  830. cudaEvent_t cudaEvent = g_cudaEvents[i0 % GGML_CUDA_MAX_EVENTS];
  831. // copy src0 to device
  832. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X2, src0, i03, i02, cudaStream2));
  833. CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
  834. // wait for data
  835. CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
  836. for (int64_t i01 = 0; i01 < ne01; i01++) {
  837. const int64_t i13 = i03%ne13;
  838. const int64_t i12 = i02%ne12;
  839. const int64_t i11 = i01%ne11;
  840. const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
  841. float * c_X1 = c_X2 + i01*ne00;
  842. float * c_Y = d_Y + i1*ne10;
  843. float * c_D1 = c_D2 + i01*ne00;
  844. // compute
  845. mul_f32_cuda(c_X1, c_Y, c_D1, ne00, ne10, cudaStream);
  846. CUDA_CHECK(cudaGetLastError());
  847. }
  848. // copy dst to host
  849. float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
  850. CUDA_CHECK(cudaMemcpyAsync(d, c_D2, sizeof(float)*ne00*ne01, cudaMemcpyDeviceToHost, cudaStream));
  851. }
  852. }
  853. CUDA_CHECK(cudaDeviceSynchronize());
  854. ggml_cuda_pool_free(d_X, x_size);
  855. ggml_cuda_pool_free(d_D, d_size);
  856. }
  857. static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  858. const int64_t ne00 = src0->ne[0];
  859. const int64_t ne01 = src0->ne[1];
  860. const int64_t ne02 = src0->ne[2];
  861. const int64_t ne03 = src0->ne[3];
  862. const int64_t ne10 = src1->ne[0];
  863. const int64_t ne11 = src1->ne[1];
  864. const int nb2 = dst->nb[2];
  865. const int nb3 = dst->nb[3];
  866. const float alpha = 1.0f;
  867. const float beta = 0.0f;
  868. const int x_ne = ne01 * ne00;
  869. const int y_ne = ne11 * ne10;
  870. const int d_ne = ne11 * ne01;
  871. const int n_mm = ne03 * ne02;
  872. size_t x_size, y_size, d_size;
  873. float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
  874. float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
  875. float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
  876. for (int64_t i03 = 0; i03 < ne03; i03++) {
  877. for (int64_t i02 = 0; i02 < ne02; i02++) {
  878. int i = i03*ne02 + i02;
  879. cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
  880. float * c_X = d_X + i * x_ne;
  881. float * c_Y = d_Y + i * y_ne;
  882. float * c_D = d_D + i * d_ne;
  883. // copy data to device
  884. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
  885. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
  886. // compute
  887. CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
  888. CUBLAS_CHECK(
  889. cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
  890. ne01, ne11, ne10,
  891. &alpha, c_X, ne00,
  892. c_Y, ne10,
  893. &beta, c_D, ne01));
  894. // copy dst to host
  895. float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
  896. CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
  897. }
  898. }
  899. CUDA_CHECK(cudaDeviceSynchronize());
  900. ggml_cuda_pool_free(d_X, x_size);
  901. ggml_cuda_pool_free(d_Y, y_size);
  902. ggml_cuda_pool_free(d_D, d_size);
  903. }
  904. static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
  905. const int64_t ne00 = src0->ne[0];
  906. const int64_t ne01 = src0->ne[1];
  907. const int64_t ne02 = src0->ne[2];
  908. const int64_t ne03 = src0->ne[3];
  909. const int64_t ne10 = src1->ne[0];
  910. const int64_t ne11 = src1->ne[1];
  911. const int nb10 = src1->nb[0];
  912. const int nb11 = src1->nb[1];
  913. const int nb12 = src1->nb[2];
  914. const int nb13 = src1->nb[3];
  915. const int nb2 = dst->nb[2];
  916. const int nb3 = dst->nb[3];
  917. const float alpha = 1.0f;
  918. const float beta = 0.0f;
  919. const int x_ne = ne01 * ne00;
  920. const int y_ne = ne11 * ne10;
  921. const int d_ne = ne11 * ne01;
  922. const int n_mm = ne03 * ne02;
  923. size_t x_size, y_size, d_size;
  924. half * d_X = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * x_ne, &x_size);
  925. half * d_Y = (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * y_ne, &y_size);
  926. float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
  927. bool src1_cont_rows = nb10 == sizeof(float);
  928. bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
  929. for (int64_t i03 = 0; i03 < ne03; i03++) {
  930. for (int64_t i02 = 0; i02 < ne02; i02++) {
  931. int i = i03*ne02 + i02;
  932. cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
  933. half * c_X = d_X + i * x_ne;
  934. half * c_Y = d_Y + i * y_ne;
  935. float * c_D = d_D + i * d_ne;
  936. // copy src0 to device
  937. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
  938. // convert src1 to fp16
  939. // TODO: use multiple threads
  940. ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
  941. char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
  942. if (src1_cont_rows) {
  943. if (src1_cont_cols) {
  944. ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
  945. }
  946. else {
  947. for (int64_t i01 = 0; i01 < ne11; i01++) {
  948. ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
  949. }
  950. }
  951. }
  952. else {
  953. for (int64_t i01 = 0; i01 < ne11; i01++) {
  954. for (int64_t i00 = 0; i00 < ne10; i00++) {
  955. // very slow due to no inlining
  956. tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
  957. }
  958. }
  959. }
  960. // copy src1 to device
  961. CUDA_CHECK(cudaMemcpyAsync(c_Y, tmp, sizeof(half) * y_ne, cudaMemcpyHostToDevice, cudaStream));
  962. // compute
  963. CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
  964. CUBLAS_CHECK(
  965. cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
  966. ne01, ne11, ne10,
  967. &alpha, c_X, CUDA_R_16F, ne00,
  968. c_Y, CUDA_R_16F, ne10,
  969. &beta, c_D, CUDA_R_32F, ne01,
  970. CUBLAS_COMPUTE_32F_FAST_16F,
  971. CUBLAS_GEMM_DEFAULT));
  972. // copy dst to host
  973. float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
  974. CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
  975. }
  976. }
  977. CUDA_CHECK(cudaDeviceSynchronize());
  978. ggml_cuda_pool_free(d_X, x_size);
  979. ggml_cuda_pool_free(d_Y, y_size);
  980. ggml_cuda_pool_free(d_D, d_size);
  981. }
  982. static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
  983. const int64_t ne00 = src0->ne[0];
  984. const int64_t ne01 = src0->ne[1];
  985. const int64_t ne02 = src0->ne[2];
  986. const int64_t ne03 = src0->ne[3];
  987. const int64_t ne10 = src1->ne[0];
  988. const int64_t ne11 = src1->ne[1];
  989. const int nb2 = dst->nb[2];
  990. const int nb3 = dst->nb[3];
  991. const ggml_type type = src0->type;
  992. const bool mul_mat_vec = ne11 == 1;
  993. const float alpha = 1.0f;
  994. const float beta = 0.0f;
  995. const int x_ne = ne01 * ne00;
  996. const int y_ne = ne11 * ne10;
  997. const int d_ne = ne11 * ne01;
  998. const int n_mm = ne03 * ne02;
  999. const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
  1000. size_t x_size, y_size, d_size, q_size;
  1001. float * d_X = nullptr;
  1002. if (!mul_mat_vec) {
  1003. d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
  1004. }
  1005. float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
  1006. float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
  1007. char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
  1008. const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
  1009. dequantize_mul_mat_vec_cuda_t dmmv = ggml_get_dequantize_mul_mat_vec_cuda(type);
  1010. GGML_ASSERT(to_fp32_cuda != nullptr);
  1011. for (int64_t i03 = 0; i03 < ne03; i03++) {
  1012. for (int64_t i02 = 0; i02 < ne02; i02++) {
  1013. int i = i03*ne02 + i02;
  1014. cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
  1015. cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
  1016. cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
  1017. float * c_Y = d_Y + i * y_ne;
  1018. float * c_D = d_D + i * d_ne;
  1019. char * c_Q = d_Q + i * q_sz;
  1020. // copy src0 to device if necessary
  1021. if (src0->backend == GGML_BACKEND_CPU) {
  1022. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
  1023. } else if (src0->backend == GGML_BACKEND_CUDA) {
  1024. c_Q = ((char *) src0->data) + i * q_sz;
  1025. } else {
  1026. GGML_ASSERT(false);
  1027. }
  1028. if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
  1029. CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
  1030. // copy src1 to device
  1031. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
  1032. // wait for data
  1033. CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
  1034. // compute
  1035. //printf("Calling dmmv\n");
  1036. dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
  1037. CUDA_CHECK(cudaGetLastError());
  1038. } else { // general dequantization kernel + cuBLAS matrix matrix multiplication
  1039. float * c_X = d_X + i * x_ne;
  1040. //typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
  1041. // convert src0 to fp32 on device
  1042. to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
  1043. CUDA_CHECK(cudaGetLastError());
  1044. CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
  1045. // copy src1 to device
  1046. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
  1047. // wait for conversion
  1048. CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
  1049. // compute
  1050. CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
  1051. CUBLAS_CHECK(
  1052. cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
  1053. ne01, ne11, ne10,
  1054. &alpha, c_X, ne00,
  1055. c_Y, ne10,
  1056. &beta, c_D, ne01));
  1057. }
  1058. // copy dst to host
  1059. float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
  1060. CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
  1061. }
  1062. }
  1063. CUDA_CHECK(cudaDeviceSynchronize());
  1064. if (!mul_mat_vec) {
  1065. ggml_cuda_pool_free(d_X, x_size);
  1066. }
  1067. ggml_cuda_pool_free(d_Y, y_size);
  1068. ggml_cuda_pool_free(d_D, d_size);
  1069. ggml_cuda_pool_free(d_Q, q_size);
  1070. }
  1071. void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
  1072. GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
  1073. ggml_cuda_mul_f32(src0, src1, dst);
  1074. }
  1075. bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
  1076. const int64_t ne10 = src1->ne[0];
  1077. const int64_t ne0 = dst->ne[0];
  1078. const int64_t ne1 = dst->ne[1];
  1079. // TODO: find the optimal values for these
  1080. if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
  1081. src1->type == GGML_TYPE_F32 &&
  1082. dst->type == GGML_TYPE_F32 &&
  1083. ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) {
  1084. return true;
  1085. }
  1086. return false;
  1087. }
  1088. bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
  1089. size_t src0_sz = ggml_nbytes(src0);
  1090. size_t src1_sz = ggml_nbytes(src1);
  1091. // mul_mat_q: src0 is converted to fp32 on device
  1092. size_t mul_mat_q_transfer = src0_sz + src1_sz;
  1093. // mul_mat_f16: src1 is converted to fp16 on cpu
  1094. size_t mul_mat_f16_transfer = src0_sz + sizeof(half) * ggml_nelements(src1);
  1095. // choose the smaller one to transfer to the device
  1096. // TODO: this is not always the best choice due to the overhead of converting to fp16
  1097. return mul_mat_f16_transfer < mul_mat_q_transfer;
  1098. }
  1099. void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
  1100. GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));
  1101. if (src0->type == GGML_TYPE_F32) {
  1102. ggml_cuda_mul_mat_f32(src0, src1, dst);
  1103. }
  1104. else if (src0->type == GGML_TYPE_F16) {
  1105. if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
  1106. ggml_cuda_mul_mat_f16(src0, src1, dst, wdata, wsize);
  1107. }
  1108. else {
  1109. ggml_cuda_mul_mat_q_f32(src0, src1, dst);
  1110. }
  1111. }
  1112. else if (ggml_is_quantized(src0->type)) {
  1113. ggml_cuda_mul_mat_q_f32(src0, src1, dst);
  1114. }
  1115. else {
  1116. GGML_ASSERT(false);
  1117. }
  1118. }
  1119. size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
  1120. if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
  1121. return ggml_nelements(src1) * sizeof(ggml_fp16_t);
  1122. }
  1123. else {
  1124. return 0;
  1125. }
  1126. }
  1127. void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
  1128. const int64_t ne0 = tensor->ne[0];
  1129. const int64_t ne1 = tensor->ne[1];
  1130. const int64_t ne2 = tensor->ne[2];
  1131. const int64_t ne3 = tensor->ne[3];
  1132. const ggml_type type = tensor->type;
  1133. const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
  1134. size_t q_size;
  1135. char * dst = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
  1136. cudaStream_t cudaStream2 = g_cudaStreams2[0];
  1137. // copy tensor to device
  1138. for (int64_t i3 = 0; i3 < ne3; i3++) {
  1139. for (int64_t i2 = 0; i2 < ne2; i2++) {
  1140. int i = i3*ne2 + i2;
  1141. CUDA_CHECK(ggml_cuda_h2d_tensor_2d(dst + i*ne0*ne1, tensor, i3, i2, cudaStream2));
  1142. }
  1143. }
  1144. tensor->data = dst;
  1145. tensor->backend = GGML_BACKEND_CUDA;
  1146. }
  1147. void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
  1148. FILE * fp = fopen(fname, "rb");
  1149. const size_t size = ggml_nbytes(tensor);
  1150. void * buf;
  1151. CUDA_CHECK(cudaMalloc(&buf, size));
  1152. void * buf_host = malloc(size);
  1153. #ifdef _WIN32
  1154. int ret = _fseeki64(fp, (__int64) offset, SEEK_SET);
  1155. #else
  1156. int ret = fseek(fp, (long) offset, SEEK_SET);
  1157. #endif
  1158. GGML_ASSERT(ret == 0); // same
  1159. size_t ret2 = fread(buf_host, size, 1, fp);
  1160. if (ret2 != 1) {
  1161. fprintf(stderr, "unexpectedly reached end of file");
  1162. exit(1);
  1163. }
  1164. cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
  1165. cudaDeviceSynchronize();
  1166. tensor->data = buf;
  1167. free(buf_host);
  1168. fclose(fp);
  1169. }