convert.cu 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824
  1. #include "convert.cuh"
  2. #include "dequantize.cuh"
  3. #define CUDA_Q8_0_NE_ALIGN 2048
  4. template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
  5. static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
  6. const int64_t i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
  7. if (i >= k) {
  8. return;
  9. }
  10. const int64_t ib = i/qk; // block index
  11. const int iqs = (i%qk)/qr; // quant index
  12. const int iybs = i - i%qk; // y block start index
  13. const int y_offset = qr == 1 ? 1 : qk/2;
  14. // dequantize
  15. dfloat2 v;
  16. dequantize_kernel(vx, ib, iqs, v);
  17. y[iybs + iqs + 0] = v.x;
  18. y[iybs + iqs + y_offset] = v.y;
  19. }
  20. template <bool need_check>
  21. static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
  22. #if __CUDA_ARCH__ >= CC_PASCAL
  23. constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
  24. const int i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
  25. const int * x0 = ((int *) vx) + blockIdx.x * nint;
  26. half2 * y2 = (half2 *) (y + i0);
  27. __shared__ int vals[nint];
  28. #pragma unroll
  29. for (int ix0 = 0; ix0 < nint; ix0 += WARP_SIZE) {
  30. if (need_check && i0*sizeof(block_q8_0)/QK8_0 + sizeof(int)*(ix0 + threadIdx.x) >= k*sizeof(block_q8_0)/QK8_0) {
  31. break;
  32. }
  33. const int ix = ix0 + threadIdx.x;
  34. vals[ix] = x0[ix];
  35. }
  36. #pragma unroll
  37. for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
  38. if (need_check && i0 + iy + 2*threadIdx.x >= k) {
  39. return;
  40. }
  41. const half * b0 = ((const half *) vals) + (sizeof(block_q8_0)/sizeof(half)) * ((iy + 2*threadIdx.x)/QK8_0);
  42. const half d = *b0;
  43. const char2 qs = ((const char2 *) (b0 + 1))[threadIdx.x % (QK8_0/2)];
  44. y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d));
  45. }
  46. #else
  47. GGML_UNUSED(vx);
  48. GGML_UNUSED(y);
  49. GGML_UNUSED(k);
  50. NO_DEVICE_CODE;
  51. #endif // __CUDA_ARCH__ >= CC_PASCAL
  52. }
  53. template<typename dst_t>
  54. static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
  55. const int64_t i = blockIdx.x;
  56. // assume 32 threads
  57. const int tid = threadIdx.x;
  58. const int il = tid/8;
  59. const int ir = tid%8;
  60. const int64_t ib = 8*i + ir;
  61. if (ib >= nb32) {
  62. return;
  63. }
  64. dst_t * y = yy + 256*i + 32*ir + 4*il;
  65. const block_q4_0 * x = (const block_q4_0 *)vx + ib;
  66. const float d = __half2float(x->d);
  67. const float dm = -8*d;
  68. const uint8_t * q = x->qs + 4*il;
  69. for (int l = 0; l < 4; ++l) {
  70. y[l+ 0] = d * (q[l] & 0xF) + dm;
  71. y[l+16] = d * (q[l] >> 4) + dm;
  72. }
  73. }
  74. template<typename dst_t>
  75. static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
  76. const int64_t i = blockIdx.x;
  77. // assume 32 threads
  78. const int tid = threadIdx.x;
  79. const int il = tid/8;
  80. const int ir = tid%8;
  81. const int64_t ib = 8*i + ir;
  82. if (ib >= nb32) {
  83. return;
  84. }
  85. dst_t * y = yy + 256*i + 32*ir + 4*il;
  86. const block_q4_1 * x = (const block_q4_1 *)vx + ib;
  87. const float2 d = __half22float2(x->dm);
  88. const uint8_t * q = x->qs + 4*il;
  89. for (int l = 0; l < 4; ++l) {
  90. y[l+ 0] = d.x * (q[l] & 0xF) + d.y;
  91. y[l+16] = d.x * (q[l] >> 4) + d.y;
  92. }
  93. }
  94. //================================== k-quants
  95. template<typename dst_t>
  96. static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  97. const int i = blockIdx.x;
  98. const block_q2_K * x = (const block_q2_K *) vx;
  99. const int tid = threadIdx.x;
  100. #if QK_K == 256
  101. const int n = tid/32;
  102. const int l = tid - 32*n;
  103. const int is = 8*n + l/16;
  104. const uint8_t q = x[i].qs[32*n + l];
  105. dst_t * y = yy + i*QK_K + 128*n;
  106. float dall = __low2half(x[i].dm);
  107. float dmin = __high2half(x[i].dm);
  108. y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
  109. y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
  110. y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
  111. y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
  112. #else
  113. const int is = tid/16; // 0 or 1
  114. const int il = tid%16; // 0...15
  115. const uint8_t q = x[i].qs[il] >> (2*is);
  116. dst_t * y = yy + i*QK_K + 16*is + il;
  117. float dall = __low2half(x[i].dm);
  118. float dmin = __high2half(x[i].dm);
  119. y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
  120. y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
  121. #endif
  122. }
  123. template<typename dst_t>
  124. static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  125. const int i = blockIdx.x;
  126. const block_q3_K * x = (const block_q3_K *) vx;
  127. #if QK_K == 256
  128. const int r = threadIdx.x/4;
  129. const int tid = r/2;
  130. const int is0 = r%2;
  131. const int l0 = 16*is0 + 4*(threadIdx.x%4);
  132. const int n = tid / 4;
  133. const int j = tid - 4*n;
  134. uint8_t m = 1 << (4*n + j);
  135. int is = 8*n + 2*j + is0;
  136. int shift = 2*j;
  137. int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
  138. is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
  139. is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
  140. (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
  141. float d_all = x[i].d;
  142. float dl = d_all * (us - 32);
  143. dst_t * y = yy + i*QK_K + 128*n + 32*j;
  144. const uint8_t * q = x[i].qs + 32*n;
  145. const uint8_t * hm = x[i].hmask;
  146. for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
  147. #else
  148. const int tid = threadIdx.x;
  149. const int is = tid/16; // 0 or 1
  150. const int il = tid%16; // 0...15
  151. const int im = il/8; // 0...1
  152. const int in = il%8; // 0...7
  153. dst_t * y = yy + i*QK_K + 16*is + il;
  154. const uint8_t q = x[i].qs[il] >> (2*is);
  155. const uint8_t h = x[i].hmask[in] >> (2*is + im);
  156. const float d = (float)x[i].d;
  157. if (is == 0) {
  158. y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
  159. y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
  160. } else {
  161. y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
  162. y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
  163. }
  164. #endif
  165. }
  166. #if QK_K == 256
  167. static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
  168. if (j < 4) {
  169. d = q[j] & 63; m = q[j + 4] & 63;
  170. } else {
  171. d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
  172. m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
  173. }
  174. }
  175. #endif
  176. template<typename dst_t>
  177. static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  178. const block_q4_K * x = (const block_q4_K *) vx;
  179. const int i = blockIdx.x;
  180. #if QK_K == 256
  181. // assume 32 threads
  182. const int tid = threadIdx.x;
  183. const int il = tid/8;
  184. const int ir = tid%8;
  185. const int is = 2*il;
  186. const int n = 4;
  187. dst_t * y = yy + i*QK_K + 64*il + n*ir;
  188. const float dall = __low2half(x[i].dm);
  189. const float dmin = __high2half(x[i].dm);
  190. const uint8_t * q = x[i].qs + 32*il + n*ir;
  191. uint8_t sc, m;
  192. get_scale_min_k4(is + 0, x[i].scales, sc, m);
  193. const float d1 = dall * sc; const float m1 = dmin * m;
  194. get_scale_min_k4(is + 1, x[i].scales, sc, m);
  195. const float d2 = dall * sc; const float m2 = dmin * m;
  196. for (int l = 0; l < n; ++l) {
  197. y[l + 0] = d1 * (q[l] & 0xF) - m1;
  198. y[l +32] = d2 * (q[l] >> 4) - m2;
  199. }
  200. #else
  201. const int tid = threadIdx.x;
  202. const uint8_t * q = x[i].qs;
  203. dst_t * y = yy + i*QK_K;
  204. const float d = (float)x[i].dm[0];
  205. const float m = (float)x[i].dm[1];
  206. y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
  207. y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
  208. #endif
  209. }
  210. template<typename dst_t>
  211. static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  212. const block_q5_K * x = (const block_q5_K *) vx;
  213. const int i = blockIdx.x;
  214. #if QK_K == 256
  215. // assume 64 threads - this is very slightly better than the one below
  216. const int tid = threadIdx.x;
  217. const int il = tid/16; // il is in 0...3
  218. const int ir = tid%16; // ir is in 0...15
  219. const int is = 2*il; // is is in 0...6
  220. dst_t * y = yy + i*QK_K + 64*il + 2*ir;
  221. const float dall = __low2half(x[i].dm);
  222. const float dmin = __high2half(x[i].dm);
  223. const uint8_t * ql = x[i].qs + 32*il + 2*ir;
  224. const uint8_t * qh = x[i].qh + 2*ir;
  225. uint8_t sc, m;
  226. get_scale_min_k4(is + 0, x[i].scales, sc, m);
  227. const float d1 = dall * sc; const float m1 = dmin * m;
  228. get_scale_min_k4(is + 1, x[i].scales, sc, m);
  229. const float d2 = dall * sc; const float m2 = dmin * m;
  230. uint8_t hm = 1 << (2*il);
  231. y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
  232. y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
  233. hm <<= 1;
  234. y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
  235. y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
  236. #else
  237. const int tid = threadIdx.x;
  238. const uint8_t q = x[i].qs[tid];
  239. const int im = tid/8; // 0...3
  240. const int in = tid%8; // 0...7
  241. const int is = tid/16; // 0 or 1
  242. const uint8_t h = x[i].qh[in] >> im;
  243. const float d = x[i].d;
  244. dst_t * y = yy + i*QK_K + tid;
  245. y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
  246. y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
  247. #endif
  248. }
  249. template<typename dst_t>
  250. static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  251. const block_q6_K * x = (const block_q6_K *) vx;
  252. const int64_t i = blockIdx.x;
  253. #if QK_K == 256
  254. // assume 64 threads - this is very slightly better than the one below
  255. const int64_t tid = threadIdx.x;
  256. const int64_t ip = tid/32; // ip is 0 or 1
  257. const int64_t il = tid - 32*ip; // 0...32
  258. const int64_t is = 8*ip + il/16;
  259. dst_t * y = yy + i*QK_K + 128*ip + il;
  260. const float d = x[i].d;
  261. const uint8_t * ql = x[i].ql + 64*ip + il;
  262. const uint8_t qh = x[i].qh[32*ip + il];
  263. const int8_t * sc = x[i].scales + is;
  264. y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
  265. y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
  266. y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
  267. y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
  268. #else
  269. // assume 32 threads
  270. const int64_t tid = threadIdx.x;
  271. const int64_t ip = tid/16; // 0 or 1
  272. const int64_t il = tid - 16*ip; // 0...15
  273. dst_t * y = yy + i*QK_K + 16*ip + il;
  274. const float d = x[i].d;
  275. const uint8_t ql = x[i].ql[16*ip + il];
  276. const uint8_t qh = x[i].qh[il] >> (2*ip);
  277. const int8_t * sc = x[i].scales;
  278. y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
  279. y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32);
  280. #endif
  281. }
  282. template<typename dst_t>
  283. static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  284. const int i = blockIdx.x;
  285. const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
  286. const int tid = threadIdx.x;
  287. #if QK_K == 256
  288. const int il = tid/8; // 0...3
  289. const int ib = tid%8; // 0...7
  290. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  291. const uint16_t * q2 = x[i].qs + 4*ib;
  292. const uint8_t * aux8 = (const uint8_t *)q2;
  293. const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
  294. const uint32_t aux32 = q2[2] | (q2[3] << 16);
  295. const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
  296. const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
  297. for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
  298. #else
  299. NO_DEVICE_CODE;
  300. #endif
  301. }
  302. template<typename dst_t>
  303. static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  304. const int i = blockIdx.x;
  305. const block_iq2_xs * x = (const block_iq2_xs *) vx;
  306. const int tid = threadIdx.x;
  307. #if QK_K == 256
  308. const int il = tid/8; // 0...3
  309. const int ib = tid%8; // 0...7
  310. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  311. const uint16_t * q2 = x[i].qs + 4*ib;
  312. const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
  313. const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
  314. const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
  315. for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
  316. #else
  317. NO_DEVICE_CODE;
  318. #endif
  319. }
  320. template<typename dst_t>
  321. static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  322. const int i = blockIdx.x;
  323. const block_iq2_s * x = (const block_iq2_s *) vx;
  324. const int tid = threadIdx.x;
  325. #if QK_K == 256
  326. const int il = tid/8; // 0...3
  327. const int ib = tid%8; // 0...7
  328. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  329. const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
  330. const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
  331. const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
  332. for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
  333. #else
  334. NO_DEVICE_CODE;
  335. #endif
  336. }
  337. template<typename dst_t>
  338. static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  339. const int i = blockIdx.x;
  340. const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
  341. const int tid = threadIdx.x;
  342. #if QK_K == 256
  343. const int il = tid/8; // 0...3
  344. const int ib = tid%8; // 0...7
  345. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  346. const uint8_t * q3 = x[i].qs + 8*ib;
  347. const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
  348. const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
  349. const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
  350. const uint32_t aux32 = gas[0] | (gas[1] << 16);
  351. const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
  352. const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
  353. for (int j = 0; j < 4; ++j) {
  354. y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
  355. y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
  356. }
  357. #else
  358. NO_DEVICE_CODE;
  359. #endif
  360. }
  361. template<typename dst_t>
  362. static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  363. const int i = blockIdx.x;
  364. const block_iq3_s * x = (const block_iq3_s *) vx;
  365. const int tid = threadIdx.x;
  366. #if QK_K == 256
  367. const int il = tid/8; // 0...3
  368. const int ib = tid%8; // 0...7
  369. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  370. const uint8_t * qs = x[i].qs + 8*ib;
  371. const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
  372. const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
  373. const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
  374. const uint8_t signs = x[i].signs[4*ib + il];
  375. for (int j = 0; j < 4; ++j) {
  376. y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
  377. y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
  378. }
  379. #else
  380. NO_DEVICE_CODE;
  381. #endif
  382. }
  383. template<typename dst_t>
  384. static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  385. const int i = blockIdx.x;
  386. const block_iq1_s * x = (const block_iq1_s *) vx;
  387. const int tid = threadIdx.x;
  388. #if QK_K == 256
  389. const int il = tid/8; // 0...3
  390. const int ib = tid%8; // 0...7
  391. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  392. const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
  393. const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
  394. uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
  395. grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
  396. grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
  397. grid32[0] &= 0x0f0f0f0f;
  398. for (int j = 0; j < 8; ++j) {
  399. y[j] = d * (q[j] + delta);
  400. }
  401. #else
  402. NO_DEVICE_CODE;
  403. #endif
  404. }
  405. template<typename dst_t>
  406. static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  407. const int i = blockIdx.x;
  408. const block_iq1_m * x = (const block_iq1_m *) vx;
  409. const int tid = threadIdx.x;
  410. #if QK_K == 256
  411. const int il = tid/8; // 0...3
  412. const int ib = tid%8; // 0...7
  413. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  414. const uint16_t * sc = (const uint16_t *)x[i].scales;
  415. iq1m_scale_t scale;
  416. scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
  417. const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
  418. const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
  419. const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
  420. uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
  421. grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
  422. grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
  423. grid32[0] &= 0x0f0f0f0f;
  424. for (int j = 0; j < 8; ++j) {
  425. y[j] = d * (q[j] + delta);
  426. }
  427. #else
  428. NO_DEVICE_CODE;
  429. #endif
  430. }
  431. template<typename dst_t>
  432. static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  433. const int i = blockIdx.x;
  434. const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
  435. const int tid = threadIdx.x;
  436. const int il = tid/8; // 0...3
  437. const int ib = tid%8; // 0...7
  438. dst_t * y = yy + i*QK_K + 32*ib + 4*il;
  439. const uint8_t * q4 = x[ib].qs + 4*il;
  440. const float d = (float)x[ib].d;
  441. for (int j = 0; j < 4; ++j) {
  442. y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
  443. y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
  444. }
  445. }
  446. #if QK_K != 64
  447. template<typename dst_t>
  448. static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
  449. const int i = blockIdx.x;
  450. const block_iq4_xs * x = (const block_iq4_xs *)vx;
  451. const int tid = threadIdx.x;
  452. const int il = tid/8; // 0...3
  453. const int ib = tid%8; // 0...7
  454. dst_t * y = yy + i*QK_K + 32*ib + 4*il;
  455. const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
  456. const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
  457. for (int j = 0; j < 4; ++j) {
  458. y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
  459. y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
  460. }
  461. }
  462. #endif
  463. template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
  464. static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
  465. const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
  466. dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  467. }
  468. static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) {
  469. const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
  470. if (k % CUDA_Q8_0_NE_ALIGN == 0) {
  471. const bool need_check = false;
  472. dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
  473. } else {
  474. const bool need_check = true;
  475. dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
  476. }
  477. }
  478. template<typename dst_t>
  479. static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  480. const int nb = k / QK_K;
  481. #if QK_K == 256
  482. dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
  483. #else
  484. dequantize_block_q2_K<<<nb, 32, 0, stream>>>(vx, y);
  485. #endif
  486. }
  487. template<typename dst_t>
  488. static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  489. const int nb = k / QK_K;
  490. #if QK_K == 256
  491. dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
  492. #else
  493. dequantize_block_q3_K<<<nb, 32, 0, stream>>>(vx, y);
  494. #endif
  495. }
  496. template<typename dst_t>
  497. static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  498. const int nb32 = k / 32;
  499. const int nb = (k + 255) / 256;
  500. dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
  501. }
  502. template<typename dst_t>
  503. static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  504. const int nb32 = k / 32;
  505. const int nb = (k + 255) / 256;
  506. dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
  507. }
  508. template<typename dst_t>
  509. static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  510. const int nb = k / QK_K;
  511. dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
  512. }
  513. template<typename dst_t>
  514. static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  515. const int nb = k / QK_K;
  516. #if QK_K == 256
  517. dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
  518. #else
  519. dequantize_block_q5_K<<<nb, 32, 0, stream>>>(vx, y);
  520. #endif
  521. }
  522. template<typename dst_t>
  523. static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  524. const int nb = k / QK_K;
  525. #if QK_K == 256
  526. dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
  527. #else
  528. dequantize_block_q6_K<<<nb, 32, 0, stream>>>(vx, y);
  529. #endif
  530. }
  531. template<typename dst_t>
  532. static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  533. const int nb = k / QK_K;
  534. dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
  535. }
  536. template<typename dst_t>
  537. static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  538. const int nb = k / QK_K;
  539. dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
  540. }
  541. template<typename dst_t>
  542. static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  543. const int nb = k / QK_K;
  544. dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
  545. }
  546. template<typename dst_t>
  547. static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  548. const int nb = k / QK_K;
  549. dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
  550. }
  551. template<typename dst_t>
  552. static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  553. const int nb = k / QK_K;
  554. dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
  555. }
  556. template<typename dst_t>
  557. static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  558. const int nb = k / QK_K;
  559. dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
  560. }
  561. template<typename dst_t>
  562. static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  563. const int nb = (k + QK_K - 1) / QK_K;
  564. dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
  565. }
  566. template<typename dst_t>
  567. static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  568. const int nb = k / QK_K;
  569. dequantize_block_iq1_m<<<nb, 32, 0, stream>>>(vx, y);
  570. }
  571. template<typename dst_t>
  572. static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
  573. const int nb = (k + QK_K - 1) / QK_K;
  574. #if QK_K == 64
  575. dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
  576. #else
  577. dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
  578. #endif
  579. }
  580. template <typename src_t, typename dst_t>
  581. static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
  582. const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
  583. if (i >= k) {
  584. return;
  585. }
  586. const src_t * x = (src_t *) vx;
  587. y[i] = x[i];
  588. }
  589. template <typename src_t, typename dst_t>
  590. static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
  591. const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
  592. convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
  593. }
  594. to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
  595. int id;
  596. switch (type) {
  597. case GGML_TYPE_Q4_0:
  598. return dequantize_row_q4_0_cuda;
  599. case GGML_TYPE_Q4_1:
  600. return dequantize_row_q4_1_cuda;
  601. case GGML_TYPE_Q5_0:
  602. return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
  603. case GGML_TYPE_Q5_1:
  604. return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
  605. case GGML_TYPE_Q8_0:
  606. CUDA_CHECK(cudaGetDevice(&id));
  607. if (ggml_cuda_info().devices[id].cc >= CC_PASCAL) {
  608. return dequantize_block_q8_0_f16_cuda;
  609. }
  610. return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
  611. case GGML_TYPE_Q2_K:
  612. return dequantize_row_q2_K_cuda;
  613. case GGML_TYPE_Q3_K:
  614. return dequantize_row_q3_K_cuda;
  615. case GGML_TYPE_Q4_K:
  616. return dequantize_row_q4_K_cuda;
  617. case GGML_TYPE_Q5_K:
  618. return dequantize_row_q5_K_cuda;
  619. case GGML_TYPE_Q6_K:
  620. return dequantize_row_q6_K_cuda;
  621. case GGML_TYPE_IQ2_XXS:
  622. return dequantize_row_iq2_xxs_cuda;
  623. case GGML_TYPE_IQ2_XS:
  624. return dequantize_row_iq2_xs_cuda;
  625. case GGML_TYPE_IQ2_S:
  626. return dequantize_row_iq2_s_cuda;
  627. case GGML_TYPE_IQ3_XXS:
  628. return dequantize_row_iq3_xxs_cuda;
  629. case GGML_TYPE_IQ1_S:
  630. return dequantize_row_iq1_s_cuda;
  631. case GGML_TYPE_IQ1_M:
  632. return dequantize_row_iq1_m_cuda;
  633. case GGML_TYPE_IQ4_NL:
  634. return dequantize_row_iq4_nl_cuda;
  635. case GGML_TYPE_IQ4_XS:
  636. return dequantize_row_iq4_xs_cuda;
  637. case GGML_TYPE_IQ3_S:
  638. return dequantize_row_iq3_s_cuda;
  639. case GGML_TYPE_F32:
  640. return convert_unary_cuda<float>;
  641. default:
  642. return nullptr;
  643. }
  644. }
  645. to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
  646. switch (type) {
  647. case GGML_TYPE_Q4_0:
  648. return dequantize_row_q4_0_cuda;
  649. case GGML_TYPE_Q4_1:
  650. return dequantize_row_q4_1_cuda;
  651. case GGML_TYPE_Q5_0:
  652. return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
  653. case GGML_TYPE_Q5_1:
  654. return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
  655. case GGML_TYPE_Q8_0:
  656. return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
  657. case GGML_TYPE_Q2_K:
  658. return dequantize_row_q2_K_cuda;
  659. case GGML_TYPE_Q3_K:
  660. return dequantize_row_q3_K_cuda;
  661. case GGML_TYPE_Q4_K:
  662. return dequantize_row_q4_K_cuda;
  663. case GGML_TYPE_Q5_K:
  664. return dequantize_row_q5_K_cuda;
  665. case GGML_TYPE_Q6_K:
  666. return dequantize_row_q6_K_cuda;
  667. case GGML_TYPE_IQ2_XXS:
  668. return dequantize_row_iq2_xxs_cuda;
  669. case GGML_TYPE_IQ2_XS:
  670. return dequantize_row_iq2_xs_cuda;
  671. case GGML_TYPE_IQ2_S:
  672. return dequantize_row_iq2_s_cuda;
  673. case GGML_TYPE_IQ3_XXS:
  674. return dequantize_row_iq3_xxs_cuda;
  675. case GGML_TYPE_IQ1_S:
  676. return dequantize_row_iq1_s_cuda;
  677. case GGML_TYPE_IQ1_M:
  678. return dequantize_row_iq1_m_cuda;
  679. case GGML_TYPE_IQ4_NL:
  680. return dequantize_row_iq4_nl_cuda;
  681. case GGML_TYPE_IQ4_XS:
  682. return dequantize_row_iq4_xs_cuda;
  683. case GGML_TYPE_IQ3_S:
  684. return dequantize_row_iq3_s_cuda;
  685. case GGML_TYPE_F16:
  686. return convert_unary_cuda<half>;
  687. default:
  688. return nullptr;
  689. }
  690. }