dequantize.hpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690
  1. //
  2. // MIT license
  3. // Copyright (C) 2024 Intel Corporation
  4. // SPDX-License-Identifier: MIT
  5. //
  6. //
  7. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  8. // See https://llvm.org/LICENSE.txt for license information.
  9. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  10. //
  11. #ifndef GGML_SYCL_DEQUANTIZE_HPP
  12. #define GGML_SYCL_DEQUANTIZE_HPP
  13. #include "common.hpp"
  14. typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
  15. static __dpct_inline__ void dequantize_q4_0(const void *vx, const int ib,
  16. const int iqs, dfloat2 &v) {
  17. const block_q4_0 * x = (const block_q4_0 *) vx;
  18. const dfloat d = x[ib].d;
  19. const int vui = x[ib].qs[iqs];
  20. v.x() = vui & 0xF;
  21. v.y() = vui >> 4;
  22. #ifdef GGML_SYCL_F16
  23. // v = v - {8.0f, 8.0f};
  24. // v = v * {d, d};
  25. v.s0() = (v.s0() - 8.0f) * d;
  26. v.s1() = (v.s1() - 8.0f) * d;
  27. #else
  28. v.x() = (v.x() - 8.0f) * d;
  29. v.y() = (v.y() - 8.0f) * d;
  30. #endif // GGML_SYCL_F16
  31. }
  32. static __dpct_inline__ void dequantize_q4_1(const void *vx, const int ib,
  33. const int iqs, dfloat2 &v) {
  34. const block_q4_1 * x = (const block_q4_1 *) vx;
  35. const dfloat d = x[ib].dm[0];
  36. const dfloat m = x[ib].dm[1];
  37. const int vui = x[ib].qs[iqs];
  38. v.x() = vui & 0xF;
  39. v.y() = vui >> 4;
  40. #ifdef GGML_SYCL_F16
  41. // v = v * {d, d};
  42. // v = v + {m, m};
  43. v.s0() = (v.s0() * d) + m;
  44. v.s1() = (v.s1() * d) + m;
  45. #else
  46. v.x() = (v.x() * d) + m;
  47. v.y() = (v.y() * d) + m;
  48. #endif // GGML_SYCL_F16
  49. }
  50. static __dpct_inline__ void dequantize_q5_0(const void *vx, const int ib,
  51. const int iqs, dfloat2 &v) {
  52. const block_q5_0 * x = (const block_q5_0 *) vx;
  53. const dfloat d = x[ib].d;
  54. uint32_t qh;
  55. memcpy(&qh, x[ib].qh, sizeof(qh));
  56. const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
  57. const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
  58. v.x() = ((x[ib].qs[iqs] & 0xf) | xh_0);
  59. v.y() = ((x[ib].qs[iqs] >> 4) | xh_1);
  60. #ifdef GGML_SYCL_F16
  61. // v = v - {16.0f, 16.0f};
  62. // v = v * {d, d};
  63. v.s0() = (v.s0() - 16.0f) * d;
  64. v.s1() = (v.s1() - 16.0f) * d;
  65. #else
  66. v.x() = (v.x() - 16.0f) * d;
  67. v.y() = (v.y() - 16.0f) * d;
  68. #endif // GGML_SYCL_F16
  69. }
  70. static __dpct_inline__ void dequantize_q5_1(const void *vx, const int ib,
  71. const int iqs, dfloat2 &v) {
  72. const block_q5_1 * x = (const block_q5_1 *) vx;
  73. const dfloat d = x[ib].dm[0];
  74. const dfloat m = x[ib].dm[1];
  75. uint32_t qh;
  76. memcpy(&qh, x[ib].qh, sizeof(qh));
  77. const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
  78. const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
  79. v.x() = ((x[ib].qs[iqs] & 0xf) | xh_0);
  80. v.y() = ((x[ib].qs[iqs] >> 4) | xh_1);
  81. #ifdef GGML_SYCL_F16
  82. // v = v * {d, d};
  83. // v = v + {m, m};
  84. v.s0() = (v.s0() * d) + m;
  85. v.s1() = (v.s1() * d) + m;
  86. #else
  87. v.x() = (v.x() * d) + m;
  88. v.y() = (v.y() * d) + m;
  89. #endif // GGML_SYCL_F16
  90. }
  91. static __dpct_inline__ void dequantize_q8_0(const void *vx, const int ib,
  92. const int iqs, dfloat2 &v) {
  93. const block_q8_0 * x = (const block_q8_0 *) vx;
  94. const dfloat d = x[ib].d;
  95. v.x() = x[ib].qs[iqs + 0];
  96. v.y() = x[ib].qs[iqs + 1];
  97. #ifdef GGML_SYCL_F16
  98. // v = v * {d, d};
  99. v.s0() *= d;
  100. v.s1() *= d;
  101. #else
  102. v.x() *= d;
  103. v.y() *= d;
  104. #endif // GGML_SYCL_F16
  105. }
  106. template<typename dst_t>
  107. static void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
  108. const sycl::nd_item<3> &item_ct1) {
  109. const int i = item_ct1.get_group(2);
  110. // assume 32 threads
  111. const int tid = item_ct1.get_local_id(2);
  112. const int il = tid/8;
  113. const int ir = tid%8;
  114. const int ib = 8*i + ir;
  115. if (ib >= nb32) {
  116. return;
  117. }
  118. dst_t * y = yy + 256*i + 32*ir + 4*il;
  119. const block_q4_0 * x = (const block_q4_0 *)vx + ib;
  120. const float d = sycl::vec<sycl::half, 1>(x->d)
  121. .convert<float, sycl::rounding_mode::automatic>()[0];
  122. const float dm = -8*d;
  123. const uint8_t * q = x->qs + 4*il;
  124. for (int l = 0; l < 4; ++l) {
  125. y[l+ 0] = d * (q[l] & 0xF) + dm;
  126. y[l+16] = d * (q[l] >> 4) + dm;
  127. }
  128. }
  129. template<typename dst_t>
  130. static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32,
  131. const sycl::nd_item<3> &item_ct1) {
  132. const int i = item_ct1.get_group(2);
  133. // assume 32 threads
  134. const int tid = item_ct1.get_local_id(2);
  135. const int il = tid/8;
  136. const int ir = tid%8;
  137. const int ib = 8*i + ir;
  138. if (ib >= nb32) {
  139. return;
  140. }
  141. dst_t * y = yy + 256*i + 32*ir + 4*il;
  142. const block_q4_1 * x = (const block_q4_1 *)vx + ib;
  143. const sycl::float2 d =
  144. x->dm.convert<float, sycl::rounding_mode::automatic>();
  145. const uint8_t * q = x->qs + 4*il;
  146. for (int l = 0; l < 4; ++l) {
  147. y[l + 0] = d.x() * (q[l] & 0xF) + d.y();
  148. y[l + 16] = d.x() * (q[l] >> 4) + d.y();
  149. }
  150. }
  151. //================================== k-quants
  152. template<typename dst_t>
  153. static void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
  154. const sycl::nd_item<3> &item_ct1) {
  155. const int i = item_ct1.get_group(2);
  156. const block_q2_K * x = (const block_q2_K *) vx;
  157. const int tid = item_ct1.get_local_id(2);
  158. #if QK_K == 256
  159. const int n = tid/32;
  160. const int l = tid - 32*n;
  161. const int is = 8*n + l/16;
  162. const uint8_t q = x[i].qs[32*n + l];
  163. dst_t * y = yy + i*QK_K + 128*n;
  164. float dall = x[i].dm[0];
  165. float dmin = x[i].dm[1];
  166. y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
  167. y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
  168. y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
  169. y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
  170. #else
  171. const int is = tid/16; // 0 or 1
  172. const int il = tid%16; // 0...15
  173. const uint8_t q = x[i].qs[il] >> (2*is);
  174. dst_t * y = yy + i*QK_K + 16*is + il;
  175. float dall = x[i].dm[0];
  176. float dmin = x[i].dm[1];
  177. y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
  178. y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
  179. #endif
  180. }
  181. template<typename dst_t>
  182. static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
  183. const sycl::nd_item<3> &item_ct1) {
  184. const int i = item_ct1.get_group(2);
  185. const block_q3_K * x = (const block_q3_K *) vx;
  186. #if QK_K == 256
  187. const int r = item_ct1.get_local_id(2) / 4;
  188. const int tid = r/2;
  189. const int is0 = r%2;
  190. const int l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4);
  191. const int n = tid / 4;
  192. const int j = tid - 4*n;
  193. uint8_t m = 1 << (4*n + j);
  194. int is = 8*n + 2*j + is0;
  195. int shift = 2*j;
  196. int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
  197. is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
  198. is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
  199. (x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
  200. float d_all = x[i].d;
  201. float dl = d_all * (us - 32);
  202. dst_t * y = yy + i*QK_K + 128*n + 32*j;
  203. const uint8_t * q = x[i].qs + 32*n;
  204. const uint8_t * hm = x[i].hmask;
  205. for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
  206. #else
  207. const int tid = item_ct1.get_local_id(2);
  208. const int is = tid/16; // 0 or 1
  209. const int il = tid%16; // 0...15
  210. const int im = il/8; // 0...1
  211. const int in = il%8; // 0...7
  212. dst_t * y = yy + i*QK_K + 16*is + il;
  213. const uint8_t q = x[i].qs[il] >> (2*is);
  214. const uint8_t h = x[i].hmask[in] >> (2*is + im);
  215. const float d = (float)x[i].d;
  216. if (is == 0) {
  217. y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
  218. y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
  219. } else {
  220. y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
  221. y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
  222. }
  223. #endif
  224. }
  225. #if QK_K == 256
  226. static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
  227. if (j < 4) {
  228. d = q[j] & 63; m = q[j + 4] & 63;
  229. } else {
  230. d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
  231. m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
  232. }
  233. }
  234. #endif
  235. template<typename dst_t>
  236. static void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
  237. const sycl::nd_item<3> &item_ct1) {
  238. const block_q4_K * x = (const block_q4_K *) vx;
  239. const int i = item_ct1.get_group(2);
  240. #if QK_K == 256
  241. // assume 32 threads
  242. const int tid = item_ct1.get_local_id(2);
  243. const int il = tid/8;
  244. const int ir = tid%8;
  245. const int is = 2*il;
  246. const int n = 4;
  247. dst_t * y = yy + i*QK_K + 64*il + n*ir;
  248. const float dall = x[i].dm[0];
  249. const float dmin = x[i].dm[1];
  250. const uint8_t * q = x[i].qs + 32*il + n*ir;
  251. uint8_t sc, m;
  252. get_scale_min_k4(is + 0, x[i].scales, sc, m);
  253. const float d1 = dall * sc; const float m1 = dmin * m;
  254. get_scale_min_k4(is + 1, x[i].scales, sc, m);
  255. const float d2 = dall * sc; const float m2 = dmin * m;
  256. for (int l = 0; l < n; ++l) {
  257. y[l + 0] = d1 * (q[l] & 0xF) - m1;
  258. y[l +32] = d2 * (q[l] >> 4) - m2;
  259. }
  260. #else
  261. const int tid = item_ct1.get_local_id(2);
  262. const uint8_t * q = x[i].qs;
  263. dst_t * y = yy + i*QK_K;
  264. const float d = (float)x[i].dm[0];
  265. const float m = (float)x[i].dm[1];
  266. y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
  267. y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
  268. #endif
  269. }
  270. template<typename dst_t>
  271. static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
  272. const sycl::nd_item<3> &item_ct1) {
  273. const block_q5_K * x = (const block_q5_K *) vx;
  274. const int i = item_ct1.get_group(2);
  275. #if QK_K == 256
  276. // assume 64 threads - this is very slightly better than the one below
  277. const int tid = item_ct1.get_local_id(2);
  278. const int il = tid/16; // il is in 0...3
  279. const int ir = tid%16; // ir is in 0...15
  280. const int is = 2*il; // is is in 0...6
  281. dst_t * y = yy + i*QK_K + 64*il + 2*ir;
  282. const float dall = x[i].dm[0];
  283. const float dmin = x[i].dm[1];
  284. const uint8_t * ql = x[i].qs + 32*il + 2*ir;
  285. const uint8_t * qh = x[i].qh + 2*ir;
  286. uint8_t sc, m;
  287. get_scale_min_k4(is + 0, x[i].scales, sc, m);
  288. const float d1 = dall * sc; const float m1 = dmin * m;
  289. get_scale_min_k4(is + 1, x[i].scales, sc, m);
  290. const float d2 = dall * sc; const float m2 = dmin * m;
  291. uint8_t hm = 1 << (2*il);
  292. y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
  293. y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
  294. hm <<= 1;
  295. y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
  296. y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
  297. #else
  298. const int tid = item_ct1.get_local_id(2);
  299. const uint8_t q = x[i].qs[tid];
  300. const int im = tid/8; // 0...3
  301. const int in = tid%8; // 0...7
  302. const int is = tid/16; // 0 or 1
  303. const uint8_t h = x[i].qh[in] >> im;
  304. const float d = x[i].d;
  305. dst_t * y = yy + i*QK_K + tid;
  306. y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
  307. y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
  308. #endif
  309. }
  310. template<typename dst_t>
  311. static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy,
  312. const sycl::nd_item<3> &item_ct1) {
  313. const block_q6_K * x = (const block_q6_K *) vx;
  314. const int i = item_ct1.get_group(2);
  315. #if QK_K == 256
  316. // assume 64 threads - this is very slightly better than the one below
  317. const int tid = item_ct1.get_local_id(2);
  318. const int ip = tid/32; // ip is 0 or 1
  319. const int il = tid - 32*ip; // 0...32
  320. const int is = 8*ip + il/16;
  321. dst_t * y = yy + i*QK_K + 128*ip + il;
  322. const float d = x[i].d;
  323. const uint8_t * ql = x[i].ql + 64*ip + il;
  324. const uint8_t qh = x[i].qh[32*ip + il];
  325. const int8_t * sc = x[i].scales + is;
  326. y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
  327. y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
  328. y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
  329. y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
  330. #else
  331. // assume 32 threads
  332. const int tid = item_ct1.get_local_id(2);
  333. const int ip = tid/16; // 0 or 1
  334. const int il = tid - 16*ip; // 0...15
  335. dst_t * y = yy + i*QK_K + 16*ip + il;
  336. const float d = x[i].d;
  337. const uint8_t ql = x[i].ql[16*ip + il];
  338. const uint8_t qh = x[i].qh[il] >> (2*ip);
  339. const int8_t * sc = x[i].scales;
  340. y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
  341. y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32);
  342. #endif
  343. }
  344. template<typename dst_t>
  345. static void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
  346. const sycl::nd_item<3> &item_ct1,
  347. const uint64_t *iq2xxs_grid_ptr,
  348. const uint8_t *ksigns_iq2xs_ptr,
  349. const uint8_t *kmask_iq2xs_ptr) {
  350. const int i = item_ct1.get_group(2);
  351. const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
  352. const int tid = item_ct1.get_local_id(2);
  353. #if QK_K == 256
  354. const int il = tid/8; // 0...3
  355. const int ib = tid%8; // 0...7
  356. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  357. const uint16_t * q2 = x[i].qs + 4*ib;
  358. const uint8_t * aux8 = (const uint8_t *)q2;
  359. const uint8_t * grid = (const uint8_t *)(iq2xxs_grid_ptr + aux8[il]);
  360. const uint32_t aux32 = q2[2] | (q2[3] << 16);
  361. const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
  362. const uint8_t signs = ksigns_iq2xs_ptr[(aux32 >> 7*il) & 127];
  363. for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs_ptr[j] ? -1.f : 1.f);
  364. #else
  365. assert(false);
  366. #endif
  367. }
  368. template<typename dst_t>
  369. static void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy,
  370. const sycl::nd_item<3> &item_ct1,
  371. const uint64_t *iq2xs_grid,
  372. const uint8_t *ksigns_iq2xs,
  373. const uint8_t *kmask_iq2xs) {
  374. const int i = item_ct1.get_group(2);
  375. const block_iq2_xs * x = (const block_iq2_xs *) vx;
  376. const int tid = item_ct1.get_local_id(2);
  377. #if QK_K == 256
  378. const int il = tid/8; // 0...3
  379. const int ib = tid%8; // 0...7
  380. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  381. const uint16_t * q2 = x[i].qs + 4*ib;
  382. const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
  383. const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
  384. const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
  385. for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
  386. #else
  387. assert(false);
  388. #endif
  389. }
  390. template <typename dst_t>
  391. __dpct_inline__ static void
  392. dequantize_block_iq2_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
  393. const sycl::nd_item<3> &item_ct1) {
  394. const int i = item_ct1.get_group(2);
  395. const block_iq2_s * x = (const block_iq2_s *) vx;
  396. const int tid = item_ct1.get_local_id(2);
  397. #if QK_K == 256
  398. const int il = tid/8; // 0...3
  399. const int ib = tid%8; // 0...7
  400. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  401. const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
  402. const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
  403. const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
  404. #pragma unroll
  405. for (int j = 0; j < 8; ++j)
  406. y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
  407. #else
  408. assert(false);
  409. #endif
  410. }
  411. template<typename dst_t>
  412. static void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy,
  413. const sycl::nd_item<3> &item_ct1,
  414. const uint32_t *iq3xxs_grid,
  415. const uint8_t *ksigns_iq2xs,
  416. const uint8_t *kmask_iq2xs) {
  417. const int i = item_ct1.get_group(2);
  418. const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
  419. const int tid = item_ct1.get_local_id(2);
  420. #if QK_K == 256
  421. const int il = tid/8; // 0...3
  422. const int ib = tid%8; // 0...7
  423. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  424. const uint8_t * q3 = x[i].qs + 8*ib;
  425. const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
  426. const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
  427. const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
  428. const uint32_t aux32 = gas[0] | (gas[1] << 16);
  429. const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
  430. const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
  431. for (int j = 0; j < 4; ++j) {
  432. y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
  433. y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
  434. }
  435. #else
  436. assert(false);
  437. #endif
  438. }
  439. template <typename dst_t>
  440. __dpct_inline__ static void
  441. dequantize_block_iq3_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
  442. const sycl::nd_item<3> &item_ct1,
  443. const uint8_t *kmask_iq2xs, const uint32_t *iq3s_grid) {
  444. const int i = item_ct1.get_group(2);
  445. const block_iq3_s * x = (const block_iq3_s *) vx;
  446. const int tid = item_ct1.get_local_id(2);
  447. #if QK_K == 256
  448. const int il = tid/8; // 0...3
  449. const int ib = tid%8; // 0...7
  450. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  451. const uint8_t * qs = x[i].qs + 8*ib;
  452. const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
  453. const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
  454. const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
  455. const uint8_t signs = x[i].signs[4*ib + il];
  456. #pragma unroll
  457. for (int j = 0; j < 4; ++j) {
  458. y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
  459. y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
  460. }
  461. #else
  462. assert(false);
  463. #endif
  464. }
  465. template <typename dst_t>
  466. __dpct_inline__ static void
  467. dequantize_block_iq1_s(const void *__restrict__ vx, dst_t *__restrict__ yy,
  468. const sycl::nd_item<3> &item_ct1,
  469. const uint32_t *iq1s_grid_gpu) {
  470. const int i = item_ct1.get_group(2);
  471. const block_iq1_s * x = (const block_iq1_s *) vx;
  472. const int tid = item_ct1.get_local_id(2);
  473. #if QK_K == 256
  474. const int il = tid/8; // 0...3
  475. const int ib = tid%8; // 0...7
  476. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  477. const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
  478. const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
  479. uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
  480. grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
  481. grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
  482. grid32[0] &= 0x0f0f0f0f;
  483. #pragma unroll
  484. for (int j = 0; j < 8; ++j) {
  485. y[j] = d * (q[j] + delta);
  486. }
  487. #else
  488. assert(false);
  489. #endif
  490. }
  491. template <typename dst_t>
  492. __dpct_inline__ static void
  493. dequantize_block_iq1_m(const void *__restrict__ vx, dst_t *__restrict__ yy,
  494. const sycl::nd_item<3> &item_ct1,
  495. const uint32_t *iq1s_grid_gpu) {
  496. const int i = item_ct1.get_group(2);
  497. const block_iq1_m * x = (const block_iq1_m *) vx;
  498. const int tid = item_ct1.get_local_id(2);
  499. #if QK_K == 256
  500. const int il = tid/8; // 0...3
  501. const int ib = tid%8; // 0...7
  502. dst_t * y = yy + i*QK_K + 32*ib + 8*il;
  503. const uint16_t * sc = (const uint16_t *)x[i].scales;
  504. iq1m_scale_t scale;
  505. scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
  506. const int ib16 = 2*ib + il/2; // sc[ib16/4] >> 3*(ib16%4) -> sc[ib/2] >> 3*((2*ib+il/2)%4);
  507. const float d = (float)scale.f16 * (2*((sc[ib16/4] >> 3*(ib16%4)) & 0x7) + 1);
  508. const float delta = x[i].qh[2*ib+il/2] & (0x08 << 4*(il%2)) ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA;
  509. uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
  510. grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[2*ib+il/2] >> 4*(il%2)) & 7) << 8)];
  511. grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
  512. grid32[0] &= 0x0f0f0f0f;
  513. #pragma unroll
  514. for (int j = 0; j < 8; ++j) {
  515. y[j] = d * (q[j] + delta);
  516. }
  517. #else
  518. assert(false);
  519. #endif
  520. }
  521. template <typename dst_t>
  522. __dpct_inline__ static void
  523. dequantize_block_iq4_nl(const void *__restrict__ vx, dst_t *__restrict__ yy,
  524. const sycl::nd_item<3> &item_ct1) {
  525. const int i = item_ct1.get_group(2);
  526. const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
  527. const int tid = item_ct1.get_local_id(2);
  528. const int il = tid/8; // 0...3
  529. const int ib = tid%8; // 0...7
  530. dst_t * y = yy + i*QK_K + 32*ib + 4*il;
  531. const uint8_t * q4 = x[ib].qs + 4*il;
  532. const float d = (float)x[ib].d;
  533. #pragma unroll
  534. for (int j = 0; j < 4; ++j) {
  535. y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
  536. y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
  537. }
  538. }
  539. template <typename dst_t>
  540. __dpct_inline__ static void
  541. dequantize_block_iq4_xs(const void *__restrict__ vx, dst_t *__restrict__ yy,
  542. const sycl::nd_item<3> &item_ct1) {
  543. const int i = item_ct1.get_group(2);
  544. const block_iq4_xs * x = (const block_iq4_xs *)vx;
  545. const int tid = item_ct1.get_local_id(2);
  546. const int il = tid/8; // 0...3
  547. const int ib = tid%8; // 0...7
  548. dst_t * y = yy + i*QK_K + 32*ib + 4*il;
  549. const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
  550. const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
  551. #pragma unroll
  552. for (int j = 0; j < 4; ++j) {
  553. y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
  554. y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
  555. }
  556. }
  557. #endif // GGML_SYCL_DEQUANTIZE_HPP