dequant_funcs_cm2.glsl 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734
  1. #include "types.glsl"
  2. layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 {
  3. vec4 block;
  4. };
  5. float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  6. {
  7. const vec4 v = bl.block;
  8. const uint idx = coordInBlock[1];
  9. const f16vec4 vf16 = f16vec4(v);
  10. return vf16[idx];
  11. }
  12. layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
  13. block_q4_0_packed16 block;
  14. };
  15. float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  16. {
  17. const float16_t d = bl.block.d;
  18. const uint idx = coordInBlock[1];
  19. const uint shift = (idx & 0x10) >> 2;
  20. uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]);
  21. qs >>= shift;
  22. qs &= 0x0F0F;
  23. qs = unpack8(qs)[idx & 1];
  24. float16_t ret = (float16_t(qs) - float16_t(8)) * d;
  25. return ret;
  26. }
  27. layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 {
  28. block_q4_1 block;
  29. };
  30. float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  31. {
  32. const float16_t d = bl.block.d;
  33. const float16_t m = bl.block.m;
  34. const uint idx = coordInBlock[1];
  35. const uint iqs = idx & 0xF;
  36. const uint shift = (idx & 0x10) >> 2;
  37. uint32_t qs = bl.block.qs[iqs];
  38. qs >>= shift;
  39. qs &= 0xF;
  40. float16_t ret = float16_t(qs) * d + m;
  41. return ret;
  42. }
  43. layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 {
  44. block_q5_0 block;
  45. };
  46. float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  47. {
  48. const float16_t d = bl.block.d;
  49. const uint idx = coordInBlock[1];
  50. const uint iqs = idx & 0xF;
  51. const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0];
  52. const uint qh = ((uint_qh >> idx) << 4) & 0x10;
  53. const uint shift = (idx & 0x10) >> 2;
  54. uint32_t qs = bl.block.qs[iqs];
  55. qs >>= shift;
  56. qs &= 0xF;
  57. float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d;
  58. return ret;
  59. }
  60. layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 {
  61. block_q5_1 block;
  62. };
  63. float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  64. {
  65. const float16_t d = bl.block.d;
  66. const float16_t m = bl.block.m;
  67. const uint idx = coordInBlock[1];
  68. const uint iqs = idx & 0xF;
  69. const uint uint_qh = bl.block.qh;
  70. const uint qh = ((uint_qh >> idx) << 4) & 0x10;
  71. const uint shift = (idx & 0x10) >> 2;
  72. uint32_t qs = bl.block.qs[iqs];
  73. qs >>= shift;
  74. qs &= 0xF;
  75. float16_t ret = float16_t(qs | qh) * d + m;
  76. return ret;
  77. }
  78. layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 {
  79. block_q8_0_packed16 block;
  80. };
  81. float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  82. {
  83. const float16_t d = bl.block.d;
  84. const uint idx = coordInBlock[1];
  85. const uint iqs = idx;
  86. // Load 16b and select the byte for this element
  87. int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1];
  88. float16_t ret = float16_t(qs) * d;
  89. return ret;
  90. }
  91. layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K {
  92. block_q2_K block;
  93. };
  94. layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2_K_packed16 {
  95. block_q2_K_packed16 block;
  96. };
  97. float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  98. {
  99. decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl);
  100. const f16vec2 dm = bl.block.dm;
  101. const uint idx = coordInBlock[1];
  102. const uint scalesi = (idx & 0xF0) >> 4; // 0..15
  103. const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6
  104. uint qs = uint32_t(bl16.block.qs[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
  105. qs = (qs >> qsshift) & 0x0303;
  106. qs = unpack8(qs)[idx & 1];
  107. const uint scales = bl.block.scales[scalesi];
  108. float16_t ret = dm.x * float16_t(scales & 0xF) * float16_t(qs) - dm.y * float16_t(scales >> 4);
  109. return ret;
  110. }
  111. layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K {
  112. block_q3_K block;
  113. };
  114. float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  115. {
  116. const uint idx = coordInBlock[1];
  117. const uint iqs = idx;
  118. const uint n = iqs / 128; // 0,1
  119. const uint qsi = n * 32 + (iqs % 32); // 0..63
  120. const uint hmi = (iqs % 32); // 0..31
  121. const uint j = (iqs % 128) / 8; // 0..15
  122. const uint is = iqs / 16; // 0..15
  123. const uint halfsplit = ((iqs % 128) / 32); // 0,1,2,3
  124. const uint qsshift = halfsplit * 2; // 0,2,4,6
  125. const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
  126. uint32_t scaleidx0 = (is < 8) ? is : (is-8);
  127. uint32_t scaleidx0shift = (is < 8) ? 0 : 4;
  128. uint32_t scaleidx1 = is + 8 - (is/4)*4;
  129. uint32_t scaleidx1shift = (is/4)*2;
  130. const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4));
  131. const float16_t dl = bl.block.d * float16_t(us - 32);
  132. float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi ] >> qsshift) & 3) - (((bl.block.hmask[hmi ] & m) != 0) ? 0 : 4));
  133. return ret;
  134. }
  135. layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K {
  136. block_q4_K block;
  137. };
  138. layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 {
  139. block_q4_K_packed16 block;
  140. };
  141. layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 {
  142. block_q4_K_packed128 block;
  143. };
  144. #if defined(IS_MUL_MM2)
  145. // For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales
  146. // into shared memory and then process the whole tile using those scales.
  147. // There is a fetch function that loads into private variables and then a store
  148. // function that stores into shared memory.
  149. // Q4_K and Q5_K have the same encoding of scales, so everything is shared except
  150. // the part that fetches from the structure (which has a different block layout).
  151. #if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
  152. const uint shAscales_stride = (BM + 2);
  153. // 1 scale per 32 elements -> 8 scales per block, per row
  154. shared vec2 shAscales[8 * shAscales_stride];
  155. uvec4 row_v;
  156. #endif
  157. #if defined(DATA_A_Q4_K)
  158. layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];};
  159. void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
  160. {
  161. uint tids_per_row = BLOCK_SIZE / BM;
  162. uint is_per_tid = 8 / tids_per_row;
  163. uint is_start = is_per_tid * (tid % tids_per_row);
  164. uint tid_row = tid / tids_per_row;
  165. uint row = ir_BM + tid_row;
  166. uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
  167. if (in_bounds || row < p.M) {
  168. row_v = data_a_q4_k_packed128[block_index].q4k[0];
  169. }
  170. }
  171. #endif
  172. #if defined(DATA_A_Q5_K)
  173. layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];};
  174. void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
  175. {
  176. uint tids_per_row = BLOCK_SIZE / BM;
  177. uint is_per_tid = 8 / tids_per_row;
  178. uint is_start = is_per_tid * (tid % tids_per_row);
  179. uint tid_row = tid / tids_per_row;
  180. uint row = ir_BM + tid_row;
  181. uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
  182. if (in_bounds || row < p.M) {
  183. row_v = data_a_q5_k_packed128[block_index].q5k[0];
  184. }
  185. }
  186. #endif
  187. #if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
  188. void store_scalesQ4_K(uint tid)
  189. {
  190. barrier();
  191. uint tids_per_row = BLOCK_SIZE / BM;
  192. uint is_per_tid = 8 / tids_per_row;
  193. uint is_start = is_per_tid * (tid % tids_per_row);
  194. uint tid_row = tid / tids_per_row;
  195. [[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) {
  196. uint is = idx + is_start;
  197. uvec4 v = row_v;
  198. const vec2 loadd = vec2(unpackFloat2x16(v.x));
  199. uint32_t sc;
  200. uint32_t mbyte;
  201. uint32_t scale0 = v.y;
  202. uint32_t scale4 = v.z;
  203. uint32_t scale8 = v.w;
  204. uint32_t sc_lo = scale0;
  205. uint32_t mb_lo = scale4;
  206. uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
  207. uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
  208. sc = is < 4 ? sc_lo : sc_hi;
  209. mbyte = is < 4 ? mb_lo : mb_hi;
  210. sc = sc >> (8 * (is & 3));
  211. mbyte = mbyte >> (8 * (is & 3));
  212. sc &= 0x3F;
  213. mbyte &= 0x3F;
  214. const float d = loadd.x * float(sc);
  215. const float m = loadd.y * float(mbyte);
  216. shAscales[is * shAscales_stride + tid_row] = vec2(d,m);
  217. }
  218. barrier();
  219. }
  220. #endif
  221. #endif
  222. float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  223. {
  224. decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
  225. decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl);
  226. const uint idx = coordInBlock[1];
  227. const uint b = (idx & 0x20) >> 5; // 0,1
  228. const uint is = (idx & 0xE0) >> 5; // 0..7
  229. #if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K)
  230. vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
  231. float d = v.x;
  232. float m = v.y;
  233. #else
  234. uvec4 v = bl128.block.q4k[0];
  235. const vec2 loadd = vec2(unpackFloat2x16(v.x));
  236. uint32_t sc;
  237. uint32_t mbyte;
  238. uint32_t scale0 = v.y;
  239. uint32_t scale4 = v.z;
  240. uint32_t scale8 = v.w;
  241. uint32_t sc_lo = scale0;
  242. uint32_t mb_lo = scale4;
  243. uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
  244. uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
  245. sc = is < 4 ? sc_lo : sc_hi;
  246. mbyte = is < 4 ? mb_lo : mb_hi;
  247. sc = sc >> (8 * (is & 3));
  248. mbyte = mbyte >> (8 * (is & 3));
  249. sc &= 0x3F;
  250. mbyte &= 0x3F;
  251. const float d = loadd.x * float(sc);
  252. const float m = loadd.y * float(mbyte);
  253. #endif
  254. uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
  255. qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
  256. float ret = d * float(qs) - m;
  257. return float16_t(ret);
  258. }
  259. layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {
  260. block_q5_K block;
  261. };
  262. layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 {
  263. block_q5_K_packed16 block;
  264. };
  265. layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 {
  266. block_q5_K_packed128 block;
  267. };
  268. float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  269. {
  270. decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl);
  271. decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl);
  272. const uint idx = coordInBlock[1];
  273. const uint b = (idx & 0x20) >> 5; // 0,1
  274. const uint is = (idx & 0xE0) >> 5; // 0..7
  275. #if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K)
  276. vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
  277. float d = v.x;
  278. float m = v.y;
  279. #else
  280. uvec4 v = bl128.block.q5k[0];
  281. const f16vec2 loadd = unpackFloat2x16(v.x);
  282. uint32_t sc;
  283. uint32_t mbyte;
  284. uint32_t scale0 = v.y;
  285. uint32_t scale4 = v.z;
  286. uint32_t scale8 = v.w;
  287. uint32_t sc_lo = scale0;
  288. uint32_t mb_lo = scale4;
  289. uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
  290. uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
  291. sc = is < 4 ? sc_lo : sc_hi;
  292. mbyte = is < 4 ? mb_lo : mb_hi;
  293. sc = sc >> (8 * (is & 3));
  294. mbyte = mbyte >> (8 * (is & 3));
  295. sc &= 0x3F;
  296. mbyte &= 0x3F;
  297. const float16_t d = loadd.x * float16_t(sc);
  298. const float16_t m = loadd.y * float16_t(mbyte);
  299. #endif
  300. uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
  301. qh = ((qh >> is) & 0x101) << 4;
  302. uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
  303. qs = (qs >> (b * 4)) & 0x0F0F;
  304. qs = unpack8(qs | qh)[idx & 1];
  305. float ret = d * float(qs) - m;
  306. return float16_t(ret);
  307. }
  308. layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
  309. block_q6_K block;
  310. };
  311. layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 {
  312. block_q6_K_packed16 block;
  313. };
  314. float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  315. {
  316. decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl);
  317. const uint idx = coordInBlock[1];
  318. const uint b = (idx & 0x40) >> 6; // 0,1
  319. const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6
  320. const uint is = (idx & 0xF0) >> 4; // 0..15
  321. const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);
  322. uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]);
  323. ql = (ql >> (b * 4)) & 0x0F0F;
  324. uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]);
  325. qh = ((qh >> qhshift) & 0x0303) << 4;
  326. int q = unpack8(ql | qh)[idx & 1];
  327. float16_t ret = dscale * float16_t(q - 32);
  328. return ret;
  329. }
  330. #if defined(DATA_A_IQ1_S)
  331. layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S {
  332. block_iq1_s block;
  333. };
  334. float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  335. {
  336. const float16_t d = bl.block.d;
  337. const uint idx = coordInBlock[1];
  338. const uint ib32 = (idx & 0xE0) >> 5;
  339. const uint ib8 = (idx & 0xF8) >> 3;
  340. const uint qh = bl.block.qh[ib32];
  341. const uint qs = bl.block.qs[ib8];
  342. const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
  343. const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
  344. const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)];
  345. float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta));
  346. return ret;
  347. }
  348. #endif
  349. #if defined(DATA_A_IQ1_M)
  350. layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_M {
  351. block_iq1_m block;
  352. };
  353. layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufIQ1_M_packed64 {
  354. block_iq1_m_packed64 block;
  355. };
  356. float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  357. {
  358. decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl);
  359. const uint idx = coordInBlock[1];
  360. uvec2 scales = unpack32(bl64.block.scales);
  361. const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16)));
  362. const uint ib8 = (idx & 0xF8) >> 3;
  363. const uint ib16 = (idx & 0xF0) >> 4;
  364. const int i8 = int(idx % 8);
  365. const uint sc = bl.block.scales[ib8 / 8];
  366. const uint qs = bl.block.qs[ib8];
  367. const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1));
  368. const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1;
  369. const float delta = ((qh & 8) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
  370. const uint grid = iq1s_grid[qs | ((qh & 7) << 8)];
  371. float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta));
  372. return ret;
  373. }
  374. #endif
  375. #if defined(DATA_A_IQ2_XXS)
  376. layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS {
  377. block_iq2_xxs block;
  378. };
  379. layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS_packed16 {
  380. block_iq2_xxs_packed16 block;
  381. };
  382. float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  383. {
  384. decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl);
  385. const float16_t d = bl.block.d;
  386. const uint idx = coordInBlock[1];
  387. const uint ib32 = (idx & 0xE0) >> 5; // 0..7
  388. const uint ib8 = (idx & 0x18) >> 3; // 0..3
  389. const uint iqs = 8 * ib32 + ib8;
  390. const uint qs = bl.block.qs[iqs];
  391. const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3]));
  392. const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28));
  393. uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7);
  394. sign |= bitCount(sign) << 7;
  395. uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2];
  396. g2 >>= (idx & 2) * 8;
  397. const vec2 g = vec2(unpack8(g2));
  398. vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
  399. return float16_t(ret[idx & 1]);
  400. }
  401. #endif
  402. #if defined(DATA_A_IQ2_XS)
  403. layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XS {
  404. block_iq2_xs block;
  405. };
  406. float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  407. {
  408. const float16_t d = bl.block.d;
  409. const uint idx = coordInBlock[1];
  410. const uint is = (idx & 0xE0) >> 5; // 0..8
  411. const uint sshift = (idx & 0x10) >> 2; // 0,4
  412. const uint iqs = (idx & 0xF8) >> 3; // 0..63
  413. const uint16_t qs = bl.block.qs[iqs];
  414. const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF));
  415. uint sign = uint(qs >> 9);
  416. sign |= bitCount(sign) << 7;
  417. uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2];
  418. g2 >>= (idx & 2) * 8;
  419. const vec2 g = vec2(unpack8(g2));
  420. vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
  421. return float16_t(ret[idx & 1]);
  422. }
  423. #endif
  424. #if defined(DATA_A_IQ2_S)
  425. layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_S {
  426. block_iq2_s block;
  427. };
  428. float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  429. {
  430. uint idx = coordInBlock[1];
  431. const uint ib32 = (idx & 0xE0) >> 5; // 0..7
  432. const uint ib8 = (idx & 0xF8) >> 3; // 0..31
  433. const uint qhshift = 2 * (ib8 % 4);
  434. const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf;
  435. const uint qs = bl.block.qs[ib8];
  436. const uint qh = bl.block.qh[ib32];
  437. const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (idx & 0x6);
  438. const float d = float(bl.block.d);
  439. const float db = d * 0.25 * (0.5 + scale);
  440. const ivec2 sign01 = 1 - (2 & ivec2(sign << 1, sign));
  441. uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2];
  442. g2 >>= (idx & 2) * 8;
  443. const vec2 v = db * vec2(sign01) * vec2(unpack8(g2));
  444. return float16_t(v[idx & 1]);
  445. }
  446. #endif
  447. #if defined(DATA_A_IQ3_XXS)
  448. layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS {
  449. block_iq3_xxs block;
  450. };
  451. layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS_packed16 {
  452. block_iq3_xxs_packed16 block;
  453. };
  454. float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  455. {
  456. decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl);
  457. uint idx = coordInBlock[1];
  458. const uint iqs = (idx & 0xFC) >> 2; // 0..63
  459. const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3);// 8 values
  460. const float d = float(bl.block.d);
  461. const uint qs = bl.block.qs[iqs];
  462. const uint signs = pack32(u16vec2(
  463. bl16.block.qs[is/2+0],
  464. bl16.block.qs[is/2+1]
  465. ));
  466. const float db = d * 0.5 * (0.5 + (signs >> 28));
  467. const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
  468. const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6);
  469. const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign)));
  470. const uint grid = iq3xxs_grid[qs] >> (16 * ((idx & 2) >> 1));
  471. const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
  472. return float16_t(v[idx & 1]);
  473. }
  474. #endif
  475. #if defined(DATA_A_IQ3_S)
  476. layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_S {
  477. block_iq3_s block;
  478. };
  479. float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  480. {
  481. uint idx = coordInBlock[1];
  482. const uint iqs = (idx & 0xFC) >> 2; // 0..63
  483. const uint iqh = (idx & 0xE0) >> 5;
  484. const float d = float(bl.block.d);
  485. const uint qs = bl.block.qs[iqs];
  486. const uint qh = bl.block.qh[iqh];
  487. const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (idx & 0x6));
  488. const uint scale = bl.block.scales[iqs / 16];
  489. const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign)));
  490. const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
  491. const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> ((idx & 2) << 3);
  492. const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
  493. return float16_t(v[idx & 1]);
  494. }
  495. #endif
  496. #if defined(DATA_A_IQ4_XS)
  497. layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_XS {
  498. block_iq4_xs block;
  499. };
  500. float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  501. {
  502. const float16_t d = bl.block.d;
  503. const uint idx = coordInBlock[1];
  504. const uint ib32 = (idx & 0xE0) >> 5; // 0..7
  505. const uint sl = (bl.block.scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
  506. const uint sh = ((bl.block.scales_h) >> (2 * ib32)) & 3;
  507. const uint qshift = (idx & 16) >> 2;
  508. const uint q = (bl.block.qs[16 * ib32 + (idx % 16)] >> qshift) & 0xF;
  509. float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]);
  510. return ret;
  511. }
  512. #endif
  513. #if defined(DATA_A_IQ4_NL)
  514. layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
  515. block_iq4_nl block;
  516. };
  517. float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  518. {
  519. const float16_t d = bl.block.d;
  520. const uint idx = coordInBlock[1];
  521. const uint iqs = idx & 0xF;
  522. const uint shift = (idx & 0x10) >> 2;
  523. uint32_t qs = bl.block.qs[iqs];
  524. qs >>= shift;
  525. qs &= 0xF;
  526. float16_t ret = float16_t(kvalues_iq4nl[qs]) * d;
  527. return ret;
  528. }
  529. #endif
  530. #if defined(DATA_A_MXFP4)
  531. layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
  532. block_mxfp4 block;
  533. };
  534. float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  535. {
  536. const float d = e8m0_to_fp32(bl.block.e);
  537. const uint idx = coordInBlock[1];
  538. const uint iqs = idx & 0xF;
  539. const uint shift = (idx & 0x10) >> 2;
  540. uint32_t qs = bl.block.qs[iqs];
  541. qs >>= shift;
  542. qs &= 0xF;
  543. float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5);
  544. return ret;
  545. }
  546. #endif
  547. #if defined(DATA_A_Q4_0)
  548. #define dequantFuncA dequantFuncQ4_0
  549. #elif defined(DATA_A_Q4_1)
  550. #define dequantFuncA dequantFuncQ4_1
  551. #elif defined(DATA_A_Q5_0)
  552. #define dequantFuncA dequantFuncQ5_0
  553. #elif defined(DATA_A_Q5_1)
  554. #define dequantFuncA dequantFuncQ5_1
  555. #elif defined(DATA_A_Q8_0)
  556. #define dequantFuncA dequantFuncQ8_0
  557. #elif defined(DATA_A_Q2_K)
  558. #define dequantFuncA dequantFuncQ2_K
  559. #elif defined(DATA_A_Q3_K)
  560. #define dequantFuncA dequantFuncQ3_K
  561. #elif defined(DATA_A_Q4_K)
  562. #define dequantFuncA dequantFuncQ4_K
  563. #define fetch_scales fetch_scalesQ4_K
  564. #define store_scales store_scalesQ4_K
  565. #elif defined(DATA_A_Q5_K)
  566. #define dequantFuncA dequantFuncQ5_K
  567. #define fetch_scales fetch_scalesQ5_K
  568. #define store_scales store_scalesQ4_K
  569. #elif defined(DATA_A_Q6_K)
  570. #define dequantFuncA dequantFuncQ6_K
  571. #elif defined(DATA_A_IQ1_S)
  572. #define dequantFuncA dequantFuncIQ1_S
  573. #elif defined(DATA_A_IQ1_M)
  574. #define dequantFuncA dequantFuncIQ1_M
  575. #elif defined(DATA_A_IQ2_XXS)
  576. #define dequantFuncA dequantFuncIQ2_XXS
  577. #elif defined(DATA_A_IQ2_XS)
  578. #define dequantFuncA dequantFuncIQ2_XS
  579. #elif defined(DATA_A_IQ2_S)
  580. #define dequantFuncA dequantFuncIQ2_S
  581. #elif defined(DATA_A_IQ3_XXS)
  582. #define dequantFuncA dequantFuncIQ3_XXS
  583. #elif defined(DATA_A_IQ3_S)
  584. #define dequantFuncA dequantFuncIQ3_S
  585. #elif defined(DATA_A_IQ4_XS)
  586. #define dequantFuncA dequantFuncIQ4_XS
  587. #elif defined(DATA_A_IQ4_NL)
  588. #define dequantFuncA dequantFuncIQ4_NL
  589. #elif defined(DATA_A_MXFP4)
  590. #define dequantFuncA dequantFuncMXFP4
  591. #elif defined(DATA_A_F32)
  592. #define dequantFuncA dequantFuncF32
  593. #endif