dmmv.cu 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674
  1. #include "dmmv.cuh"
  2. #include "dequantize.cuh"
  3. #include "convert.cuh"
  4. #ifndef K_QUANTS_PER_ITERATION
  5. #define K_QUANTS_PER_ITERATION 2
  6. #else
  7. static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
  8. #endif
  9. static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
  10. static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
  11. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  12. if (row > nrows) return;
  13. const int num_blocks_per_row = ncols / QK_K;
  14. const int ib0 = row*num_blocks_per_row;
  15. const block_q2_K * x = (const block_q2_K *)vx + ib0;
  16. float tmp = 0; // partial sum for thread in warp
  17. const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
  18. const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
  19. const int step = 16/K_QUANTS_PER_ITERATION;
  20. const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
  21. const int in = tid - step*im; // 0...15 or 0...7
  22. const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
  23. const int q_offset = 32*im + l0;
  24. const int s_offset = 8*im;
  25. const int y_offset = 128*im + l0;
  26. uint32_t aux[4];
  27. const uint8_t * d = (const uint8_t *)aux;
  28. const uint8_t * m = (const uint8_t *)(aux + 2);
  29. for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  30. const float * y = yy + i * QK_K + y_offset;
  31. const uint8_t * q = x[i].qs + q_offset;
  32. const float dall = __low2half(x[i].dm);
  33. const float dmin = __high2half(x[i].dm);
  34. const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
  35. aux[0] = a[0] & 0x0f0f0f0f;
  36. aux[1] = a[1] & 0x0f0f0f0f;
  37. aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
  38. aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
  39. float sum1 = 0, sum2 = 0;
  40. for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
  41. sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
  42. + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
  43. + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
  44. + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
  45. + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
  46. + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
  47. + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
  48. +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
  49. sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
  50. + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
  51. }
  52. tmp += dall * sum1 - dmin * sum2;
  53. }
  54. // sum up partial sums and write back result
  55. tmp = warp_reduce_sum(tmp);
  56. if (threadIdx.x == 0) {
  57. dst[row] = tmp;
  58. }
  59. }
  60. static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
  61. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  62. if (row > nrows) return;
  63. const int num_blocks_per_row = ncols / QK_K;
  64. const int ib0 = row*num_blocks_per_row;
  65. const block_q3_K * x = (const block_q3_K *)vx + ib0;
  66. float tmp = 0; // partial sum for thread in warp
  67. const uint16_t kmask1 = 0x0303;
  68. const uint16_t kmask2 = 0x0f0f;
  69. const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
  70. const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
  71. const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
  72. const int step = 16/K_QUANTS_PER_ITERATION;
  73. const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
  74. const int in = tid - step*im; // 0....15 or 0...7
  75. const uint8_t m = 1 << (4*im);
  76. const int l0 = n*in; // 0...15 or 0...14 in steps of 2
  77. const int q_offset = 32*im + l0;
  78. const int y_offset = 128*im + l0;
  79. uint16_t utmp[4];
  80. const int8_t * s = (const int8_t *)utmp;
  81. const uint16_t s_shift = 4*im;
  82. for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  83. const float * y = yy + i * QK_K + y_offset;
  84. const uint8_t * q = x[i].qs + q_offset;
  85. const uint8_t * h = x[i].hmask + l0;
  86. const uint16_t * a = (const uint16_t *)x[i].scales;
  87. utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
  88. utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
  89. utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
  90. utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
  91. const float d = x[i].d;
  92. float sum = 0;
  93. for (int l = 0; l < n; ++l) {
  94. sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
  95. + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
  96. + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
  97. + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
  98. sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
  99. + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
  100. + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
  101. + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
  102. }
  103. tmp += d * sum;
  104. }
  105. // sum up partial sums and write back result
  106. tmp = warp_reduce_sum(tmp);
  107. if (threadIdx.x == 0) {
  108. dst[row] = tmp;
  109. }
  110. }
  111. static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
  112. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  113. if (row > nrows) return;
  114. const int num_blocks_per_row = ncols / QK_K;
  115. const int ib0 = row*num_blocks_per_row;
  116. const block_q4_K * x = (const block_q4_K *)vx + ib0;
  117. const uint16_t kmask1 = 0x3f3f;
  118. const uint16_t kmask2 = 0x0f0f;
  119. const uint16_t kmask3 = 0xc0c0;
  120. const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
  121. const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
  122. const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
  123. const int il = tid/step; // 0...3
  124. const int ir = tid - step*il; // 0...7 or 0...3
  125. const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
  126. const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
  127. const int in = il%2;
  128. const int l0 = n*(2*ir + in);
  129. const int q_offset = 32*im + l0;
  130. const int y_offset = 64*im + l0;
  131. uint16_t aux[4];
  132. const uint8_t * sc = (const uint8_t *)aux;
  133. #if K_QUANTS_PER_ITERATION == 2
  134. uint32_t q32[4];
  135. const uint8_t * q4 = (const uint8_t *)q32;
  136. #else
  137. uint16_t q16[4];
  138. const uint8_t * q4 = (const uint8_t *)q16;
  139. #endif
  140. float tmp = 0; // partial sum for thread in warp
  141. for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  142. const float * y1 = yy + i*QK_K + y_offset;
  143. const float * y2 = y1 + 128;
  144. const float dall = __low2half(x[i].dm);
  145. const float dmin = __high2half(x[i].dm);
  146. const uint16_t * a = (const uint16_t *)x[i].scales;
  147. aux[0] = a[im+0] & kmask1;
  148. aux[1] = a[im+2] & kmask1;
  149. aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
  150. aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
  151. #if K_QUANTS_PER_ITERATION == 2
  152. const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
  153. const uint32_t * q2 = q1 + 16;
  154. q32[0] = q1[0] & 0x0f0f0f0f;
  155. q32[1] = q1[0] & 0xf0f0f0f0;
  156. q32[2] = q2[0] & 0x0f0f0f0f;
  157. q32[3] = q2[0] & 0xf0f0f0f0;
  158. float4 s = {0.f, 0.f, 0.f, 0.f};
  159. float smin = 0;
  160. for (int l = 0; l < 4; ++l) {
  161. s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4];
  162. s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12];
  163. smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
  164. }
  165. tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
  166. #else
  167. const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
  168. const uint16_t * q2 = q1 + 32;
  169. q16[0] = q1[0] & 0x0f0f;
  170. q16[1] = q1[0] & 0xf0f0;
  171. q16[2] = q2[0] & 0x0f0f;
  172. q16[3] = q2[0] & 0xf0f0;
  173. float4 s = {0.f, 0.f, 0.f, 0.f};
  174. float smin = 0;
  175. for (int l = 0; l < 2; ++l) {
  176. s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
  177. s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
  178. smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
  179. }
  180. tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
  181. #endif
  182. }
  183. // sum up partial sums and write back result
  184. tmp = warp_reduce_sum(tmp);
  185. if (tid == 0) {
  186. dst[row] = tmp;
  187. }
  188. }
  189. static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) {
  190. const int row = blockIdx.x;
  191. const int num_blocks_per_row = ncols / QK_K;
  192. const int ib0 = row*num_blocks_per_row;
  193. const block_q5_K * x = (const block_q5_K *)vx + ib0;
  194. float tmp = 0; // partial sum for thread in warp
  195. const uint16_t kmask1 = 0x3f3f;
  196. const uint16_t kmask2 = 0x0f0f;
  197. const uint16_t kmask3 = 0xc0c0;
  198. const int tid = threadIdx.x/2; // 0...15
  199. const int ix = threadIdx.x%2;
  200. const int il = tid/4; // 0...3
  201. const int ir = tid - 4*il;// 0...3
  202. const int n = 2;
  203. const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
  204. const int in = il%2;
  205. const int l0 = n*(2*ir + in);
  206. const int q_offset = 32*im + l0;
  207. const int y_offset = 64*im + l0;
  208. const uint8_t hm1 = 1 << (2*im);
  209. const uint8_t hm2 = hm1 << 4;
  210. uint16_t aux[4];
  211. const uint8_t * sc = (const uint8_t *)aux;
  212. uint16_t q16[8];
  213. const uint8_t * q4 = (const uint8_t *)q16;
  214. for (int i = ix; i < num_blocks_per_row; i += 2) {
  215. const uint8_t * ql1 = x[i].qs + q_offset;
  216. const uint8_t * qh = x[i].qh + l0;
  217. const float * y1 = yy + i*QK_K + y_offset;
  218. const float * y2 = y1 + 128;
  219. const float dall = __low2half(x[i].dm);
  220. const float dmin = __high2half(x[i].dm);
  221. const uint16_t * a = (const uint16_t *)x[i].scales;
  222. aux[0] = a[im+0] & kmask1;
  223. aux[1] = a[im+2] & kmask1;
  224. aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
  225. aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
  226. float4 sum = {0.f, 0.f, 0.f, 0.f};
  227. float smin = 0;
  228. const uint16_t * q1 = (const uint16_t *)ql1;
  229. const uint16_t * q2 = q1 + 32;
  230. q16[0] = q1[0] & 0x0f0f;
  231. q16[1] = q1[8] & 0x0f0f;
  232. q16[2] = (q1[0] >> 4) & 0x0f0f;
  233. q16[3] = (q1[8] >> 4) & 0x0f0f;
  234. q16[4] = q2[0] & 0x0f0f;
  235. q16[5] = q2[8] & 0x0f0f;
  236. q16[6] = (q2[0] >> 4) & 0x0f0f;
  237. q16[7] = (q2[8] >> 4) & 0x0f0f;
  238. for (int l = 0; l < n; ++l) {
  239. sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
  240. + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0));
  241. sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
  242. + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0));
  243. sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
  244. + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0));
  245. sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
  246. + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0));
  247. smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
  248. + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
  249. }
  250. tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
  251. }
  252. // sum up partial sums and write back result
  253. tmp = warp_reduce_sum(tmp);
  254. if (threadIdx.x == 0) {
  255. dst[row] = tmp;
  256. }
  257. }
  258. static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
  259. static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
  260. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  261. if (row > nrows) return;
  262. const int num_blocks_per_row = ncols / QK_K;
  263. const int ib0 = row*num_blocks_per_row;
  264. const block_q6_K * x = (const block_q6_K *)vx + ib0;
  265. const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
  266. const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
  267. const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
  268. const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
  269. const int in = tid - step*im; // 0...15 or 0...7
  270. #if K_QUANTS_PER_ITERATION == 1
  271. const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
  272. const int is = 0;
  273. #else
  274. const int l0 = 4 * in; // 0, 4, 8, ..., 28
  275. const int is = in / 4;
  276. #endif
  277. const int ql_offset = 64*im + l0;
  278. const int qh_offset = 32*im + l0;
  279. const int s_offset = 8*im + is;
  280. const int y_offset = 128*im + l0;
  281. float tmp = 0; // partial sum for thread in warp
  282. for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  283. const float * y = yy + i * QK_K + y_offset;
  284. const uint8_t * ql = x[i].ql + ql_offset;
  285. const uint8_t * qh = x[i].qh + qh_offset;
  286. const int8_t * s = x[i].scales + s_offset;
  287. const float d = x[i].d;
  288. #if K_QUANTS_PER_ITERATION == 1
  289. float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
  290. + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
  291. + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
  292. + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
  293. + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
  294. + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
  295. + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
  296. +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
  297. tmp += sum;
  298. #else
  299. float sum = 0;
  300. for (int l = 0; l < 4; ++l) {
  301. sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
  302. + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
  303. + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
  304. + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
  305. }
  306. tmp += sum;
  307. #endif
  308. }
  309. // sum up partial sums and write back result
  310. tmp = warp_reduce_sum(tmp);
  311. if (tid == 0) {
  312. dst[row] = tmp;
  313. }
  314. }
  315. static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
  316. const half * x = (const half *) vx;
  317. // automatic half -> float type cast if dfloat == float
  318. v.x = x[ib + iqs + 0];
  319. v.y = x[ib + iqs + 1];
  320. }
  321. static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
  322. return type == GGML_TYPE_Q4_0 ? dequantize_q4_0 :
  323. type == GGML_TYPE_Q4_1 ? dequantize_q4_1 :
  324. type == GGML_TYPE_Q5_0 ? dequantize_q5_0 :
  325. type == GGML_TYPE_Q5_1 ? dequantize_q5_1 :
  326. type == GGML_TYPE_Q8_0 ? dequantize_q8_0 :
  327. type == GGML_TYPE_F16 ? convert_f16 :
  328. nullptr;
  329. }
  330. template <ggml_type type>
  331. static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
  332. constexpr int qk = ggml_cuda_type_traits<type>::qk; // quantized weights per x block
  333. constexpr int qr = ggml_cuda_type_traits<type>::qr; // number of quantized weights per data value in x block
  334. constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type);
  335. const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y;
  336. if (row >= nrows) {
  337. return;
  338. }
  339. const int tid = threadIdx.x;
  340. const int iter_stride = 2*GGML_CUDA_DMMV_X;
  341. const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
  342. const int y_offset = qr == 1 ? 1 : qk/2;
  343. // partial sum for each thread
  344. #ifdef GGML_CUDA_F16
  345. half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
  346. #else
  347. float tmp = 0.0f;
  348. #endif // GGML_CUDA_F16
  349. for (int i = 0; i < ncols; i += iter_stride) {
  350. const int col = i + vals_per_iter*tid;
  351. const int64_t ib = ((int64_t)row*ncols + col)/qk; // x block index
  352. const int iqs = (col%qk)/qr; // x quant index
  353. const int iybs = col - col%qk; // y block start index
  354. // processing >2 values per i iter is faster for fast GPUs
  355. #pragma unroll
  356. for (int j = 0; j < vals_per_iter; j += 2) {
  357. // process 2 vals per j iter
  358. // dequantize
  359. // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
  360. dfloat2 v;
  361. dequantize_kernel(vx, ib, iqs + j/qr, v);
  362. // matrix multiplication
  363. // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
  364. #ifdef GGML_CUDA_F16
  365. tmp += __hmul2(v, {
  366. y[iybs + iqs + j/qr + 0],
  367. y[iybs + iqs + j/qr + y_offset]
  368. });
  369. #else
  370. tmp += v.x * y[iybs + iqs + j/qr + 0];
  371. tmp += v.y * y[iybs + iqs + j/qr + y_offset];
  372. #endif // GGML_CUDA_F16
  373. }
  374. }
  375. // sum up partial sums and write back result
  376. tmp = warp_reduce_sum(tmp);
  377. if (tid == 0) {
  378. #ifdef GGML_CUDA_F16
  379. dst[row] = tmp.x + tmp.y;
  380. #else
  381. dst[row] = tmp;
  382. #endif // GGML_CUDA_F16
  383. }
  384. }
  385. static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  386. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  387. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  388. // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
  389. const dim3 block_nums(block_num_y, 1, 1);
  390. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  391. dequantize_mul_mat_vec<GGML_TYPE_Q4_0>
  392. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  393. }
  394. static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  395. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  396. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  397. const dim3 block_nums(block_num_y, 1, 1);
  398. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  399. dequantize_mul_mat_vec<GGML_TYPE_Q4_1>
  400. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  401. }
  402. static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  403. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  404. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  405. const dim3 block_nums(block_num_y, 1, 1);
  406. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  407. dequantize_mul_mat_vec<GGML_TYPE_Q5_0>
  408. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  409. }
  410. static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  411. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  412. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  413. const dim3 block_nums(block_num_y, 1, 1);
  414. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  415. dequantize_mul_mat_vec<GGML_TYPE_Q5_1>
  416. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  417. }
  418. static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  419. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  420. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  421. const dim3 block_nums(block_num_y, 1, 1);
  422. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  423. dequantize_mul_mat_vec<GGML_TYPE_Q8_0>
  424. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  425. }
  426. static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  427. GGML_ASSERT(ncols % QK_K == 0);
  428. const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
  429. const int block_num_y = (nrows + ny - 1) / ny;
  430. const dim3 block_nums(block_num_y, 1, 1);
  431. const dim3 block_dims(32, ny, 1);
  432. dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  433. }
  434. static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  435. GGML_ASSERT(ncols % QK_K == 0);
  436. const int ny = 2 / K_QUANTS_PER_ITERATION;
  437. const int block_num_y = (nrows + ny - 1) / ny;
  438. const dim3 block_nums(block_num_y, 1, 1);
  439. const dim3 block_dims(32, ny, 1);
  440. dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  441. }
  442. static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  443. GGML_ASSERT(ncols % QK_K == 0);
  444. const int ny = 2 / K_QUANTS_PER_ITERATION;
  445. const int block_num_y = (nrows + ny - 1) / ny;
  446. const dim3 block_nums(block_num_y, 1, 1);
  447. const dim3 block_dims(32, ny, 1);
  448. dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  449. }
  450. static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  451. GGML_ASSERT(ncols % QK_K == 0);
  452. const dim3 block_dims(32, 1, 1);
  453. dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
  454. }
  455. static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  456. GGML_ASSERT(ncols % QK_K == 0);
  457. const int ny = 2 / K_QUANTS_PER_ITERATION;
  458. const int block_num_y = (nrows + ny - 1) / ny;
  459. const dim3 block_nums(block_num_y, 1, 1);
  460. const dim3 block_dims(32, ny, 1);
  461. dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  462. }
  463. static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  464. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  465. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  466. const dim3 block_nums(block_num_y, 1, 1);
  467. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  468. dequantize_mul_mat_vec<GGML_TYPE_F16>
  469. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  470. }
  471. void ggml_cuda_op_dequantize_mul_mat_vec(
  472. ggml_backend_cuda_context & ctx,
  473. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
  474. const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
  475. const int64_t src1_padded_row_size, cudaStream_t stream) {
  476. GGML_UNUSED(ctx);
  477. const int64_t ne00 = src0->ne[0];
  478. const int64_t row_diff = row_high - row_low;
  479. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  480. // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
  481. #ifdef GGML_CUDA_F16
  482. ggml_cuda_pool_alloc<half> src1_dfloat_a(ctx.pool());
  483. half * src1_dfloat = nullptr; // dfloat == half
  484. bool src1_convert_f16 =
  485. src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
  486. src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
  487. src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
  488. if (src1_convert_f16) {
  489. src1_dfloat = src1_dfloat_a.alloc(ne00);
  490. const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
  491. GGML_ASSERT(to_fp16_cuda != nullptr);
  492. to_fp16_cuda(src1_ddf_i, src1_dfloat, ne00, stream);
  493. }
  494. #else
  495. const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
  496. #endif // GGML_CUDA_F16
  497. switch (src0->type) {
  498. case GGML_TYPE_Q4_0:
  499. dequantize_mul_mat_vec_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
  500. break;
  501. case GGML_TYPE_Q4_1:
  502. dequantize_mul_mat_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
  503. break;
  504. case GGML_TYPE_Q5_0:
  505. dequantize_mul_mat_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
  506. break;
  507. case GGML_TYPE_Q5_1:
  508. dequantize_mul_mat_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
  509. break;
  510. case GGML_TYPE_Q8_0:
  511. dequantize_mul_mat_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
  512. break;
  513. case GGML_TYPE_Q2_K:
  514. dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
  515. break;
  516. case GGML_TYPE_Q3_K:
  517. dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
  518. break;
  519. case GGML_TYPE_Q4_K:
  520. dequantize_mul_mat_vec_q4_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
  521. break;
  522. case GGML_TYPE_Q5_K:
  523. dequantize_mul_mat_vec_q5_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
  524. break;
  525. case GGML_TYPE_Q6_K:
  526. dequantize_mul_mat_vec_q6_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
  527. break;
  528. case GGML_TYPE_F16:
  529. convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
  530. break;
  531. default:
  532. GGML_ASSERT(false);
  533. break;
  534. }
  535. GGML_UNUSED(src1);
  536. GGML_UNUSED(dst);
  537. GGML_UNUSED(src1_ddq_i);
  538. GGML_UNUSED(src1_ncols);
  539. GGML_UNUSED(src1_padded_row_size);
  540. }