dmmv.cu 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821
  1. #include "dmmv.cuh"
  2. #include "dequantize.cuh"
  3. #include "convert.cuh"
  4. // dmmv = dequantize_mul_mat_vec
  5. #ifndef GGML_CUDA_DMMV_X
  6. #define GGML_CUDA_DMMV_X 32
  7. #endif
  8. #ifndef GGML_CUDA_MMV_Y
  9. #define GGML_CUDA_MMV_Y 1
  10. #endif
  11. #ifndef K_QUANTS_PER_ITERATION
  12. #define K_QUANTS_PER_ITERATION 2
  13. #else
  14. static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
  15. #endif
  16. static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
  17. static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
  18. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  19. if (row > nrows) return;
  20. const int num_blocks_per_row = ncols / QK_K;
  21. const int ib0 = row*num_blocks_per_row;
  22. const block_q2_K * x = (const block_q2_K *)vx + ib0;
  23. float tmp = 0; // partial sum for thread in warp
  24. #if QK_K == 256
  25. const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
  26. const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
  27. const int step = 16/K_QUANTS_PER_ITERATION;
  28. const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
  29. const int in = tid - step*im; // 0...15 or 0...7
  30. const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
  31. const int q_offset = 32*im + l0;
  32. const int s_offset = 8*im;
  33. const int y_offset = 128*im + l0;
  34. uint32_t aux[4];
  35. const uint8_t * d = (const uint8_t *)aux;
  36. const uint8_t * m = (const uint8_t *)(aux + 2);
  37. for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  38. const float * y = yy + i * QK_K + y_offset;
  39. const uint8_t * q = x[i].qs + q_offset;
  40. const float dall = __low2half(x[i].dm);
  41. const float dmin = __high2half(x[i].dm);
  42. const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
  43. aux[0] = a[0] & 0x0f0f0f0f;
  44. aux[1] = a[1] & 0x0f0f0f0f;
  45. aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
  46. aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
  47. float sum1 = 0, sum2 = 0;
  48. for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
  49. sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
  50. + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
  51. + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
  52. + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
  53. + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
  54. + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
  55. + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
  56. +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
  57. sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
  58. + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
  59. }
  60. tmp += dall * sum1 - dmin * sum2;
  61. }
  62. #else
  63. const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
  64. const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
  65. const int offset = tid * K_QUANTS_PER_ITERATION;
  66. uint32_t uaux[2];
  67. const uint8_t * d = (const uint8_t *)uaux;
  68. for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
  69. const float * y = yy + i * QK_K + offset;
  70. const uint8_t * q = x[i].qs + offset;
  71. const uint32_t * s = (const uint32_t *)x[i].scales;
  72. uaux[0] = s[0] & 0x0f0f0f0f;
  73. uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
  74. const float2 dall = __half22float2(x[i].dm);
  75. float sum1 = 0, sum2 = 0;
  76. for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
  77. const uint8_t ql = q[l];
  78. sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
  79. + y[l+16] * d[1] * ((ql >> 2) & 3)
  80. + y[l+32] * d[2] * ((ql >> 4) & 3)
  81. + y[l+48] * d[3] * ((ql >> 6) & 3);
  82. sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
  83. }
  84. tmp += dall.x * sum1 - dall.y * sum2;
  85. }
  86. #endif
  87. // sum up partial sums and write back result
  88. tmp = warp_reduce_sum(tmp);
  89. if (threadIdx.x == 0) {
  90. dst[row] = tmp;
  91. }
  92. }
  93. static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
  94. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  95. if (row > nrows) return;
  96. const int num_blocks_per_row = ncols / QK_K;
  97. const int ib0 = row*num_blocks_per_row;
  98. const block_q3_K * x = (const block_q3_K *)vx + ib0;
  99. float tmp = 0; // partial sum for thread in warp
  100. #if QK_K == 256
  101. const uint16_t kmask1 = 0x0303;
  102. const uint16_t kmask2 = 0x0f0f;
  103. const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
  104. const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
  105. const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
  106. const int step = 16/K_QUANTS_PER_ITERATION;
  107. const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
  108. const int in = tid - step*im; // 0....15 or 0...7
  109. const uint8_t m = 1 << (4*im);
  110. const int l0 = n*in; // 0...15 or 0...14 in steps of 2
  111. const int q_offset = 32*im + l0;
  112. const int y_offset = 128*im + l0;
  113. uint16_t utmp[4];
  114. const int8_t * s = (const int8_t *)utmp;
  115. const uint16_t s_shift = 4*im;
  116. for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  117. const float * y = yy + i * QK_K + y_offset;
  118. const uint8_t * q = x[i].qs + q_offset;
  119. const uint8_t * h = x[i].hmask + l0;
  120. const uint16_t * a = (const uint16_t *)x[i].scales;
  121. utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
  122. utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
  123. utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
  124. utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
  125. const float d = x[i].d;
  126. float sum = 0;
  127. for (int l = 0; l < n; ++l) {
  128. sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
  129. + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
  130. + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
  131. + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
  132. sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
  133. + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
  134. + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
  135. + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
  136. }
  137. tmp += d * sum;
  138. }
  139. #else
  140. const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
  141. const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
  142. const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14
  143. const int in = offset/8; // 0 or 1
  144. const int im = offset%8; // 0...7
  145. for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
  146. const float * y = yy + i * QK_K + offset;
  147. const uint8_t * q = x[i].qs + offset;
  148. const uint8_t * s = x[i].scales;
  149. const float dall = (float)x[i].d;
  150. float sum = 0;
  151. for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
  152. const uint8_t hl = x[i].hmask[im+l] >> in;
  153. const uint8_t ql = q[l];
  154. sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
  155. + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
  156. + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
  157. + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
  158. }
  159. tmp += sum;
  160. }
  161. #endif
  162. // sum up partial sums and write back result
  163. tmp = warp_reduce_sum(tmp);
  164. if (threadIdx.x == 0) {
  165. dst[row] = tmp;
  166. }
  167. }
  168. static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
  169. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  170. if (row > nrows) return;
  171. const int num_blocks_per_row = ncols / QK_K;
  172. const int ib0 = row*num_blocks_per_row;
  173. const block_q4_K * x = (const block_q4_K *)vx + ib0;
  174. #if QK_K == 256
  175. const uint16_t kmask1 = 0x3f3f;
  176. const uint16_t kmask2 = 0x0f0f;
  177. const uint16_t kmask3 = 0xc0c0;
  178. const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
  179. const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
  180. const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
  181. const int il = tid/step; // 0...3
  182. const int ir = tid - step*il; // 0...7 or 0...3
  183. const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
  184. const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
  185. const int in = il%2;
  186. const int l0 = n*(2*ir + in);
  187. const int q_offset = 32*im + l0;
  188. const int y_offset = 64*im + l0;
  189. uint16_t aux[4];
  190. const uint8_t * sc = (const uint8_t *)aux;
  191. #if K_QUANTS_PER_ITERATION == 2
  192. uint32_t q32[4];
  193. const uint8_t * q4 = (const uint8_t *)q32;
  194. #else
  195. uint16_t q16[4];
  196. const uint8_t * q4 = (const uint8_t *)q16;
  197. #endif
  198. float tmp = 0; // partial sum for thread in warp
  199. for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  200. const float * y1 = yy + i*QK_K + y_offset;
  201. const float * y2 = y1 + 128;
  202. const float dall = __low2half(x[i].dm);
  203. const float dmin = __high2half(x[i].dm);
  204. const uint16_t * a = (const uint16_t *)x[i].scales;
  205. aux[0] = a[im+0] & kmask1;
  206. aux[1] = a[im+2] & kmask1;
  207. aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
  208. aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
  209. #if K_QUANTS_PER_ITERATION == 2
  210. const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
  211. const uint32_t * q2 = q1 + 16;
  212. q32[0] = q1[0] & 0x0f0f0f0f;
  213. q32[1] = q1[0] & 0xf0f0f0f0;
  214. q32[2] = q2[0] & 0x0f0f0f0f;
  215. q32[3] = q2[0] & 0xf0f0f0f0;
  216. float4 s = {0.f, 0.f, 0.f, 0.f};
  217. float smin = 0;
  218. for (int l = 0; l < 4; ++l) {
  219. s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4];
  220. s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12];
  221. smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
  222. }
  223. tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
  224. #else
  225. const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
  226. const uint16_t * q2 = q1 + 32;
  227. q16[0] = q1[0] & 0x0f0f;
  228. q16[1] = q1[0] & 0xf0f0;
  229. q16[2] = q2[0] & 0x0f0f;
  230. q16[3] = q2[0] & 0xf0f0;
  231. float4 s = {0.f, 0.f, 0.f, 0.f};
  232. float smin = 0;
  233. for (int l = 0; l < 2; ++l) {
  234. s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
  235. s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
  236. smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
  237. }
  238. tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
  239. #endif
  240. }
  241. #else
  242. const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
  243. const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
  244. const int step = tid * K_QUANTS_PER_ITERATION;
  245. uint16_t aux16[2];
  246. const uint8_t * s = (const uint8_t *)aux16;
  247. float tmp = 0;
  248. for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
  249. const uint8_t * q = x[i].qs + step;
  250. const float * y = yy + i*QK_K + step;
  251. const uint16_t * a = (const uint16_t *)x[i].scales;
  252. aux16[0] = a[0] & 0x0f0f;
  253. aux16[1] = (a[0] >> 4) & 0x0f0f;
  254. const float d = (float)x[i].dm[0];
  255. const float m = (float)x[i].dm[1];
  256. float sum = 0.f;
  257. for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
  258. sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
  259. + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
  260. + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
  261. + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]);
  262. }
  263. tmp += sum;
  264. }
  265. #endif
  266. // sum up partial sums and write back result
  267. tmp = warp_reduce_sum(tmp);
  268. if (tid == 0) {
  269. dst[row] = tmp;
  270. }
  271. }
  272. static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) {
  273. const int row = blockIdx.x;
  274. const int num_blocks_per_row = ncols / QK_K;
  275. const int ib0 = row*num_blocks_per_row;
  276. const block_q5_K * x = (const block_q5_K *)vx + ib0;
  277. float tmp = 0; // partial sum for thread in warp
  278. #if QK_K == 256
  279. const uint16_t kmask1 = 0x3f3f;
  280. const uint16_t kmask2 = 0x0f0f;
  281. const uint16_t kmask3 = 0xc0c0;
  282. const int tid = threadIdx.x/2; // 0...15
  283. const int ix = threadIdx.x%2;
  284. const int il = tid/4; // 0...3
  285. const int ir = tid - 4*il;// 0...3
  286. const int n = 2;
  287. const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
  288. const int in = il%2;
  289. const int l0 = n*(2*ir + in);
  290. const int q_offset = 32*im + l0;
  291. const int y_offset = 64*im + l0;
  292. const uint8_t hm1 = 1 << (2*im);
  293. const uint8_t hm2 = hm1 << 4;
  294. uint16_t aux[4];
  295. const uint8_t * sc = (const uint8_t *)aux;
  296. uint16_t q16[8];
  297. const uint8_t * q4 = (const uint8_t *)q16;
  298. for (int i = ix; i < num_blocks_per_row; i += 2) {
  299. const uint8_t * ql1 = x[i].qs + q_offset;
  300. const uint8_t * qh = x[i].qh + l0;
  301. const float * y1 = yy + i*QK_K + y_offset;
  302. const float * y2 = y1 + 128;
  303. const float dall = __low2half(x[i].dm);
  304. const float dmin = __high2half(x[i].dm);
  305. const uint16_t * a = (const uint16_t *)x[i].scales;
  306. aux[0] = a[im+0] & kmask1;
  307. aux[1] = a[im+2] & kmask1;
  308. aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
  309. aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
  310. float4 sum = {0.f, 0.f, 0.f, 0.f};
  311. float smin = 0;
  312. const uint16_t * q1 = (const uint16_t *)ql1;
  313. const uint16_t * q2 = q1 + 32;
  314. q16[0] = q1[0] & 0x0f0f;
  315. q16[1] = q1[8] & 0x0f0f;
  316. q16[2] = (q1[0] >> 4) & 0x0f0f;
  317. q16[3] = (q1[8] >> 4) & 0x0f0f;
  318. q16[4] = q2[0] & 0x0f0f;
  319. q16[5] = q2[8] & 0x0f0f;
  320. q16[6] = (q2[0] >> 4) & 0x0f0f;
  321. q16[7] = (q2[8] >> 4) & 0x0f0f;
  322. for (int l = 0; l < n; ++l) {
  323. sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
  324. + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0));
  325. sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
  326. + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0));
  327. sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
  328. + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0));
  329. sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
  330. + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0));
  331. smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
  332. + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
  333. }
  334. tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
  335. }
  336. #else
  337. const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
  338. const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
  339. const int step = tid * K_QUANTS_PER_ITERATION;
  340. const int im = step/8;
  341. const int in = step%8;
  342. for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
  343. const uint8_t * q = x[i].qs + step;
  344. const int8_t * s = x[i].scales;
  345. const float * y = yy + i*QK_K + step;
  346. const float d = x[i].d;
  347. float sum = 0.f;
  348. for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
  349. const uint8_t h = x[i].qh[in+j] >> im;
  350. sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
  351. + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
  352. + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16))
  353. + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16));
  354. }
  355. tmp += sum;
  356. }
  357. #endif
  358. // sum up partial sums and write back result
  359. tmp = warp_reduce_sum(tmp);
  360. if (threadIdx.x == 0) {
  361. dst[row] = tmp;
  362. }
  363. }
  364. static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
  365. static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
  366. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  367. if (row > nrows) return;
  368. const int num_blocks_per_row = ncols / QK_K;
  369. const int ib0 = row*num_blocks_per_row;
  370. const block_q6_K * x = (const block_q6_K *)vx + ib0;
  371. #if QK_K == 256
  372. const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
  373. const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
  374. const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
  375. const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
  376. const int in = tid - step*im; // 0...15 or 0...7
  377. #if K_QUANTS_PER_ITERATION == 1
  378. const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
  379. const int is = 0;
  380. #else
  381. const int l0 = 4 * in; // 0, 4, 8, ..., 28
  382. const int is = in / 4;
  383. #endif
  384. const int ql_offset = 64*im + l0;
  385. const int qh_offset = 32*im + l0;
  386. const int s_offset = 8*im + is;
  387. const int y_offset = 128*im + l0;
  388. float tmp = 0; // partial sum for thread in warp
  389. for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  390. const float * y = yy + i * QK_K + y_offset;
  391. const uint8_t * ql = x[i].ql + ql_offset;
  392. const uint8_t * qh = x[i].qh + qh_offset;
  393. const int8_t * s = x[i].scales + s_offset;
  394. const float d = x[i].d;
  395. #if K_QUANTS_PER_ITERATION == 1
  396. float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
  397. + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
  398. + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
  399. + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
  400. + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
  401. + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
  402. + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
  403. +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
  404. tmp += sum;
  405. #else
  406. float sum = 0;
  407. for (int l = 0; l < 4; ++l) {
  408. sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
  409. + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
  410. + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
  411. + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
  412. }
  413. tmp += sum;
  414. #endif
  415. }
  416. #else
  417. const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7
  418. const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3
  419. const int step = tid * K_QUANTS_PER_ITERATION;
  420. float tmp = 0; // partial sum for thread in warp
  421. for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
  422. const float * y = yy + i * QK_K + step;
  423. const uint8_t * ql = x[i].ql + step;
  424. const uint8_t * qh = x[i].qh + step;
  425. const int8_t * s = x[i].scales;
  426. const float d = x[i+0].d;
  427. float sum = 0;
  428. for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
  429. sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
  430. + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
  431. + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
  432. + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32);
  433. }
  434. tmp += sum;
  435. }
  436. #endif
  437. // sum up partial sums and write back result
  438. tmp = warp_reduce_sum(tmp);
  439. if (tid == 0) {
  440. dst[row] = tmp;
  441. }
  442. }
  443. static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
  444. const half * x = (const half *) vx;
  445. // automatic half -> float type cast if dfloat == float
  446. v.x = x[ib + iqs + 0];
  447. v.y = x[ib + iqs + 1];
  448. }
  449. template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
  450. static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
  451. // qk = quantized weights per x block
  452. // qr = number of quantized weights per data value in x block
  453. const int row = blockIdx.x*blockDim.y + threadIdx.y;
  454. if (row >= nrows) {
  455. return;
  456. }
  457. const int tid = threadIdx.x;
  458. const int iter_stride = 2*GGML_CUDA_DMMV_X;
  459. const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
  460. const int y_offset = qr == 1 ? 1 : qk/2;
  461. // partial sum for each thread
  462. #ifdef GGML_CUDA_F16
  463. half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
  464. #else
  465. float tmp = 0.0f;
  466. #endif // GGML_CUDA_F16
  467. for (int i = 0; i < ncols; i += iter_stride) {
  468. const int col = i + vals_per_iter*tid;
  469. const int ib = (row*ncols + col)/qk; // x block index
  470. const int iqs = (col%qk)/qr; // x quant index
  471. const int iybs = col - col%qk; // y block start index
  472. // processing >2 values per i iter is faster for fast GPUs
  473. #pragma unroll
  474. for (int j = 0; j < vals_per_iter; j += 2) {
  475. // process 2 vals per j iter
  476. // dequantize
  477. // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
  478. dfloat2 v;
  479. dequantize_kernel(vx, ib, iqs + j/qr, v);
  480. // matrix multiplication
  481. // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
  482. #ifdef GGML_CUDA_F16
  483. tmp += __hmul2(v, {
  484. y[iybs + iqs + j/qr + 0],
  485. y[iybs + iqs + j/qr + y_offset]
  486. });
  487. #else
  488. tmp += v.x * y[iybs + iqs + j/qr + 0];
  489. tmp += v.y * y[iybs + iqs + j/qr + y_offset];
  490. #endif // GGML_CUDA_F16
  491. }
  492. }
  493. // sum up partial sums and write back result
  494. tmp = warp_reduce_sum(tmp);
  495. if (tid == 0) {
  496. #ifdef GGML_CUDA_F16
  497. dst[row] = tmp.x + tmp.y;
  498. #else
  499. dst[row] = tmp;
  500. #endif // GGML_CUDA_F16
  501. }
  502. }
  503. static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  504. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  505. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  506. // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
  507. const dim3 block_nums(block_num_y, 1, 1);
  508. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  509. dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
  510. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  511. }
  512. static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  513. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  514. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  515. const dim3 block_nums(block_num_y, 1, 1);
  516. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  517. dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
  518. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  519. }
  520. static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  521. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  522. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  523. const dim3 block_nums(block_num_y, 1, 1);
  524. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  525. dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
  526. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  527. }
  528. static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  529. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  530. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  531. const dim3 block_nums(block_num_y, 1, 1);
  532. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  533. dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
  534. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  535. }
  536. static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  537. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  538. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  539. const dim3 block_nums(block_num_y, 1, 1);
  540. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  541. dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
  542. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  543. }
  544. static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  545. GGML_ASSERT(ncols % QK_K == 0);
  546. const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
  547. const int block_num_y = (nrows + ny - 1) / ny;
  548. const dim3 block_nums(block_num_y, 1, 1);
  549. const dim3 block_dims(32, ny, 1);
  550. dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  551. }
  552. static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  553. GGML_ASSERT(ncols % QK_K == 0);
  554. const int ny = 2 / K_QUANTS_PER_ITERATION;
  555. const int block_num_y = (nrows + ny - 1) / ny;
  556. const dim3 block_nums(block_num_y, 1, 1);
  557. const dim3 block_dims(32, ny, 1);
  558. dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  559. }
  560. static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  561. GGML_ASSERT(ncols % QK_K == 0);
  562. const int ny = 2 / K_QUANTS_PER_ITERATION;
  563. const int block_num_y = (nrows + ny - 1) / ny;
  564. const dim3 block_nums(block_num_y, 1, 1);
  565. const dim3 block_dims(32, ny, 1);
  566. dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  567. }
  568. static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  569. GGML_ASSERT(ncols % QK_K == 0);
  570. const dim3 block_dims(32, 1, 1);
  571. dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
  572. }
  573. static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  574. GGML_ASSERT(ncols % QK_K == 0);
  575. const int ny = 2 / K_QUANTS_PER_ITERATION;
  576. const int block_num_y = (nrows + ny - 1) / ny;
  577. const dim3 block_nums(block_num_y, 1, 1);
  578. const dim3 block_dims(32, ny, 1);
  579. dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  580. }
  581. static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
  582. GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
  583. const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
  584. const dim3 block_nums(block_num_y, 1, 1);
  585. const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
  586. dequantize_mul_mat_vec<1, 1, convert_f16>
  587. <<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
  588. }
  589. void ggml_cuda_op_dequantize_mul_mat_vec(
  590. ggml_backend_cuda_context & ctx,
  591. const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
  592. const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
  593. const int64_t src1_padded_row_size, cudaStream_t stream) {
  594. GGML_UNUSED(ctx);
  595. const int64_t ne00 = src0->ne[0];
  596. const int64_t row_diff = row_high - row_low;
  597. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  598. // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
  599. #ifdef GGML_CUDA_F16
  600. ggml_cuda_pool_alloc<half> src1_dfloat_a(ctx.pool());
  601. half * src1_dfloat = nullptr; // dfloat == half
  602. bool src1_convert_f16 =
  603. src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
  604. src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
  605. src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
  606. if (src1_convert_f16) {
  607. src1_dfloat = src1_dfloat_a.alloc(ne00);
  608. const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
  609. GGML_ASSERT(to_fp16_cuda != nullptr);
  610. to_fp16_cuda(src1_ddf_i, src1_dfloat, ne00, stream);
  611. }
  612. #else
  613. const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
  614. #endif // GGML_CUDA_F16
  615. switch (src0->type) {
  616. case GGML_TYPE_Q4_0:
  617. dequantize_mul_mat_vec_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
  618. break;
  619. case GGML_TYPE_Q4_1:
  620. dequantize_mul_mat_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
  621. break;
  622. case GGML_TYPE_Q5_0:
  623. dequantize_mul_mat_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
  624. break;
  625. case GGML_TYPE_Q5_1:
  626. dequantize_mul_mat_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
  627. break;
  628. case GGML_TYPE_Q8_0:
  629. dequantize_mul_mat_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
  630. break;
  631. case GGML_TYPE_Q2_K:
  632. dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
  633. break;
  634. case GGML_TYPE_Q3_K:
  635. dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
  636. break;
  637. case GGML_TYPE_Q4_K:
  638. dequantize_mul_mat_vec_q4_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
  639. break;
  640. case GGML_TYPE_Q5_K:
  641. dequantize_mul_mat_vec_q5_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
  642. break;
  643. case GGML_TYPE_Q6_K:
  644. dequantize_mul_mat_vec_q6_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
  645. break;
  646. case GGML_TYPE_F16:
  647. convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
  648. break;
  649. default:
  650. GGML_ASSERT(false);
  651. break;
  652. }
  653. GGML_UNUSED(src1);
  654. GGML_UNUSED(dst);
  655. GGML_UNUSED(src1_ddq_i);
  656. GGML_UNUSED(src1_ncols);
  657. GGML_UNUSED(src1_padded_row_size);
  658. }