convert.cu 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686
  1. #include "convert.cuh"
  2. #include "dequantize.cuh"
  3. #define CUDA_Q8_0_NE_ALIGN 2048
  4. template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
  5. static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
  6. const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x);
  7. if (i >= k) {
  8. return;
  9. }
  10. const int64_t ib = i/qk; // block index
  11. const int64_t iqs = (i%qk)/qr; // quant index
  12. const int64_t iybs = i - i%qk; // y block start index
  13. const int64_t y_offset = qr == 1 ? 1 : qk/2;
  14. // dequantize
  15. dfloat2 v;
  16. dequantize_kernel(vx, ib, iqs, v);
  17. y[iybs + iqs + 0] = v.x;
  18. y[iybs + iqs + y_offset] = v.y;
  19. }
  20. template <bool need_check>
  21. static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
  22. #if __CUDA_ARCH__ >= CC_PASCAL
  23. constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
  24. const int64_t i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
  25. const int * x0 = ((int *) vx) + blockIdx.x * nint;
  26. half2 * y2 = (half2 *) (y + i0);
  27. __shared__ int vals[nint];
  28. #pragma unroll
  29. for (int ix0 = 0; ix0 < nint; ix0 += WARP_SIZE) {
  30. if (need_check && i0*sizeof(block_q8_0)/QK8_0 + sizeof(int)*(ix0 + threadIdx.x) >= k*sizeof(block_q8_0)/QK8_0) {
  31. break;
  32. }
  33. const int ix = ix0 + threadIdx.x;
  34. vals[ix] = x0[ix];
  35. }
  36. __syncthreads();
  37. #pragma unroll
  38. for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
  39. if (need_check && i0 + iy + 2*threadIdx.x >= k) {
  40. return;
  41. }
  42. const half * b0 = ((const half *) vals) + (sizeof(block_q8_0)/sizeof(half)) * ((iy + 2*threadIdx.x)/QK8_0);
  43. const half d = *b0;
  44. const char2 qs = ((const char2 *) (b0 + 1))[threadIdx.x % (QK8_0/2)];
  45. y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d));
  46. }
  47. #else
  48. GGML_UNUSED(vx);
  49. GGML_UNUSED(y);
  50. GGML_UNUSED(k);
  51. NO_DEVICE_CODE;
  52. #endif // __CUDA_ARCH__ >= CC_PASCAL
  53. }
  54. template<typename dst_t>
  55. static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
  56. const int64_t i = blockIdx.x;
  57. // assume 32 threads
  58. const int64_t tid = threadIdx.x;
  59. const int64_t il = tid/8;
  60. const int64_t ir = tid%8;
  61. const int64_t ib = 8*i + ir;
  62. if (ib >= nb32) {
  63. return;
  64. }
  65. dst_t * y = yy + 256*i + 32*ir + 4*il;
  66. const block_q4_0 * x = (const block_q4_0 *)vx + ib;
  67. const float d = __half2float(x->d);
  68. const float dm = -8*d;
  69. const uint8_t * q = x->qs + 4*il;
  70. for (int l = 0; l < 4; ++l) {
  71. y[l+ 0] = d * (q[l] & 0xF) + dm;
  72. y[l+16] = d * (q[l] >> 4) + dm;
  73. }
  74. }
  75. template<typename dst_t>
  76. static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
  77. const int64_t i = blockIdx.x;
  78. // assume 32 threads
  79. const int64_t tid = threadIdx.x;
  80. const int64_t il = tid/8;
  81. const int64_t ir = tid%8;
  82. const int64_t ib = 8*i + ir;
  83. if (ib >= nb32) {
  84. return;
  85. }
  86. dst_t * y = yy + 256*i + 32*ir + 4*il;
  87. const block_q4_1 * x = (const block_q4_1 *)vx + ib;
  88. const float2 d = __half22float2(x->dm);
  89. const uint8_t * q = x->qs + 4*il;
  90. for (int l = 0; l < 4; ++l) {
  91. y[l+ 0] = d.x * (q[l] & 0xF) + d.y;
  92. y[l+16] = d.x * (q[l] >> 4) + d.y;
  93. }
  94. }
  95. //================================== k-quants
  96. template<typename dst_t>
  97. static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  98. const int64_t i = blockIdx.x;
  99. const block_q2_K * x = (const block_q2_K *) vx;
  100. const int64_t tid = threadIdx.x;
  101. const int64_t n = tid/32;
  102. const int64_t l = tid - 32*n;
  103. const int64_t is = 8*n + l/16;
  104. const uint8_t q = x[i].qs[32*n + l];
  105. dst_t * y = yy + i*QK_K + 128*n;
  106. float dall = __low2half(x[i].dm);
  107. float dmin = __high2half(x[i].dm);
  108. y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
  109. y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
  110. y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
  111. y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
  112. }
  113. template<typename dst_t>
  114. static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  115. const int64_t i = blockIdx.x;
  116. const block_q3_K * x = (const block_q3_K *) vx;
  117. const int64_t r = threadIdx.x/4;
  118. const int64_t tid = r/2;
  119. const int64_t is0 = r%2;
  120. const int64_t l0 = 16*is0 + 4*(threadIdx.x%4);
  121. const int64_t n = tid / 4;
  122. const int64_t j = tid - 4*n;
  123. uint8_t m = 1 << (4*n + j);
  124. int64_t is = 8*n + 2*j + is0;
  125. int shift = 2*j;
  126. int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
  127. is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
  128. is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
  129. (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
  130. float d_all = x[i].d;
  131. float dl = d_all * (us - 32);
  132. dst_t * y = yy + i*QK_K + 128*n + 32*j;
  133. const uint8_t * q = x[i].qs + 32*n;
  134. const uint8_t * hm = x[i].hmask;
  135. for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
  136. }
  137. static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
  138. if (j < 4) {
  139. d = q[j] & 63; m = q[j + 4] & 63;
  140. } else {
  141. d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
  142. m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
  143. }
  144. }
  145. template<typename dst_t>
  146. static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  147. const block_q4_K * x = (const block_q4_K *) vx;
  148. const int64_t i = blockIdx.x;
  149. // assume 32 threads
  150. const int64_t tid = threadIdx.x;
  151. const int64_t il = tid/8;
  152. const int64_t ir = tid%8;
  153. const int64_t is = 2*il;
  154. const int64_t n = 4;
  155. dst_t * y = yy + i*QK_K + 64*il + n*ir;
  156. const float dall = __low2half(x[i].dm);
  157. const float dmin = __high2half(x[i].dm);
  158. const uint8_t * q = x[i].qs + 32*il + n*ir;
  159. uint8_t sc, m;
  160. get_scale_min_k4(is + 0, x[i].scales, sc, m);
  161. const float d1 = dall * sc; const float m1 = dmin * m;
  162. get_scale_min_k4(is + 1, x[i].scales, sc, m);
  163. const float d2 = dall * sc; const float m2 = dmin * m;
  164. for (int l = 0; l < n; ++l) {
  165. y[l + 0] = d1 * (q[l] & 0xF) - m1;
  166. y[l +32] = d2 * (q[l] >> 4) - m2;
  167. }
  168. }
  169. template<typename dst_t>
  170. static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  171. const block_q5_K * x = (const block_q5_K *) vx;
  172. const int64_t i = blockIdx.x;
  173. // assume 64 threads - this is very slightly better than the one below
  174. const int64_t tid = threadIdx.x;
  175. const int64_t il = tid/16; // il is in 0...3
  176. const int64_t ir = tid%16; // ir is in 0...15
  177. const int64_t is = 2*il; // is is in 0...6
  178. dst_t * y = yy + i*QK_K + 64*il + 2*ir;
  179. const float dall = __low2half(x[i].dm);
  180. const float dmin = __high2half(x[i].dm);
  181. const uint8_t * ql = x[i].qs + 32*il + 2*ir;
  182. const uint8_t * qh = x[i].qh + 2*ir;
  183. uint8_t sc, m;
  184. get_scale_min_k4(is + 0, x[i].scales, sc, m);
  185. const float d1 = dall * sc; const float m1 = dmin * m;
  186. get_scale_min_k4(is + 1, x[i].scales, sc, m);
  187. const float d2 = dall * sc; const float m2 = dmin * m;
  188. uint8_t hm = 1 << (2*il);
  189. y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
  190. y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
  191. hm <<= 1;
  192. y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
  193. y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
  194. }
  195. template<typename dst_t>
  196. static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  197. const block_q6_K * x = (const block_q6_K *) vx;
  198. const int64_t i = blockIdx.x;
  199. // assume 64 threads - this is very slightly better than the one below
  200. const int64_t tid = threadIdx.x;
  201. const int64_t ip = tid/32; // ip is 0 or 1
  202. const int64_t il = tid - 32*ip; // 0...32
  203. const int64_t is = 8*ip + il/16;
  204. dst_t * y = yy + i*QK_K + 128*ip + il;
  205. const float d = x[i].d;
  206. const uint8_t * ql = x[i].ql + 64*ip + il;
  207. const uint8_t qh = x[i].qh[32*ip + il];
  208. const int8_t * sc = x[i].scales + is;
  209. y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
  210. y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
  211. y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
  212. y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
  213. }
  214. template<typename dst_t>
  215. static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  216. const int64_t i = blockIdx.x;
  217. const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
  218. const int64_t tid = threadIdx.x;
  219. const int64_t il = tid/8; // 0...3
  220. const int64_t ib = tid%8; // 0...7
  221. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  222. const uint16_t * q2 = x[i].qs + 4*ib;
  223. const uint8_t * aux8 = (const uint8_t *)q2;
  224. const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
  225. const uint32_t aux32 = q2[2] | (q2[3] << 16);
  226. const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
  227. const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
  228. for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
  229. }
  230. template<typename dst_t>
  231. static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  232. const int64_t i = blockIdx.x;
  233. const block_iq2_xs * x = (const block_iq2_xs *) vx;
  234. const int64_t tid = threadIdx.x;
  235. const int64_t il = tid/8; // 0...3
  236. const int64_t ib = tid%8; // 0...7
  237. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  238. const uint16_t * q2 = x[i].qs + 4*ib;
  239. const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
  240. const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
  241. const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
  242. for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
  243. }
  244. template<typename dst_t>
  245. static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  246. const int64_t i = blockIdx.x;
  247. const block_iq2_s * x = (const block_iq2_s *) vx;
  248. const int64_t tid = threadIdx.x;
  249. const int64_t il = tid/8; // 0...3
  250. const int64_t ib = tid%8; // 0...7
  251. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  252. const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
  253. const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
  254. const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
  255. for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
  256. }
  257. template<typename dst_t>
  258. static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  259. const int64_t i = blockIdx.x;
  260. const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
  261. const int64_t tid = threadIdx.x;
  262. const int64_t il = tid/8; // 0...3
  263. const int64_t ib = tid%8; // 0...7
  264. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  265. const uint8_t * q3 = x[i].qs + 8*ib;
  266. const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
  267. const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
  268. const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
  269. const uint32_t aux32 = gas[0] | (gas[1] << 16);
  270. const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
  271. const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
  272. for (int j = 0; j < 4; ++j) {
  273. y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
  274. y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
  275. }
  276. }
  277. template<typename dst_t>
  278. static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  279. const int64_t i = blockIdx.x;
  280. const block_iq3_s * x = (const block_iq3_s *) vx;
  281. const int64_t tid = threadIdx.x;
  282. const int64_t il = tid/8; // 0...3
  283. const int64_t ib = tid%8; // 0...7
  284. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  285. const uint8_t * qs = x[i].qs + 8*ib;
  286. const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
  287. const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
  288. const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
  289. const uint8_t signs = x[i].signs[4*ib + il];
  290. for (int j = 0; j < 4; ++j) {
  291. y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
  292. y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
  293. }
  294. }
  295. template<typename dst_t>
  296. static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  297. const int64_t i = blockIdx.x;
  298. const block_iq1_s * x = (const block_iq1_s *) vx;
  299. const int64_t tid = threadIdx.x;
  300. const int64_t il = tid/8; // 0...3
  301. const int64_t ib = tid%8; // 0...7
  302. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  303. const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
  304. const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
  305. uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
  306. grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
  307. grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
  308. grid32[0] &= 0x0f0f0f0f;
  309. for (int j = 0; j < 8; ++j) {
  310. y[j] = d * (q[j] + delta);
  311. }
  312. }
  313. template<typename dst_t>
  314. static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  315. const int64_t i = blockIdx.x;
  316. const block_iq1_m * x = (const block_iq1_m *) vx;
  317. const int64_t tid = threadIdx.x;
  318. const int64_t il = tid/8; // 0...3
  319. const int64_t ib = tid%8; // 0...7
  320. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  321. const uint16_t * sc = (const uint16_t *)x[i].scales;
  322. iq1m_scale_t scale;
  323. scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
  324. const int64_t ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
  325. const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
  326. const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
  327. uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
  328. grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
  329. grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
  330. grid32[0] &= 0x0f0f0f0f;
  331. for (int j = 0; j < 8; ++j) {
  332. y[j] = d * (q[j] + delta);
  333. }
  334. }
  335. template<typename dst_t>
  336. static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  337. const int64_t i = blockIdx.x;
  338. const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
  339. const int64_t tid = threadIdx.x;
  340. const int64_t il = tid/8; // 0...3
  341. const int64_t ib = tid%8; // 0...7
  342. dst_t * y = yy + i*QK_K + 32*ib + 4*il;
  343. const uint8_t * q4 = x[ib].qs + 4*il;
  344. const float d = (float)x[ib].d;
  345. for (int j = 0; j < 4; ++j) {
  346. y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
  347. y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
  348. }
  349. }
  350. template<typename dst_t>
  351. static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  352. const int64_t i = blockIdx.x;
  353. const block_iq4_xs * x = (const block_iq4_xs *)vx;
  354. const int64_t tid = threadIdx.x;
  355. const int64_t il = tid/8; // 0...3
  356. const int64_t ib = tid%8; // 0...7
  357. dst_t * y = yy + i*QK_K + 32*ib + 4*il;
  358. const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
  359. const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
  360. for (int j = 0; j < 4; ++j) {
  361. y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
  362. y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
  363. }
  364. }
  365. template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
  366. static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
  367. const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
  368. dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  369. }
  370. static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
  371. const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
  372. if (k % CUDA_Q8_0_NE_ALIGN == 0) {
  373. const bool need_check = false;
  374. dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
  375. } else {
  376. const bool need_check = true;
  377. dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
  378. }
  379. }
  380. template<typename dst_t>
  381. static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  382. const int nb = k / QK_K;
  383. dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
  384. }
  385. template<typename dst_t>
  386. static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  387. const int nb = k / QK_K;
  388. dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
  389. }
  390. template<typename dst_t>
  391. static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  392. const int nb32 = k / 32;
  393. const int nb = (k + 255) / 256;
  394. dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
  395. }
  396. template<typename dst_t>
  397. static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  398. const int nb32 = k / 32;
  399. const int nb = (k + 255) / 256;
  400. dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
  401. }
  402. template<typename dst_t>
  403. static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  404. const int nb = k / QK_K;
  405. dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
  406. }
  407. template<typename dst_t>
  408. static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  409. const int nb = k / QK_K;
  410. dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
  411. }
  412. template<typename dst_t>
  413. static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  414. const int nb = k / QK_K;
  415. dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
  416. }
  417. template<typename dst_t>
  418. static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  419. const int nb = k / QK_K;
  420. dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
  421. }
  422. template<typename dst_t>
  423. static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  424. const int nb = k / QK_K;
  425. dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
  426. }
  427. template<typename dst_t>
  428. static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  429. const int nb = k / QK_K;
  430. dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
  431. }
  432. template<typename dst_t>
  433. static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  434. const int nb = k / QK_K;
  435. dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
  436. }
  437. template<typename dst_t>
  438. static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  439. const int nb = k / QK_K;
  440. dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
  441. }
  442. template<typename dst_t>
  443. static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  444. const int nb = k / QK_K;
  445. dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
  446. }
  447. template<typename dst_t>
  448. static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  449. const int nb = (k + QK_K - 1) / QK_K;
  450. dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
  451. }
  452. template<typename dst_t>
  453. static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  454. const int nb = k / QK_K;
  455. dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
  456. }
  457. template<typename dst_t>
  458. static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  459. const int nb = (k + QK_K - 1) / QK_K;
  460. dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
  461. }
  462. template <typename src_t, typename dst_t>
  463. static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
  464. const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
  465. if (i >= k) {
  466. return;
  467. }
  468. const src_t * x = (src_t *) vx;
  469. y[i] = x[i];
  470. }
  471. template <typename src_t, typename dst_t>
  472. static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
  473. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  474. convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  475. }
  476. to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
  477. switch (type) {
  478. case GGML_TYPE_Q4_0:
  479. return dequantize_row_q4_0_cuda;
  480. case GGML_TYPE_Q4_1:
  481. return dequantize_row_q4_1_cuda;
  482. case GGML_TYPE_Q5_0:
  483. return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
  484. case GGML_TYPE_Q5_1:
  485. return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
  486. case GGML_TYPE_Q8_0:
  487. if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) {
  488. return dequantize_block_q8_0_f16_cuda;
  489. }
  490. return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
  491. case GGML_TYPE_Q2_K:
  492. return dequantize_row_q2_K_cuda;
  493. case GGML_TYPE_Q3_K:
  494. return dequantize_row_q3_K_cuda;
  495. case GGML_TYPE_Q4_K:
  496. return dequantize_row_q4_K_cuda;
  497. case GGML_TYPE_Q5_K:
  498. return dequantize_row_q5_K_cuda;
  499. case GGML_TYPE_Q6_K:
  500. return dequantize_row_q6_K_cuda;
  501. case GGML_TYPE_IQ2_XXS:
  502. return dequantize_row_iq2_xxs_cuda;
  503. case GGML_TYPE_IQ2_XS:
  504. return dequantize_row_iq2_xs_cuda;
  505. case GGML_TYPE_IQ2_S:
  506. return dequantize_row_iq2_s_cuda;
  507. case GGML_TYPE_IQ3_XXS:
  508. return dequantize_row_iq3_xxs_cuda;
  509. case GGML_TYPE_IQ1_S:
  510. return dequantize_row_iq1_s_cuda;
  511. case GGML_TYPE_IQ1_M:
  512. return dequantize_row_iq1_m_cuda;
  513. case GGML_TYPE_IQ4_NL:
  514. return dequantize_row_iq4_nl_cuda;
  515. case GGML_TYPE_IQ4_XS:
  516. return dequantize_row_iq4_xs_cuda;
  517. case GGML_TYPE_IQ3_S:
  518. return dequantize_row_iq3_s_cuda;
  519. case GGML_TYPE_F32:
  520. return convert_unary_cuda<float>;
  521. default:
  522. return nullptr;
  523. }
  524. }
  525. to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
  526. switch (type) {
  527. case GGML_TYPE_Q4_0:
  528. return dequantize_row_q4_0_cuda;
  529. case GGML_TYPE_Q4_1:
  530. return dequantize_row_q4_1_cuda;
  531. case GGML_TYPE_Q5_0:
  532. return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
  533. case GGML_TYPE_Q5_1:
  534. return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
  535. case GGML_TYPE_Q8_0:
  536. return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
  537. case GGML_TYPE_Q2_K:
  538. return dequantize_row_q2_K_cuda;
  539. case GGML_TYPE_Q3_K:
  540. return dequantize_row_q3_K_cuda;
  541. case GGML_TYPE_Q4_K:
  542. return dequantize_row_q4_K_cuda;
  543. case GGML_TYPE_Q5_K:
  544. return dequantize_row_q5_K_cuda;
  545. case GGML_TYPE_Q6_K:
  546. return dequantize_row_q6_K_cuda;
  547. case GGML_TYPE_IQ2_XXS:
  548. return dequantize_row_iq2_xxs_cuda;
  549. case GGML_TYPE_IQ2_XS:
  550. return dequantize_row_iq2_xs_cuda;
  551. case GGML_TYPE_IQ2_S:
  552. return dequantize_row_iq2_s_cuda;
  553. case GGML_TYPE_IQ3_XXS:
  554. return dequantize_row_iq3_xxs_cuda;
  555. case GGML_TYPE_IQ1_S:
  556. return dequantize_row_iq1_s_cuda;
  557. case GGML_TYPE_IQ1_M:
  558. return dequantize_row_iq1_m_cuda;
  559. case GGML_TYPE_IQ4_NL:
  560. return dequantize_row_iq4_nl_cuda;
  561. case GGML_TYPE_IQ4_XS:
  562. return dequantize_row_iq4_xs_cuda;
  563. case GGML_TYPE_IQ3_S:
  564. return dequantize_row_iq3_s_cuda;
  565. case GGML_TYPE_F16:
  566. return convert_unary_cuda<half>;
  567. default:
  568. return nullptr;
  569. }
  570. }