k_quants.c 84 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244
  1. #include "k_quants.h"
  2. #include "ggml.h"
  3. #include <math.h>
  4. #include <string.h>
  5. #include <assert.h>
  6. #ifdef __ARM_NEON
  7. // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
  8. //
  9. // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
  10. //
  11. #include <arm_neon.h>
  12. #else
  13. #ifdef __wasm_simd128__
  14. #include <wasm_simd128.h>
  15. #else
  16. #ifdef __POWER9_VECTOR__
  17. #include <altivec.h>
  18. #undef bool
  19. #define bool _Bool
  20. #else
  21. #if defined(_MSC_VER) || defined(__MINGW32__)
  22. #include <intrin.h>
  23. #else
  24. #if !defined(__riscv)
  25. #include <immintrin.h>
  26. #endif
  27. #endif
  28. #endif
  29. #endif
  30. #endif
  31. #undef MIN
  32. #undef MAX
  33. #define MIN(a, b) ((a) < (b) ? (a) : (b))
  34. #define MAX(a, b) ((a) > (b) ? (a) : (b))
  35. //
  36. // 2-6 bit quantization in super-blocks
  37. //
  38. //
  39. // ===================== Helper functions
  40. //
  41. static inline int nearest_int(float fval) {
  42. assert(fval <= 4194303.f);
  43. float val = fval + 12582912.f;
  44. int i; memcpy(&i, &val, sizeof(int));
  45. return (i & 0x007fffff) - 0x00400000;
  46. }
  47. static float make_qx_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, int rmse_type) {
  48. float max = 0;
  49. float amax = 0;
  50. for (int i = 0; i < n; ++i) {
  51. float ax = fabsf(x[i]);
  52. if (ax > amax) { amax = ax; max = x[i]; }
  53. }
  54. if (!amax) { // all zero
  55. for (int i = 0; i < n; ++i) {
  56. L[i] = 0;
  57. }
  58. return 0.f;
  59. }
  60. float iscale = -nmax / max;
  61. if (rmse_type == 0) {
  62. for (int i = 0; i < n; ++i) {
  63. int l = nearest_int(iscale * x[i]);
  64. L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
  65. }
  66. return 1/iscale;
  67. }
  68. int weight_type = rmse_type%2;
  69. float sumlx = 0;
  70. float suml2 = 0;
  71. for (int i = 0; i < n; ++i) {
  72. int l = nearest_int(iscale * x[i]);
  73. l = MAX(-nmax, MIN(nmax-1, l));
  74. L[i] = l + nmax;
  75. float w = weight_type == 1 ? x[i] * x[i] : 1;
  76. sumlx += w*x[i]*l;
  77. suml2 += w*l*l;
  78. }
  79. float scale = sumlx/suml2;
  80. float best = scale * sumlx;
  81. for (int itry = 0; itry < 3; ++itry) {
  82. iscale = 1/scale;
  83. float slx = 0;
  84. float sl2 = 0;
  85. bool changed = false;
  86. for (int i = 0; i < n; ++i) {
  87. int l = nearest_int(iscale * x[i]);
  88. l = MAX(-nmax, MIN(nmax-1, l));
  89. if (l + nmax != L[i]) { changed = true; }
  90. float w = weight_type == 1 ? x[i] * x[i] : 1.f;
  91. slx += w*x[i]*l;
  92. sl2 += w*l*l;
  93. }
  94. if (!changed || sl2 == 0 || slx*slx <= best*sl2) { break; }
  95. for (int i = 0; i < n; ++i) {
  96. int l = nearest_int(iscale * x[i]);
  97. L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
  98. }
  99. sumlx = slx; suml2 = sl2;
  100. scale = sumlx/suml2;
  101. best = scale * sumlx;
  102. }
  103. for (int itry = 0; itry < 5; ++itry) {
  104. int n_changed = 0;
  105. for (int i = 0; i < n; ++i) {
  106. float w = weight_type == 1 ? x[i]*x[i] : 1;
  107. int l = L[i] - nmax;
  108. float slx = sumlx - w*x[i]*l;
  109. if (slx > 0) {
  110. float sl2 = suml2 - w*l*l;
  111. int new_l = nearest_int(x[i] * sl2 / slx);
  112. new_l = MAX(-nmax, MIN(nmax-1, new_l));
  113. if (new_l != l) {
  114. slx += w*x[i]*new_l;
  115. sl2 += w*new_l*new_l;
  116. if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
  117. L[i] = nmax + new_l; sumlx = slx; suml2 = sl2;
  118. scale = sumlx / suml2; best = scale * sumlx;
  119. ++n_changed;
  120. }
  121. }
  122. }
  123. }
  124. if (!n_changed) { break; }
  125. }
  126. if (rmse_type < 3) {
  127. return scale;
  128. }
  129. for (int is = -4; is <= 4; ++is) {
  130. if (is == 0) {
  131. continue;
  132. }
  133. iscale = -(nmax + 0.1f*is) / max;
  134. sumlx = suml2 = 0;
  135. for (int i = 0; i < n; ++i) {
  136. int l = nearest_int(iscale * x[i]);
  137. l = MAX(-nmax, MIN(nmax-1, l));
  138. float w = weight_type == 1 ? x[i] * x[i] : 1;
  139. sumlx += w*x[i]*l;
  140. suml2 += w*l*l;
  141. }
  142. if (suml2 > 0 && sumlx*sumlx > best*suml2) {
  143. for (int i = 0; i < n; ++i) {
  144. int l = nearest_int(iscale * x[i]);
  145. L[i] = nmax + MAX(-nmax, MIN(nmax-1, l));
  146. }
  147. scale = sumlx/suml2; best = scale*sumlx;
  148. }
  149. }
  150. return scale;
  151. }
  152. static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t * restrict L, bool do_rmse) {
  153. float max = 0;
  154. float amax = 0;
  155. for (int i = 0; i < n; ++i) {
  156. float ax = fabsf(x[i]);
  157. if (ax > amax) { amax = ax; max = x[i]; }
  158. }
  159. if (!amax) { // all zero
  160. for (int i = 0; i < n; ++i) { L[i] = 0; }
  161. return 0.f;
  162. }
  163. float iscale = -nmax / max;
  164. if (do_rmse) {
  165. float sumlx = 0;
  166. float suml2 = 0;
  167. for (int i = 0; i < n; ++i) {
  168. int l = nearest_int(iscale * x[i]);
  169. l = MAX(-nmax, MIN(nmax-1, l));
  170. L[i] = l;
  171. float w = x[i]*x[i];
  172. sumlx += w*x[i]*l;
  173. suml2 += w*l*l;
  174. }
  175. for (int itry = 0; itry < 5; ++itry) {
  176. int n_changed = 0;
  177. for (int i = 0; i < n; ++i) {
  178. float w = x[i]*x[i];
  179. float slx = sumlx - w*x[i]*L[i];
  180. if (slx > 0) {
  181. float sl2 = suml2 - w*L[i]*L[i];
  182. int new_l = nearest_int(x[i] * sl2 / slx);
  183. new_l = MAX(-nmax, MIN(nmax-1, new_l));
  184. if (new_l != L[i]) {
  185. slx += w*x[i]*new_l;
  186. sl2 += w*new_l*new_l;
  187. if (sl2 > 0 && slx*slx*suml2 > sumlx*sumlx*sl2) {
  188. L[i] = new_l; sumlx = slx; suml2 = sl2;
  189. ++n_changed;
  190. }
  191. }
  192. }
  193. }
  194. if (!n_changed) {
  195. break;
  196. }
  197. }
  198. for (int i = 0; i < n; ++i) {
  199. L[i] += nmax;
  200. }
  201. return sumlx / suml2;
  202. }
  203. for (int i = 0; i < n; ++i) {
  204. int l = nearest_int(iscale * x[i]);
  205. l = MAX(-nmax, MIN(nmax-1, l));
  206. L[i] = l + nmax;
  207. }
  208. return 1/iscale;
  209. }
  210. static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min, int ntry) {
  211. float min = x[0];
  212. float max = x[0];
  213. for (int i = 1; i < n; ++i) {
  214. if (x[i] < min) min = x[i];
  215. if (x[i] > max) max = x[i];
  216. }
  217. if (max == min) {
  218. for (int i = 0; i < n; ++i) L[i] = 0;
  219. *the_min = 0;
  220. return 0.f;
  221. }
  222. if (min > 0) min = 0;
  223. float iscale = nmax/(max - min);
  224. float scale = 1/iscale;
  225. for (int itry = 0; itry < ntry; ++itry) {
  226. float sumlx = 0; int suml2 = 0;
  227. bool did_change = false;
  228. for (int i = 0; i < n; ++i) {
  229. int l = nearest_int(iscale*(x[i] - min));
  230. l = MAX(0, MIN(nmax, l));
  231. if (l != L[i]) {
  232. L[i] = l;
  233. did_change = true;
  234. }
  235. sumlx += (x[i] - min)*l;
  236. suml2 += l*l;
  237. }
  238. scale = sumlx/suml2;
  239. float sum = 0;
  240. for (int i = 0; i < n; ++i) {
  241. sum += x[i] - scale*L[i];
  242. }
  243. min = sum/n;
  244. if (min > 0) min = 0;
  245. iscale = 1/scale;
  246. if (!did_change) break;
  247. }
  248. *the_min = -min;
  249. return scale;
  250. }
  251. static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
  252. if (j < 4) {
  253. *d = q[j] & 63; *m = q[j + 4] & 63;
  254. } else {
  255. *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
  256. *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
  257. }
  258. }
  259. //========================- 2-bit (de)-quantization
  260. void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int k) {
  261. assert(k % QK_K == 0);
  262. const int nb = k / QK_K;
  263. uint8_t L[QK_K];
  264. float mins[QK_K/16];
  265. float scales[QK_K/16];
  266. const float q4scale = 15.f;
  267. for (int i = 0; i < nb; i++) {
  268. float max_scale = 0; // as we are deducting the min, scales are always positive
  269. float max_min = 0;
  270. for (int j = 0; j < QK_K/16; ++j) {
  271. scales[j] = make_qkx1_quants(16, 3, x + 16*j, L + 16*j, &mins[j], 5);
  272. float scale = scales[j];
  273. if (scale > max_scale) {
  274. max_scale = scale;
  275. }
  276. float min = mins[j];
  277. if (min > max_min) {
  278. max_min = min;
  279. }
  280. }
  281. if (max_scale > 0) {
  282. float iscale = q4scale/max_scale;
  283. for (int j = 0; j < QK_K/16; ++j) {
  284. int l = nearest_int(iscale*scales[j]);
  285. y[i].scales[j] = l;
  286. }
  287. y[i].d = ggml_fp32_to_fp16(max_scale/q4scale);
  288. } else {
  289. for (int j = 0; j < QK_K/16; ++j) y[i].scales[j] = 0;
  290. y[i].d = ggml_fp32_to_fp16(0.f);
  291. }
  292. if (max_min > 0) {
  293. float iscale = q4scale/max_min;
  294. for (int j = 0; j < QK_K/16; ++j) {
  295. int l = nearest_int(iscale*mins[j]);
  296. y[i].scales[j] |= (l << 4);
  297. }
  298. y[i].dmin = ggml_fp32_to_fp16(max_min/q4scale);
  299. } else {
  300. y[i].dmin = ggml_fp32_to_fp16(0.f);
  301. }
  302. for (int j = 0; j < QK_K/16; ++j) {
  303. const float d = ggml_fp16_to_fp32(y[i].d) * (y[i].scales[j] & 0xF);
  304. if (!d) continue;
  305. const float dm = ggml_fp16_to_fp32(y[i].dmin) * (y[i].scales[j] >> 4);
  306. for (int ii = 0; ii < 16; ++ii) {
  307. int l = nearest_int((x[16*j + ii] + dm)/d);
  308. l = MAX(0, MIN(3, l));
  309. L[16*j + ii] = l;
  310. }
  311. }
  312. for (int j = 0; j < QK_K; j += 128) {
  313. for (int l = 0; l < 32; ++l) {
  314. y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
  315. }
  316. }
  317. x += QK_K;
  318. }
  319. }
  320. void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int k) {
  321. assert(k % QK_K == 0);
  322. const int nb = k / QK_K;
  323. for (int i = 0; i < nb; i++) {
  324. const float d = ggml_fp16_to_fp32(x[i].d);
  325. const float min = ggml_fp16_to_fp32(x[i].dmin);
  326. const uint8_t * q = x[i].qs;
  327. int is = 0;
  328. float dl, ml;
  329. for (int n = 0; n < QK_K; n += 128) {
  330. int shift = 0;
  331. for (int j = 0; j < 4; ++j) {
  332. uint8_t sc = x[i].scales[is++];
  333. dl = d * (sc & 0xF); ml = min * (sc >> 4);
  334. for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
  335. sc = x[i].scales[is++];
  336. dl = d * (sc & 0xF); ml = min * (sc >> 4);
  337. for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
  338. shift += 2;
  339. }
  340. q += 32;
  341. }
  342. }
  343. }
  344. void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) {
  345. quantize_row_q2_K_reference(x, vy, k);
  346. }
  347. size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
  348. const int nb = k / QK_K;
  349. // TODO - collect histograms - although, at a second thought, I don't really care about them
  350. (void)hist;
  351. for (int j = 0; j < nb; j += k) {
  352. block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K;
  353. quantize_row_q2_K_reference(src + j, y, k);
  354. }
  355. return (n/QK_K*sizeof(block_q2_K));
  356. }
  357. //========================= 3-bit (de)-quantization
  358. void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int k) {
  359. assert(k % QK_K == 0);
  360. const int nb = k / QK_K;
  361. int8_t L[QK_K];
  362. float scales[QK_K / 16];
  363. for (int i = 0; i < nb; i++) {
  364. float max_scale = 0;
  365. float amax = 0;
  366. for (int j = 0; j < QK_K/16; ++j) {
  367. scales[j] = make_q3_quants(16, 4, x + 16*j, L + 16*j, true);
  368. float scale = fabsf(scales[j]);
  369. if (scale > amax) {
  370. amax = scale; max_scale = scales[j];
  371. }
  372. }
  373. memset(y[i].scales, 0, 12);
  374. if (max_scale) {
  375. float iscale = -32.f/max_scale;
  376. for (int j = 0; j < QK_K/16; ++j) {
  377. int8_t l = nearest_int(iscale*scales[j]);
  378. l = MAX(-32, MIN(31, l)) + 32;
  379. if (j < 8) {
  380. y[i].scales[j] = l & 0xF;
  381. } else {
  382. y[i].scales[j-8] |= ((l & 0xF) << 4);
  383. }
  384. l >>= 4;
  385. y[i].scales[j%4 + 8] |= (l << (2*(j/4)));
  386. }
  387. y[i].d = ggml_fp32_to_fp16(1/iscale);
  388. } else {
  389. y[i].d = ggml_fp32_to_fp16(0.f);
  390. }
  391. int8_t sc;
  392. for (int j = 0; j < QK_K/16; ++j) {
  393. sc = j < 8 ? y[i].scales[j] & 0xF : y[i].scales[j-8] >> 4;
  394. sc = (sc | (((y[i].scales[8 + j%4] >> (2*(j/4))) & 3) << 4)) - 32;
  395. float d = ggml_fp16_to_fp32(y[i].d) * sc;
  396. if (!d) {
  397. continue;
  398. }
  399. for (int ii = 0; ii < 16; ++ii) {
  400. int l = nearest_int(x[16*j + ii]/d);
  401. l = MAX(-4, MIN(3, l));
  402. L[16*j + ii] = l + 4;
  403. }
  404. }
  405. memset(y[i].hmask, 0, QK_K/8);
  406. // We put the high-bit for the 1st 32 quants into bit 0, the next 32 into bit 1, etc.
  407. int m = 0;
  408. uint8_t hm = 1;
  409. for (int j = 0; j < QK_K; ++j) {
  410. if (L[j] > 3) {
  411. y[i].hmask[m] |= hm;
  412. L[j] -= 4;
  413. }
  414. if (++m == QK_K/8) {
  415. m = 0; hm <<= 1;
  416. }
  417. }
  418. for (int j = 0; j < QK_K; j += 128) {
  419. for (int l = 0; l < 32; ++l) {
  420. y[i].qs[j/4 + l] = L[j + l] | (L[j + l + 32] << 2) | (L[j + l + 64] << 4) | (L[j + l + 96] << 6);
  421. }
  422. }
  423. x += QK_K;
  424. }
  425. }
  426. void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int k) {
  427. assert(k % QK_K == 0);
  428. assert(QK_K == 256);
  429. const int nb = k / QK_K;
  430. const uint32_t kmask1 = 0x03030303;
  431. const uint32_t kmask2 = 0x0f0f0f0f;
  432. uint32_t aux[4];
  433. const int8_t * scales = (const int8_t*)aux;
  434. for (int i = 0; i < nb; i++) {
  435. const float d_all = ggml_fp16_to_fp32(x[i].d);
  436. const uint8_t * restrict q = x[i].qs;
  437. const uint8_t * restrict hm = x[i].hmask;
  438. uint8_t m = 1;
  439. memcpy(aux, x[i].scales, 12);
  440. uint32_t tmp = aux[2];
  441. aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
  442. aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
  443. aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
  444. aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
  445. int is = 0;
  446. float dl;
  447. for (int n = 0; n < QK_K; n += 128) {
  448. int shift = 0;
  449. for (int j = 0; j < 4; ++j) {
  450. dl = d_all * (scales[is++] - 32);
  451. for (int l = 0; l < 16; ++l) {
  452. *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));
  453. }
  454. dl = d_all * (scales[is++] - 32);
  455. for (int l = 0; l < 16; ++l) {
  456. *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));
  457. }
  458. shift += 2;
  459. m <<= 1;
  460. }
  461. q += 32;
  462. }
  463. }
  464. }
  465. void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) {
  466. quantize_row_q3_K_reference(x, vy, k);
  467. }
  468. size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
  469. const int nb = k / QK_K;
  470. // TODO - collect histograms - although, at a second thought, I don't really care about them
  471. (void)hist;
  472. for (int j = 0; j < nb; j += k) {
  473. block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K;
  474. quantize_row_q3_K_reference(src + j, y, k);
  475. }
  476. return (n/QK_K*sizeof(block_q3_K));
  477. }
  478. // ====================== 4-bit (de)-quantization
  479. void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int k) {
  480. assert(k % QK_K == 0);
  481. const int nb = k / QK_K;
  482. uint8_t L[QK_K];
  483. float mins[QK_K/32];
  484. float scales[QK_K/32];
  485. for (int i = 0; i < nb; i++) {
  486. float max_scale = 0; // as we are deducting the min, scales are always positive
  487. float max_min = 0;
  488. for (int j = 0; j < QK_K/32; ++j) {
  489. scales[j] = make_qkx1_quants(32, 15, x + 32*j, L + 32*j, &mins[j], 5);
  490. float scale = scales[j];
  491. if (scale > max_scale) {
  492. max_scale = scale;
  493. }
  494. float min = mins[j];
  495. if (min > max_min) {
  496. max_min = min;
  497. }
  498. }
  499. float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
  500. float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
  501. for (int j = 0; j < QK_K/32; ++j) {
  502. uint8_t ls = nearest_int(inv_scale*scales[j]);
  503. uint8_t lm = nearest_int(inv_min*mins[j]);
  504. ls = MIN(63, ls);
  505. lm = MIN(63, lm);
  506. if (j < 4) {
  507. y[i].scales[j] = ls;
  508. y[i].scales[j+4] = lm;
  509. } else {
  510. y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
  511. y[i].scales[j-4] |= ((ls >> 4) << 6);
  512. y[i].scales[j-0] |= ((lm >> 4) << 6);
  513. }
  514. }
  515. y[i].d = ggml_fp32_to_fp16(max_scale/63.f);
  516. y[i].dmin = ggml_fp32_to_fp16(max_min/63.f);
  517. uint8_t sc, m;
  518. for (int j = 0; j < QK_K/32; ++j) {
  519. get_scale_min_k4(j, y[i].scales, &sc, &m);
  520. const float d = ggml_fp16_to_fp32(y[i].d) * sc;
  521. if (!d) continue;
  522. const float dm = ggml_fp16_to_fp32(y[i].dmin) * m;
  523. for (int ii = 0; ii < 32; ++ii) {
  524. int l = nearest_int((x[32*j + ii] + dm)/d);
  525. l = MAX(0, MIN(15, l));
  526. L[32*j + ii] = l;
  527. }
  528. }
  529. uint8_t * q = y[i].qs;
  530. for (int j = 0; j < QK_K; j += 64) {
  531. for (int l = 0; l < 32; ++l) *q++ = L[j + l] | (L[j + l + 32] << 4);
  532. }
  533. x += QK_K;
  534. }
  535. }
  536. void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int k) {
  537. assert(k % QK_K == 0);
  538. const int nb = k / QK_K;
  539. for (int i = 0; i < nb; i++) {
  540. const float d = ggml_fp16_to_fp32(x[i].d);
  541. const float min = ggml_fp16_to_fp32(x[i].dmin);
  542. const uint8_t * q = x[i].qs;
  543. int is = 0;
  544. uint8_t sc, m;
  545. for (int j = 0; j < QK_K; j += 64) {
  546. get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
  547. const float d1 = d * sc; const float m1 = min * m;
  548. get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
  549. const float d2 = d * sc; const float m2 = min * m;
  550. for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
  551. for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
  552. q += 32; is += 2;
  553. }
  554. }
  555. }
  556. void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) {
  557. assert(k % QK_K == 0);
  558. block_q4_K * restrict y = vy;
  559. quantize_row_q4_K_reference(x, y, k);
  560. }
  561. size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
  562. assert(k % QK_K == 0);
  563. const int nb = k / QK_K;
  564. (void)hist; // TODO: collect histograms
  565. for (int j = 0; j < nb; j += k) {
  566. block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K;
  567. quantize_row_q4_K_reference(src + j, y, k);
  568. }
  569. return (n/QK_K*sizeof(block_q4_K));
  570. }
  571. // ====================== 5-bit (de)-quantization
  572. void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int k) {
  573. assert(k % QK_K == 0);
  574. const int nb = k / QK_K;
  575. uint8_t L[QK_K];
  576. float mins[QK_K/32];
  577. float scales[QK_K/32];
  578. for (int i = 0; i < nb; i++) {
  579. float max_scale = 0; // as we are deducting the min, scales are always positive
  580. float max_min = 0;
  581. for (int j = 0; j < QK_K/32; ++j) {
  582. scales[j] = make_qkx1_quants(32, 31, x + 32*j, L + 32*j, &mins[j], 5);
  583. float scale = scales[j];
  584. if (scale > max_scale) {
  585. max_scale = scale;
  586. }
  587. float min = mins[j];
  588. if (min > max_min) {
  589. max_min = min;
  590. }
  591. }
  592. float inv_scale = max_scale > 0 ? 63.f/max_scale : 0.f;
  593. float inv_min = max_min > 0 ? 63.f/max_min : 0.f;
  594. for (int j = 0; j < QK_K/32; ++j) {
  595. uint8_t ls = nearest_int(inv_scale*scales[j]);
  596. uint8_t lm = nearest_int(inv_min*mins[j]);
  597. ls = MIN(63, ls);
  598. lm = MIN(63, lm);
  599. if (j < 4) {
  600. y[i].scales[j] = ls;
  601. y[i].scales[j+4] = lm;
  602. } else {
  603. y[i].scales[j+4] = (ls & 0xF) | ((lm & 0xF) << 4);
  604. y[i].scales[j-4] |= ((ls >> 4) << 6);
  605. y[i].scales[j-0] |= ((lm >> 4) << 6);
  606. }
  607. }
  608. y[i].d = ggml_fp32_to_fp16(max_scale/63.f);
  609. y[i].dmin = ggml_fp32_to_fp16(max_min/63.f);
  610. uint8_t sc, m;
  611. for (int j = 0; j < QK_K/32; ++j) {
  612. get_scale_min_k4(j, y[i].scales, &sc, &m);
  613. const float d = ggml_fp16_to_fp32(y[i].d) * sc;
  614. if (!d) continue;
  615. const float dm = ggml_fp16_to_fp32(y[i].dmin) * m;
  616. for (int ii = 0; ii < 32; ++ii) {
  617. int l = nearest_int((x[32*j + ii] + dm)/d);
  618. l = MAX(0, MIN(31, l));
  619. L[32*j + ii] = l;
  620. }
  621. }
  622. uint8_t * restrict qh = y[i].qh;
  623. uint8_t * restrict ql = y[i].qs;
  624. memset(qh, 0, QK_K/8);
  625. uint8_t m1 = 1, m2 = 2;
  626. for (int n = 0; n < QK_K; n += 64) {
  627. for (int j = 0; j < 32; ++j) {
  628. int l1 = L[n + j];
  629. if (l1 > 15) {
  630. l1 -= 16; qh[j] |= m1;
  631. }
  632. int l2 = L[n + j + 32];
  633. if (l2 > 15) {
  634. l2 -= 16; qh[j] |= m2;
  635. }
  636. ql[j] = l1 | (l2 << 4);
  637. }
  638. m1 <<= 2; m2 <<= 2;
  639. ql += 32;
  640. }
  641. x += QK_K;
  642. }
  643. }
  644. void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int k) {
  645. assert(k % QK_K == 0);
  646. const int nb = k / QK_K;
  647. for (int i = 0; i < nb; i++) {
  648. const float d = ggml_fp16_to_fp32(x[i].d);
  649. const float min = ggml_fp16_to_fp32(x[i].dmin);
  650. const uint8_t * ql = x[i].qs;
  651. const uint8_t * qh = x[i].qh;
  652. int is = 0;
  653. uint8_t sc, m;
  654. uint8_t u1 = 1, u2 = 2;
  655. for (int j = 0; j < QK_K; j += 64) {
  656. get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
  657. const float d1 = d * sc; const float m1 = min * m;
  658. get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
  659. const float d2 = d * sc; const float m2 = min * m;
  660. for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
  661. for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
  662. ql += 32; is += 2;
  663. u1 <<= 2; u2 <<= 2;
  664. }
  665. }
  666. }
  667. void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) {
  668. assert(k % QK_K == 0);
  669. block_q5_K * restrict y = vy;
  670. quantize_row_q5_K_reference(x, y, k);
  671. }
  672. size_t ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) {
  673. assert(k % QK_K == 0);
  674. const int nb = k / QK_K;
  675. (void)hist;
  676. for (int j = 0; j < nb; j += k) {
  677. block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K;
  678. quantize_row_q5_K_reference(src + j, y, k);
  679. }
  680. return (n/QK_K*sizeof(block_q5_K));
  681. }
  682. // ====================== 6-bit (de)-quantization
  683. void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int k) {
  684. assert(k % QK_K == 0);
  685. const int nb = k / QK_K;
  686. int8_t L[QK_K];
  687. float scales[QK_K/16];
  688. for (int i = 0; i < nb; i++) {
  689. float max_scale = 0;
  690. float max_abs_scale = 0;
  691. for (int ib = 0; ib < QK_K/16; ++ib) {
  692. const float scale = make_qx_quants(16, 32, x + 16*ib, L + 16*ib, 1);
  693. scales[ib] = scale;
  694. const float abs_scale = fabsf(scale);
  695. if (abs_scale > max_abs_scale) {
  696. max_abs_scale = abs_scale;
  697. max_scale = scale;
  698. }
  699. }
  700. float iscale = -128.f/max_scale;
  701. y[i].d = ggml_fp32_to_fp16(1/iscale);
  702. for (int ib = 0; ib < QK_K/16; ++ib) {
  703. y[i].scales[ib] = MIN(127, nearest_int(iscale*scales[ib]));
  704. }
  705. for (int j = 0; j < QK_K/16; ++j) {
  706. float d = ggml_fp16_to_fp32(y[i].d) * y[i].scales[j];
  707. if (!d) {
  708. continue;
  709. }
  710. for (int ii = 0; ii < 16; ++ii) {
  711. int l = nearest_int(x[16*j + ii]/d);
  712. l = MAX(-32, MIN(31, l));
  713. L[16*j + ii] = l + 32;
  714. }
  715. }
  716. uint8_t * restrict ql = y[i].ql;
  717. uint8_t * restrict qh = y[i].qh;
  718. for (int j = 0; j < QK_K; j += 128) {
  719. for (int l = 0; l < 32; ++l) {
  720. const uint8_t q1 = L[j + l + 0] & 0xF;
  721. const uint8_t q2 = L[j + l + 32] & 0xF;
  722. const uint8_t q3 = L[j + l + 64] & 0xF;
  723. const uint8_t q4 = L[j + l + 96] & 0xF;
  724. ql[l+ 0] = q1 | (q3 << 4);
  725. ql[l+32] = q2 | (q4 << 4);
  726. qh[l] = (L[j + l] >> 4) | ((L[j + l + 32] >> 4) << 2) | ((L[j + l + 64] >> 4) << 4) | ((L[j + l + 96] >> 4) << 6);
  727. }
  728. ql += 64;
  729. qh += 32;
  730. }
  731. x += QK_K;
  732. }
  733. }
  734. void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k) {
  735. assert(k % QK_K == 0);
  736. const int nb = k / QK_K;
  737. for (int i = 0; i < nb; i++) {
  738. const float d = ggml_fp16_to_fp32(x[i].d);
  739. const uint8_t * restrict ql = x[i].ql;
  740. const uint8_t * restrict qh = x[i].qh;
  741. const int8_t * restrict sc = x[i].scales;
  742. for (int n = 0; n < QK_K; n += 128) {
  743. for (int l = 0; l < 32; ++l) {
  744. int is = l/16;
  745. const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
  746. const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
  747. const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
  748. const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
  749. y[l + 0] = d * sc[is + 0] * q1;
  750. y[l + 32] = d * sc[is + 2] * q2;
  751. y[l + 64] = d * sc[is + 4] * q3;
  752. y[l + 96] = d * sc[is + 6] * q4;
  753. }
  754. y += 128;
  755. ql += 64;
  756. qh += 32;
  757. sc += 8;
  758. }
  759. }
  760. }
  761. void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) {
  762. assert(k % QK_K == 0);
  763. block_q6_K * restrict y = vy;
  764. quantize_row_q6_K_reference(x, y, k);
  765. }
  766. size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) {
  767. assert(k % QK_K == 0);
  768. const int nb = k / QK_K;
  769. (void)hist; // TODO
  770. for (int j = 0; j < nb; j += k) {
  771. block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K;
  772. quantize_row_q6_K_reference(src + j, y, k);
  773. }
  774. return (n/QK_K*sizeof(block_q6_K));
  775. }
  776. //===================================== Q8_K ==============================================
  777. void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
  778. assert(k % QK_K == 0);
  779. const int nb = k / QK_K;
  780. for (int i = 0; i < nb; i++) {
  781. float max = 0;
  782. float amax = 0;
  783. for (int j = 0; j < QK_K; ++j) {
  784. float ax = fabsf(x[j]);
  785. if (ax > amax) {
  786. amax = ax; max = x[j];
  787. }
  788. }
  789. if (!amax) {
  790. y[i].d = 0;
  791. memset(y[i].qs, 0, QK_K);
  792. x += QK_K;
  793. continue;
  794. }
  795. const float iscale = -128.f/max;
  796. for (int j = 0; j < QK_K; ++j) {
  797. int v = nearest_int(iscale*x[j]);
  798. y[i].qs[j] = MIN(127, v);
  799. }
  800. for (int j = 0; j < QK_K/16; ++j) {
  801. int sum = 0;
  802. for (int ii = 0; ii < 16; ++ii) {
  803. sum += y[i].qs[j*16 + ii];
  804. }
  805. y[i].bsums[j] = sum;
  806. }
  807. y[i].d = 1/iscale;
  808. x += QK_K;
  809. }
  810. }
  811. void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k) {
  812. assert(k % QK_K == 0);
  813. const int nb = k / QK_K;
  814. for (int i = 0; i < nb; i++) {
  815. for (int j = 0; j < QK_K; ++j) {
  816. *y++ = x[i].d * x[i].qs[j];
  817. }
  818. }
  819. }
  820. void quantize_row_q8_K(const float * restrict x, void * restrict y, int k) {
  821. quantize_row_q8_K_reference(x, y, k);
  822. }
  823. //===================================== Dot ptoducts =================================
  824. //
  825. // Helper functions
  826. //
  827. #if __AVX__ || __AVX2__ || __AVX512F__
  828. // horizontally add 8 floats
  829. static inline float hsum_float_8(const __m256 x) {
  830. __m128 res = _mm256_extractf128_ps(x, 1);
  831. res = _mm_add_ps(res, _mm256_castps256_ps128(x));
  832. res = _mm_add_ps(res, _mm_movehl_ps(res, res));
  833. res = _mm_add_ss(res, _mm_movehdup_ps(res));
  834. return _mm_cvtss_f32(res);
  835. }
  836. // shuffles to pick the required scales in dot products
  837. static inline __m256i get_scale_shuffle_q3k(int i) {
  838. static const uint8_t k_shuffle[128] = {
  839. 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
  840. 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
  841. 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
  842. 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13, 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,
  843. };
  844. return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
  845. }
  846. static inline __m256i get_scale_shuffle_k4(int i) {
  847. static const uint8_t k_shuffle[256] = {
  848. 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1,
  849. 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
  850. 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5, 4, 5,
  851. 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7, 6, 7,
  852. 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9, 8, 9,
  853. 10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,10,11,
  854. 12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,12,13,
  855. 14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15,14,15
  856. };
  857. return _mm256_loadu_si256((const __m256i*)k_shuffle + i);
  858. }
  859. static inline __m128i get_scale_shuffle(int i) {
  860. static const uint8_t k_shuffle[128] = {
  861. 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
  862. 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
  863. 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5,
  864. 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7,
  865. 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9,
  866. 10,10,10,10,10,10,10,10, 11,11,11,11,11,11,11,11,
  867. 12,12,12,12,12,12,12,12, 13,13,13,13,13,13,13,13,
  868. 14,14,14,14,14,14,14,14, 15,15,15,15,15,15,15,15
  869. };
  870. return _mm_loadu_si128((const __m128i*)k_shuffle + i);
  871. }
  872. #endif
  873. void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
  874. const block_q2_K * restrict x = vx;
  875. const block_q8_K * restrict y = vy;
  876. const int nb = n / QK_K;
  877. #ifdef __ARM_NEON
  878. const uint8x16_t m3 = vdupq_n_u8(0x3);
  879. const uint8x16_t m4 = vdupq_n_u8(0xF);
  880. const int32x4_t vzero = vdupq_n_s32(0);
  881. int8x16x2_t q2bytes;
  882. uint8_t aux[16];
  883. float sum = 0;
  884. for (int i = 0; i < nb; ++i) {
  885. const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
  886. const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
  887. const uint8_t * restrict q2 = x[i].qs;
  888. const int8_t * restrict q8 = y[i].qs;
  889. const uint8_t * restrict sc = x[i].scales;
  890. const uint8x16_t mins_and_scales = vld1q_u8(sc);
  891. const uint8x16_t scales = vandq_u8(mins_and_scales, m4);
  892. vst1q_u8(aux, scales);
  893. const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
  894. const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
  895. const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
  896. const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
  897. vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
  898. const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
  899. vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1])));
  900. sum += dmin * vaddvq_s32(vaddq_s32(s0, s1));
  901. int isum = 0;
  902. int is = 0;
  903. // We use this macro instead of a function call because for some reason
  904. // the code runs 2-3% slower, even if the function is declared inline
  905. #if defined(__ARM_FEATURE_DOTPROD)
  906. #define MULTIPLY_ACCUM_WITH_SCALE(index)\
  907. isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
  908. isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
  909. #else
  910. #define MULTIPLY_ACCUM_WITH_SCALE(index)\
  911. {\
  912. const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\
  913. vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\
  914. const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\
  915. vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\
  916. isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\
  917. }
  918. #endif
  919. #define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
  920. q8bytes = vld1q_s8_x2(q8); q8 += 32;\
  921. q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
  922. q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
  923. MULTIPLY_ACCUM_WITH_SCALE((index));
  924. for (int j = 0; j < QK_K/128; ++j) {
  925. const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32;
  926. int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32;
  927. q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
  928. q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
  929. MULTIPLY_ACCUM_WITH_SCALE(0);
  930. SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
  931. SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
  932. SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
  933. is += 8;
  934. }
  935. sum += d * isum;
  936. }
  937. *s = sum;
  938. #elif defined __AVX2__
  939. const __m256i m3 = _mm256_set1_epi8(3);
  940. const __m128i m4 = _mm_set1_epi8(0xF);
  941. __m256 acc = _mm256_setzero_ps();
  942. for (int i = 0; i < nb; ++i) {
  943. const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
  944. const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
  945. const uint8_t * restrict q2 = x[i].qs;
  946. const int8_t * restrict q8 = y[i].qs;
  947. const __m128i mins_and_scales = _mm_loadu_si128((const __m128i*)x[i].scales);
  948. const __m128i scales8 = _mm_and_si128(mins_and_scales, m4);
  949. const __m128i mins8 = _mm_and_si128(_mm_srli_epi16(mins_and_scales, 4), m4);
  950. const __m256i mins = _mm256_cvtepi8_epi16(mins8);
  951. const __m256i prod = _mm256_madd_epi16(mins, _mm256_loadu_si256((const __m256i*)y[i].bsums));
  952. acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&dmin), _mm256_cvtepi32_ps(prod), acc);
  953. const __m256i all_scales = _mm256_cvtepi8_epi16(scales8);
  954. const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
  955. const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
  956. const __m256i scales[2] = {_mm256_set_m128i(l_scales, l_scales), _mm256_set_m128i(h_scales, h_scales)};
  957. __m256i sumi = _mm256_setzero_si256();
  958. for (int j = 0; j < QK_K/128; ++j) {
  959. const __m256i q2bits = _mm256_loadu_si256((const __m256i*)q2); q2 += 32;
  960. const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
  961. const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
  962. const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
  963. const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
  964. const __m256i q2_0 = _mm256_and_si256(q2bits, m3);
  965. const __m256i q2_1 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 2), m3);
  966. const __m256i q2_2 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 4), m3);
  967. const __m256i q2_3 = _mm256_and_si256(_mm256_srli_epi16(q2bits, 6), m3);
  968. __m256i p0 = _mm256_maddubs_epi16(q2_0, q8_0);
  969. __m256i p1 = _mm256_maddubs_epi16(q2_1, q8_1);
  970. __m256i p2 = _mm256_maddubs_epi16(q2_2, q8_2);
  971. __m256i p3 = _mm256_maddubs_epi16(q2_3, q8_3);
  972. p0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(0)), p0);
  973. p1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(1)), p1);
  974. p2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(2)), p2);
  975. p3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(3)), p3);
  976. p0 = _mm256_add_epi32(p0, p1);
  977. p2 = _mm256_add_epi32(p2, p3);
  978. sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p0, p2));
  979. }
  980. acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
  981. }
  982. *s = hsum_float_8(acc);
  983. #else
  984. float sumf = 0;
  985. for (int i = 0; i < nb; ++i) {
  986. const uint8_t * q2 = x[i].qs;
  987. const int8_t * q8 = y[i].qs;
  988. const uint8_t * sc = x[i].scales;
  989. int summs = 0;
  990. for (int j = 0; j < 16; ++j) {
  991. summs += y[i].bsums[j] * (sc[j] >> 4);
  992. }
  993. const float dall = y[i].d * ggml_fp16_to_fp32(x[i].d);
  994. const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
  995. int isum = 0;
  996. int is = 0;
  997. int d;
  998. for (int k = 0; k < QK_K/128; ++k) {
  999. int shift = 0;
  1000. for (int j = 0; j < 4; ++j) {
  1001. d = sc[is++] & 0xF;
  1002. int isuml = 0;
  1003. for (int l = 0; l < 16; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
  1004. isum += d * isuml;
  1005. d = sc[is++] & 0xF;
  1006. isuml = 0;
  1007. for (int l = 16; l < 32; ++l) isuml += q8[l] * ((q2[l] >> shift) & 3);
  1008. isum += d * isuml;
  1009. shift += 2;
  1010. q8 += 32;
  1011. }
  1012. q2 += 32;
  1013. }
  1014. sumf += dall * isum - dmin * summs;
  1015. }
  1016. *s = sumf;
  1017. #endif
  1018. }
  1019. void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
  1020. assert(n % QK_K == 0);
  1021. const uint32_t kmask1 = 0x03030303;
  1022. const uint32_t kmask2 = 0x0f0f0f0f;
  1023. const block_q3_K * restrict x = vx;
  1024. const block_q8_K * restrict y = vy;
  1025. const int nb = n / QK_K;
  1026. #ifdef __ARM_NEON
  1027. uint32_t aux[3];
  1028. uint32_t utmp[4];
  1029. const uint8x16_t m3b = vdupq_n_u8(0x3);
  1030. #ifdef __ARM_FEATURE_DOTPROD
  1031. const int32x4_t vzero = vdupq_n_s32(0);
  1032. #endif
  1033. const uint8x16_t m0 = vdupq_n_u8(1);
  1034. const uint8x16_t m1 = vshlq_n_u8(m0, 1);
  1035. const uint8x16_t m2 = vshlq_n_u8(m0, 2);
  1036. const uint8x16_t m3 = vshlq_n_u8(m0, 3);
  1037. const int8_t m32 = 32;
  1038. int8x16x4_t q3bytes;
  1039. float sum = 0;
  1040. for (int i = 0; i < nb; ++i) {
  1041. const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
  1042. const uint8_t * restrict q3 = x[i].qs;
  1043. const uint8_t * restrict qh = x[i].hmask;
  1044. const int8_t * restrict q8 = y[i].qs;
  1045. uint8x16x2_t qhbits = vld1q_u8_x2(qh);
  1046. uint8x16x4_t q3h;
  1047. int32_t isum = 0;
  1048. // Set up scales
  1049. memcpy(aux, x[i].scales, 12);
  1050. utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
  1051. utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
  1052. utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
  1053. utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
  1054. int8_t * scale = (int8_t *)utmp;
  1055. for (int j = 0; j < 16; ++j) scale[j] -= m32;
  1056. for (int j = 0; j < QK_K/128; ++j) {
  1057. const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32;
  1058. const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64;
  1059. const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64;
  1060. q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
  1061. q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
  1062. q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1);
  1063. q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1);
  1064. q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0]));
  1065. q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1]));
  1066. q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
  1067. q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
  1068. #if defined(__ARM_FEATURE_DOTPROD)
  1069. isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
  1070. isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
  1071. isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
  1072. isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
  1073. #else
  1074. int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])),
  1075. vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0])));
  1076. int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])),
  1077. vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1])));
  1078. int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])),
  1079. vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2])));
  1080. int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])),
  1081. vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3])));
  1082. isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
  1083. #endif
  1084. scale += 4;
  1085. q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
  1086. q3h.val[1] = vbicq_u8(m2, qhbits.val[1]);
  1087. q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);
  1088. q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);
  1089. q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0]));
  1090. q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1]));
  1091. q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
  1092. q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
  1093. #if defined(__ARM_FEATURE_DOTPROD)
  1094. isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
  1095. isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
  1096. isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
  1097. isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
  1098. #else
  1099. p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])),
  1100. vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0])));
  1101. p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])),
  1102. vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1])));
  1103. p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])),
  1104. vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2])));
  1105. p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])),
  1106. vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3])));
  1107. isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
  1108. #endif
  1109. scale += 4;
  1110. if (j == 0) {
  1111. qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
  1112. qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);
  1113. }
  1114. }
  1115. sum += d * isum;
  1116. }
  1117. *s = sum;
  1118. #elif defined __AVX2__
  1119. const __m256i m3 = _mm256_set1_epi8(3);
  1120. const __m256i mone = _mm256_set1_epi8(1);
  1121. const __m128i m32 = _mm_set1_epi8(32);
  1122. __m256 acc = _mm256_setzero_ps();
  1123. uint32_t aux[3];
  1124. for (int i = 0; i < nb; ++i) {
  1125. const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
  1126. const uint8_t * restrict q3 = x[i].qs;
  1127. const int8_t * restrict q8 = y[i].qs;
  1128. // Set up scales
  1129. memcpy(aux, x[i].scales, 12);
  1130. __m128i scales128 = _mm_set_epi32(
  1131. ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4),
  1132. ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4),
  1133. (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4),
  1134. (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4));
  1135. scales128 = _mm_sub_epi8(scales128, m32);
  1136. const __m256i all_scales = _mm256_cvtepi8_epi16(scales128);
  1137. const __m128i l_scales = _mm256_extracti128_si256(all_scales, 0);
  1138. const __m128i h_scales = _mm256_extracti128_si256(all_scales, 1);
  1139. const __m256i scales[2] = {_mm256_set_m128i(l_scales, l_scales), _mm256_set_m128i(h_scales, h_scales)};
  1140. // high bit
  1141. const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].hmask);
  1142. // integer accumulator
  1143. __m256i sumi = _mm256_setzero_si256();
  1144. int bit = 0;
  1145. int is = 0;
  1146. for (int j = 0; j < QK_K/128; ++j) {
  1147. // load low 2 bits
  1148. const __m256i q3bits = _mm256_loadu_si256((const __m256i*)q3); q3 += 32;
  1149. // prepare low and high bits
  1150. const __m256i q3l_0 = _mm256_and_si256(q3bits, m3);
  1151. const __m256i q3h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
  1152. ++bit;
  1153. const __m256i q3l_1 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 2), m3);
  1154. const __m256i q3h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
  1155. ++bit;
  1156. const __m256i q3l_2 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 4), m3);
  1157. const __m256i q3h_2 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
  1158. ++bit;
  1159. const __m256i q3l_3 = _mm256_and_si256(_mm256_srli_epi16(q3bits, 6), m3);
  1160. const __m256i q3h_3 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_andnot_si256(hbits, _mm256_slli_epi16(mone, bit)), bit), 2);
  1161. ++bit;
  1162. // load Q8 quants
  1163. const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
  1164. const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
  1165. const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
  1166. const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
  1167. // Dot product: we multiply the 2 low bits and 1 high bit part separately, so we can use _mm256_maddubs_epi16,
  1168. // and then subtract. The high bit part has the 2 already subtracted (and so, it is zero if the high bit was not set,
  1169. // and 2 if the high bit was set)
  1170. __m256i q8s_0 = _mm256_maddubs_epi16(q3h_0, q8_0);
  1171. __m256i q8s_1 = _mm256_maddubs_epi16(q3h_1, q8_1);
  1172. __m256i q8s_2 = _mm256_maddubs_epi16(q3h_2, q8_2);
  1173. __m256i q8s_3 = _mm256_maddubs_epi16(q3h_3, q8_3);
  1174. __m256i p16_0 = _mm256_maddubs_epi16(q3l_0, q8_0);
  1175. __m256i p16_1 = _mm256_maddubs_epi16(q3l_1, q8_1);
  1176. __m256i p16_2 = _mm256_maddubs_epi16(q3l_2, q8_2);
  1177. __m256i p16_3 = _mm256_maddubs_epi16(q3l_3, q8_3);
  1178. p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
  1179. p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
  1180. p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
  1181. p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
  1182. // multiply with scales
  1183. p16_0 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 0)), p16_0);
  1184. p16_1 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 1)), p16_1);
  1185. p16_2 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 2)), p16_2);
  1186. p16_3 = _mm256_madd_epi16(_mm256_shuffle_epi8(scales[j], get_scale_shuffle_q3k(is + 3)), p16_3);
  1187. // accumulate
  1188. p16_0 = _mm256_add_epi32(p16_0, p16_1);
  1189. p16_2 = _mm256_add_epi32(p16_2, p16_3);
  1190. sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_2));
  1191. }
  1192. // multiply with block scale and accumulate
  1193. acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
  1194. }
  1195. *s = hsum_float_8(acc);
  1196. #else
  1197. // scalar version
  1198. // This function is written like this so the compiler can manage to vectorize most of it
  1199. // Using -Ofast, GCC and clang manage to produce code that is within a factor of 2 or so from the
  1200. // manually vectorized version above. Every other version I tried would run at least 4 times slower.
  1201. // The ideal situation would be if we could just write the code once, and the compiler would
  1202. // automatically produce the best possible set of machine instructions, instead of us having to manually
  1203. // write vectorized versions for AVX, ARM_NEON, etc.
  1204. int8_t aux8[QK_K];
  1205. int16_t aux16[8];
  1206. float sums [8];
  1207. int32_t aux32[8];
  1208. memset(sums, 0, 8*sizeof(float));
  1209. uint32_t auxs[4];
  1210. const int8_t * scales = (const int8_t*)auxs;
  1211. float sumf = 0;
  1212. for (int i = 0; i < nb; ++i) {
  1213. const uint8_t * restrict q3 = x[i].qs;
  1214. const uint8_t * restrict hm = x[i].hmask;
  1215. const int8_t * restrict q8 = y[i].qs;
  1216. memset(aux32, 0, 8*sizeof(int32_t));
  1217. int8_t * restrict a = aux8;
  1218. uint8_t m = 1;
  1219. for (int j = 0; j < QK_K; j += 128) {
  1220. for (int l = 0; l < 32; ++l) a[l] = q3[l] & 3;
  1221. for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
  1222. a += 32; m <<= 1;
  1223. for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 2) & 3;
  1224. for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
  1225. a += 32; m <<= 1;
  1226. for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 4) & 3;
  1227. for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
  1228. a += 32; m <<= 1;
  1229. for (int l = 0; l < 32; ++l) a[l] = (q3[l] >> 6) & 3;
  1230. for (int l = 0; l < 32; ++l) a[l] -= (hm[l] & m ? 0 : 4);
  1231. a += 32; m <<= 1;
  1232. q3 += 32;
  1233. }
  1234. a = aux8;
  1235. memcpy(auxs, x[i].scales, 12);
  1236. uint32_t tmp = auxs[2];
  1237. auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
  1238. auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
  1239. auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
  1240. auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
  1241. for (int j = 0; j < QK_K/16; ++j) {
  1242. for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
  1243. for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
  1244. q8 += 8; a += 8;
  1245. for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
  1246. for (int l = 0; l < 8; ++l) aux32[l] += (scales[j] - 32) * aux16[l];
  1247. q8 += 8; a += 8;
  1248. }
  1249. const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
  1250. for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
  1251. }
  1252. for (int l = 0; l < 8; ++l) sumf += sums[l];
  1253. *s = sumf;
  1254. #endif
  1255. }
  1256. void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
  1257. assert(n % QK_K == 0);
  1258. const block_q4_K * restrict x = vx;
  1259. const block_q8_K * restrict y = vy;
  1260. const int nb = n / QK_K;
  1261. static const uint32_t kmask1 = 0x3f3f3f3f;
  1262. static const uint32_t kmask2 = 0x0f0f0f0f;
  1263. static const uint32_t kmask3 = 0x03030303;
  1264. uint32_t utmp[4];
  1265. #ifdef __ARM_NEON
  1266. const uint8x16_t m4b = vdupq_n_u8(0xf);
  1267. #ifdef __ARM_FEATURE_DOTPROD
  1268. const int32x4_t mzero = vdupq_n_s32(0);
  1269. #endif
  1270. int8x16x2_t q4bytes;
  1271. int8x16x2_t q8bytes;
  1272. float sumf = 0;
  1273. for (int i = 0; i < nb; ++i) {
  1274. const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
  1275. const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
  1276. const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
  1277. memcpy(utmp, x[i].scales, 12);
  1278. const uint32x2_t mins8 = {utmp[1] & kmask1, ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4)};
  1279. utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
  1280. utmp[0] &= kmask1;
  1281. const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
  1282. const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
  1283. vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
  1284. sumf -= dmin * vaddvq_s32(prod);
  1285. const uint8_t * scales = (const uint8_t *)utmp;
  1286. const uint8_t * restrict q4 = x[i].qs;
  1287. const int8_t * restrict q8 = y[i].qs;
  1288. //int32x4_t isum = mzero;
  1289. int32_t sumi1 = 0;
  1290. int32_t sumi2 = 0;
  1291. for (int j = 0; j < QK_K/64; ++j) {
  1292. const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32;
  1293. #ifdef __ARM_FEATURE_DOTPROD
  1294. q8bytes = vld1q_s8_x2(q8); q8 += 32;
  1295. q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
  1296. q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
  1297. const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
  1298. sumi1 += vaddvq_s32(p1) * scales[2*j+0];
  1299. q8bytes = vld1q_s8_x2(q8); q8 += 32;
  1300. q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
  1301. q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
  1302. const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
  1303. sumi2 += vaddvq_s32(p2) * scales[2*j+1];
  1304. #else
  1305. q8bytes = vld1q_s8_x2(q8); q8 += 32;
  1306. q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
  1307. q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
  1308. const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
  1309. vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
  1310. const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
  1311. vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
  1312. sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
  1313. q8bytes = vld1q_s8_x2(q8); q8 += 32;
  1314. q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
  1315. q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
  1316. const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
  1317. vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
  1318. const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
  1319. vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
  1320. sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1];
  1321. #endif
  1322. }
  1323. sumf += d * (sumi1 + sumi2);
  1324. }
  1325. *s = sumf;
  1326. #elif defined __AVX2__
  1327. const __m256i m4 = _mm256_set1_epi8(0xF);
  1328. __m256 acc = _mm256_setzero_ps();
  1329. __m128 acc_m = _mm_setzero_ps();
  1330. for (int i = 0; i < nb; ++i) {
  1331. const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
  1332. const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
  1333. const uint8_t * restrict q4 = x[i].qs;
  1334. const int8_t * restrict q8 = y[i].qs;
  1335. memcpy(utmp, x[i].scales, 12);
  1336. utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
  1337. const uint32_t uaux = utmp[1] & kmask1;
  1338. utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
  1339. utmp[2] = uaux;
  1340. utmp[0] &= kmask1;
  1341. const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
  1342. const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
  1343. const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
  1344. const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
  1345. acc_m = _mm_fmadd_ps(_mm_set1_ps(dmin), _mm_cvtepi32_ps(prod), acc_m);
  1346. const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
  1347. const __m256i scales = _mm256_set_m128i(sc128, sc128);
  1348. __m256i sumi = _mm256_setzero_si256();
  1349. for (int j = 0; j < QK_K/64; ++j) {
  1350. const __m256i scale_l = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
  1351. const __m256i scale_h = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
  1352. const __m256i q4bits = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
  1353. const __m256i q4l = _mm256_and_si256(q4bits, m4);
  1354. const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(q4bits, 4), m4);
  1355. const __m256i q8l = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
  1356. __m256i p16l = _mm256_maddubs_epi16(q4l, q8l);
  1357. p16l = _mm256_madd_epi16(scale_l, p16l);
  1358. sumi = _mm256_add_epi32(sumi, p16l);
  1359. const __m256i q8h = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
  1360. __m256i p16h = _mm256_maddubs_epi16(q4h, q8h);
  1361. p16h = _mm256_madd_epi16(scale_h, p16h);
  1362. sumi = _mm256_add_epi32(sumi, p16h);
  1363. }
  1364. __m256 vd = _mm256_set1_ps(d);
  1365. acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
  1366. }
  1367. acc_m = _mm_add_ps(acc_m, _mm_movehl_ps(acc_m, acc_m));
  1368. acc_m = _mm_add_ss(acc_m, _mm_movehdup_ps(acc_m));
  1369. *s = hsum_float_8(acc) + _mm_cvtss_f32(acc_m);
  1370. #else
  1371. const uint8_t * scales = (const uint8_t*)&utmp[0];
  1372. const uint8_t * mins = (const uint8_t*)&utmp[2];
  1373. int8_t aux8[QK_K];
  1374. int16_t aux16[8];
  1375. float sums [8];
  1376. int32_t aux32[8];
  1377. memset(sums, 0, 8*sizeof(float));
  1378. float sumf = 0;
  1379. for (int i = 0; i < nb; ++i) {
  1380. const uint8_t * restrict q4 = x[i].qs;
  1381. const int8_t * restrict q8 = y[i].qs;
  1382. memset(aux32, 0, 8*sizeof(int32_t));
  1383. int8_t * restrict a = aux8;
  1384. for (int j = 0; j < QK_K/64; ++j) {
  1385. for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
  1386. a += 32;
  1387. for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
  1388. a += 32; q4 += 32;
  1389. }
  1390. memcpy(utmp, x[i].scales, 12);
  1391. utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
  1392. const uint32_t uaux = utmp[1] & kmask1;
  1393. utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
  1394. utmp[2] = uaux;
  1395. utmp[0] &= kmask1;
  1396. int sumi = 0;
  1397. for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
  1398. a = aux8;
  1399. int is = 0;
  1400. for (int j = 0; j < QK_K/32; ++j) {
  1401. int32_t scale = scales[is++];
  1402. for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
  1403. for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
  1404. q8 += 8; a += 8;
  1405. for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
  1406. for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
  1407. q8 += 8; a += 8;
  1408. for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
  1409. for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
  1410. q8 += 8; a += 8;
  1411. for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
  1412. for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
  1413. q8 += 8; a += 8;
  1414. }
  1415. const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
  1416. for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
  1417. const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d;
  1418. sumf -= dmin * sumi;
  1419. }
  1420. for (int l = 0; l < 8; ++l) sumf += sums[l];
  1421. *s = sumf;
  1422. #endif
  1423. }
  1424. void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
  1425. assert(n % QK_K == 0);
  1426. const block_q5_K * restrict x = vx;
  1427. const block_q8_K * restrict y = vy;
  1428. const int nb = n / QK_K;
  1429. static const uint32_t kmask1 = 0x3f3f3f3f;
  1430. static const uint32_t kmask2 = 0x0f0f0f0f;
  1431. static const uint32_t kmask3 = 0x03030303;
  1432. uint32_t utmp[4];
  1433. #ifdef __ARM_NEON
  1434. const uint8x16_t m4b = vdupq_n_u8(0xf);
  1435. const int32x4_t mzero = vdupq_n_s32(0);
  1436. const uint8x16_t mone = vdupq_n_u8(1);
  1437. const uint8x16_t mtwo = vdupq_n_u8(2);
  1438. int8x16x4_t q5bytes;
  1439. float sumf = 0;
  1440. for (int i = 0; i < nb; ++i) {
  1441. const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
  1442. const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
  1443. const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
  1444. memcpy(utmp, x[i].scales, 12);
  1445. utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
  1446. const uint32_t uaux = utmp[1] & kmask1;
  1447. utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
  1448. utmp[2] = uaux;
  1449. utmp[0] &= kmask1;
  1450. const uint8x8_t mins8 = vld1_u8((const uint8_t*)utmp + 8);
  1451. const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(mins8));
  1452. const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
  1453. vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
  1454. int32_t sumi_mins = vaddvq_s32(prod);
  1455. const uint8_t * scales = (const uint8_t *)utmp;
  1456. const uint8_t * restrict q5 = x[i].qs;
  1457. const uint8_t * restrict qh = x[i].qh;
  1458. const int8_t * restrict q8 = y[i].qs;
  1459. uint8x16x2_t qhbits = vld1q_u8_x2(qh);
  1460. uint8x16x4_t q5h;
  1461. int32_t sumi = 0;
  1462. for (int j = 0; j < QK_K/64; ++j) {
  1463. const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32;
  1464. const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
  1465. q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
  1466. q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
  1467. q5h.val[2] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[0]), 3);
  1468. q5h.val[3] = vshlq_n_u8(vandq_u8(mtwo, qhbits.val[1]), 3);
  1469. qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 2);
  1470. qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 2);
  1471. q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));
  1472. q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));
  1473. q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
  1474. q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
  1475. #if defined(__ARM_FEATURE_DOTPROD)
  1476. sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * *scales++;
  1477. sumi += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * *scales++;
  1478. #else
  1479. const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
  1480. vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
  1481. const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
  1482. vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
  1483. sumi += vaddvq_s16(vaddq_s16(p0, p1)) * *scales++;
  1484. const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
  1485. vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
  1486. const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
  1487. vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
  1488. sumi += vaddvq_s16(vaddq_s16(p2, p3)) * *scales++;
  1489. #endif
  1490. }
  1491. sumf += d * sumi - dmin * sumi_mins;
  1492. }
  1493. *s = sumf;
  1494. #elif defined __AVX2__
  1495. const __m256i m4 = _mm256_set1_epi8(0xF);
  1496. const __m128i mzero = _mm_setzero_si128();
  1497. const __m256i mone = _mm256_set1_epi8(1);
  1498. __m256 acc = _mm256_setzero_ps();
  1499. float summs = 0.f;
  1500. for (int i = 0; i < nb; ++i) {
  1501. const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
  1502. const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin);
  1503. const uint8_t * restrict q5 = x[i].qs;
  1504. const int8_t * restrict q8 = y[i].qs;
  1505. memcpy(utmp, x[i].scales, 12);
  1506. utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
  1507. const uint32_t uaux = utmp[1] & kmask1;
  1508. utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
  1509. utmp[2] = uaux;
  1510. utmp[0] &= kmask1;
  1511. const __m256i mins_and_scales = _mm256_cvtepu8_epi16(_mm_set_epi32(utmp[3], utmp[2], utmp[1], utmp[0]));
  1512. const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums);
  1513. const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
  1514. const __m128i prod = _mm_madd_epi16(_mm256_extracti128_si256(mins_and_scales, 1), q8s);
  1515. const __m128i hsum = _mm_hadd_epi32(_mm_hadd_epi32(prod, mzero), mzero);
  1516. summs += dmin * _mm_extract_epi32(hsum, 0);
  1517. const __m128i sc128 = _mm256_extracti128_si256(mins_and_scales, 0);
  1518. const __m256i scales = _mm256_set_m128i(sc128, sc128);
  1519. const __m256i hbits = _mm256_loadu_si256((const __m256i*)x[i].qh);
  1520. __m256i hmask = mone;
  1521. __m256i sumi = _mm256_setzero_si256();
  1522. int bit = 0;
  1523. for (int j = 0; j < QK_K/64; ++j) {
  1524. const __m256i scale_0 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+0));
  1525. const __m256i scale_1 = _mm256_shuffle_epi8(scales, get_scale_shuffle_k4(2*j+1));
  1526. const __m256i q5bits = _mm256_loadu_si256((const __m256i*)q5); q5 += 32;
  1527. const __m256i q5l_0 = _mm256_and_si256(q5bits, m4);
  1528. const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
  1529. const __m256i q5_0 = _mm256_add_epi8(q5l_0, q5h_0);
  1530. hmask = _mm256_slli_epi16(hmask, 1);
  1531. const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), m4);
  1532. const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), bit++), 4);
  1533. const __m256i q5_1 = _mm256_add_epi8(q5l_1, q5h_1);
  1534. hmask = _mm256_slli_epi16(hmask, 1);
  1535. const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
  1536. const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
  1537. __m256i p16_0 = _mm256_maddubs_epi16(q5_0, q8_0);
  1538. __m256i p16_1 = _mm256_maddubs_epi16(q5_1, q8_1);
  1539. p16_0 = _mm256_madd_epi16(scale_0, p16_0);
  1540. p16_1 = _mm256_madd_epi16(scale_1, p16_1);
  1541. sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
  1542. }
  1543. __m256 vd = _mm256_set1_ps(d);
  1544. acc = _mm256_fmadd_ps(vd, _mm256_cvtepi32_ps(sumi), acc);
  1545. }
  1546. *s = hsum_float_8(acc) + summs;
  1547. #else
  1548. const uint8_t * scales = (const uint8_t*)&utmp[0];
  1549. const uint8_t * mins = (const uint8_t*)&utmp[2];
  1550. int8_t aux8[QK_K];
  1551. int16_t aux16[8];
  1552. float sums [8];
  1553. int32_t aux32[8];
  1554. memset(sums, 0, 8*sizeof(float));
  1555. float sumf = 0;
  1556. for (int i = 0; i < nb; ++i) {
  1557. const uint8_t * restrict q4 = x[i].qs;
  1558. const uint8_t * restrict hm = x[i].qh;
  1559. const int8_t * restrict q8 = y[i].qs;
  1560. memset(aux32, 0, 8*sizeof(int32_t));
  1561. int8_t * restrict a = aux8;
  1562. uint8_t m = 1;
  1563. for (int j = 0; j < QK_K/64; ++j) {
  1564. for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] & 0xF);
  1565. for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
  1566. a += 32; m <<= 1;
  1567. for (int l = 0; l < 32; ++l) a[l] = (int8_t)(q4[l] >> 4);
  1568. for (int l = 0; l < 32; ++l) a[l] += (hm[l] & m ? 16 : 0);
  1569. a += 32; m <<= 1;
  1570. q4 += 32;
  1571. }
  1572. memcpy(utmp, x[i].scales, 12);
  1573. utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
  1574. const uint32_t uaux = utmp[1] & kmask1;
  1575. utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
  1576. utmp[2] = uaux;
  1577. utmp[0] &= kmask1;
  1578. int sumi = 0;
  1579. for (int j = 0; j < QK_K/16; ++j) sumi += y[i].bsums[j] * mins[j/2];
  1580. a = aux8;
  1581. int is = 0;
  1582. for (int j = 0; j < QK_K/32; ++j) {
  1583. int32_t scale = scales[is++];
  1584. for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
  1585. for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
  1586. q8 += 8; a += 8;
  1587. for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
  1588. for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
  1589. q8 += 8; a += 8;
  1590. for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
  1591. for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
  1592. q8 += 8; a += 8;
  1593. for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
  1594. for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
  1595. q8 += 8; a += 8;
  1596. }
  1597. const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
  1598. for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
  1599. const float dmin = ggml_fp16_to_fp32(x[i].dmin) * y[i].d;
  1600. sumf -= dmin * sumi;
  1601. }
  1602. for (int l = 0; l < 8; ++l) sumf += sums[l];
  1603. *s = sumf;
  1604. #endif
  1605. }
  1606. void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
  1607. assert(n % QK_K == 0);
  1608. const block_q6_K * restrict x = vx;
  1609. const block_q8_K * restrict y = vy;
  1610. const int nb = n / QK_K;
  1611. #ifdef __ARM_NEON
  1612. float sum = 0;
  1613. const uint8x16_t m4b = vdupq_n_u8(0xF);
  1614. const int32x4_t vzero = vdupq_n_s32(0);
  1615. //const int8x16_t m32s = vdupq_n_s8(32);
  1616. const uint8x16_t mone = vdupq_n_u8(3);
  1617. int8x16x4_t q6bytes;
  1618. uint8x16x4_t q6h;
  1619. for (int i = 0; i < nb; ++i) {
  1620. const float d_all = ggml_fp16_to_fp32(x[i].d);
  1621. const uint8_t * restrict q6 = x[i].ql;
  1622. const uint8_t * restrict qh = x[i].qh;
  1623. const int8_t * restrict q8 = y[i].qs;
  1624. const int8_t * restrict scale = x[i].scales;
  1625. const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
  1626. const int8x16_t scales = vld1q_s8(scale);
  1627. const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
  1628. const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
  1629. vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
  1630. vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[1]), vget_low_s16 (q6scales.val[1])),
  1631. vmull_s16(vget_high_s16(q8sums.val[1]), vget_high_s16(q6scales.val[1]))));
  1632. int32_t isum_mins = vaddvq_s32(prod);
  1633. int32_t isum = 0;
  1634. for (int j = 0; j < QK_K/128; ++j) {
  1635. uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32;
  1636. uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64;
  1637. int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
  1638. q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
  1639. q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
  1640. uint8x16_t shifted = vshrq_n_u8(qhbits.val[0], 2);
  1641. q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
  1642. shifted = vshrq_n_u8(qhbits.val[1], 2);
  1643. q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
  1644. //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0])), m32s);
  1645. //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1])), m32s);
  1646. //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2])), m32s);
  1647. //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3])), m32s);
  1648. q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[0], m4b), q6h.val[0]));
  1649. q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[1], m4b), q6h.val[1]));
  1650. q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[2], m4b), q6h.val[2]));
  1651. q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q6bits.val[3], m4b), q6h.val[3]));
  1652. #if defined(__ARM_FEATURE_DOTPROD)
  1653. isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
  1654. vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
  1655. vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
  1656. vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
  1657. scale += 4;
  1658. #else
  1659. int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
  1660. vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
  1661. int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
  1662. vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
  1663. isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
  1664. scale += 2;
  1665. int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
  1666. vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
  1667. int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
  1668. vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
  1669. isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
  1670. scale += 2;
  1671. #endif
  1672. q8bytes = vld1q_s8_x4(q8); q8 += 64;
  1673. shifted = vshrq_n_u8(qhbits.val[0], 4);
  1674. q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
  1675. shifted = vshrq_n_u8(qhbits.val[1], 4);
  1676. q6h.val[1] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
  1677. shifted = vshrq_n_u8(qhbits.val[0], 6);
  1678. q6h.val[2] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
  1679. shifted = vshrq_n_u8(qhbits.val[1], 6);
  1680. q6h.val[3] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
  1681. //q6bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0])), m32s);
  1682. //q6bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1])), m32s);
  1683. //q6bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2])), m32s);
  1684. //q6bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3])), m32s);
  1685. q6bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[0], 4), q6h.val[0]));
  1686. q6bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[1], 4), q6h.val[1]));
  1687. q6bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[2], 4), q6h.val[2]));
  1688. q6bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6bits.val[3], 4), q6h.val[3]));
  1689. #if defined(__ARM_FEATURE_DOTPROD)
  1690. isum += vaddvq_s32(vdotq_s32(vzero, q6bytes.val[0], q8bytes.val[0])) * scale[0] +
  1691. vaddvq_s32(vdotq_s32(vzero, q6bytes.val[1], q8bytes.val[1])) * scale[1] +
  1692. vaddvq_s32(vdotq_s32(vzero, q6bytes.val[2], q8bytes.val[2])) * scale[2] +
  1693. vaddvq_s32(vdotq_s32(vzero, q6bytes.val[3], q8bytes.val[3])) * scale[3];
  1694. scale += 4;
  1695. //for (int l = 0; l < 4; ++l) {
  1696. // const int32x4_t p = vdotq_s32(vzero, q6bytes.val[l], q8bytes.val[l]);
  1697. // isum += vaddvq_s32(p) * *scale++;
  1698. //}
  1699. #else
  1700. p0 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
  1701. vmull_s8(vget_high_s8(q6bytes.val[0]), vget_high_s8(q8bytes.val[0])));
  1702. p1 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
  1703. vmull_s8(vget_high_s8(q6bytes.val[1]), vget_high_s8(q8bytes.val[1])));
  1704. isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1];
  1705. scale += 2;
  1706. p2 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
  1707. vmull_s8(vget_high_s8(q6bytes.val[2]), vget_high_s8(q8bytes.val[2])));
  1708. p3 = vaddq_s16(vmull_s8(vget_low_s8 (q6bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
  1709. vmull_s8(vget_high_s8(q6bytes.val[3]), vget_high_s8(q8bytes.val[3])));
  1710. isum += vaddvq_s16(p2) * scale[0] + vaddvq_s16(p3) * scale[1];
  1711. scale += 2;
  1712. #endif
  1713. }
  1714. //sum += isum * d_all * y[i].d;
  1715. sum += d_all * y[i].d * (isum - 32 * isum_mins);
  1716. }
  1717. *s = sum;
  1718. #elif defined __AVX2__
  1719. const __m256i m4 = _mm256_set1_epi8(0xF);
  1720. const __m256i m2 = _mm256_set1_epi8(3);
  1721. const __m256i m32s = _mm256_set1_epi8(32);
  1722. __m256 acc = _mm256_setzero_ps();
  1723. for (int i = 0; i < nb; ++i) {
  1724. const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
  1725. const uint8_t * restrict q4 = x[i].ql;
  1726. const uint8_t * restrict qh = x[i].qh;
  1727. const int8_t * restrict q8 = y[i].qs;
  1728. const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales);
  1729. __m256i sumi = _mm256_setzero_si256();
  1730. int is = 0;
  1731. for (int j = 0; j < QK_K/128; ++j) {
  1732. const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0));
  1733. const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1));
  1734. const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2));
  1735. const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3));
  1736. is += 4;
  1737. const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
  1738. const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32;
  1739. const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32;
  1740. const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4);
  1741. const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4);
  1742. const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4);
  1743. const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4);
  1744. const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0);
  1745. const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1);
  1746. const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2);
  1747. const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3);
  1748. const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
  1749. const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
  1750. const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
  1751. const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
  1752. __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0);
  1753. __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1);
  1754. __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2);
  1755. __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3);
  1756. __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0);
  1757. __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1);
  1758. __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2);
  1759. __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3);
  1760. p16_0 = _mm256_sub_epi16(p16_0, q8s_0);
  1761. p16_1 = _mm256_sub_epi16(p16_1, q8s_1);
  1762. p16_2 = _mm256_sub_epi16(p16_2, q8s_2);
  1763. p16_3 = _mm256_sub_epi16(p16_3, q8s_3);
  1764. p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0);
  1765. p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1);
  1766. p16_2 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_2), p16_2);
  1767. p16_3 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_3), p16_3);
  1768. sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_0, p16_1));
  1769. sumi = _mm256_add_epi32(sumi, _mm256_add_epi32(p16_2, p16_3));
  1770. }
  1771. acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc);
  1772. }
  1773. *s = hsum_float_8(acc);
  1774. #else
  1775. int8_t aux8[QK_K];
  1776. int16_t aux16[8];
  1777. float sums [8];
  1778. int32_t aux32[8];
  1779. memset(sums, 0, 8*sizeof(float));
  1780. float sumf = 0;
  1781. for (int i = 0; i < nb; ++i) {
  1782. const uint8_t * restrict q4 = x[i].ql;
  1783. const uint8_t * restrict qh = x[i].qh;
  1784. const int8_t * restrict q8 = y[i].qs;
  1785. memset(aux32, 0, 8*sizeof(int32_t));
  1786. int8_t * restrict a = aux8;
  1787. for (int j = 0; j < QK_K; j += 128) {
  1788. for (int l = 0; l < 32; ++l) {
  1789. a[l + 0] = (int8_t)((q4[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
  1790. a[l + 32] = (int8_t)((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
  1791. a[l + 64] = (int8_t)((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
  1792. a[l + 96] = (int8_t)((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
  1793. }
  1794. a += 128;
  1795. q4 += 64;
  1796. qh += 32;
  1797. }
  1798. a = aux8;
  1799. int is = 0;
  1800. for (int j = 0; j < QK_K/16; ++j) {
  1801. int scale = x[i].scales[is++];
  1802. for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
  1803. for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
  1804. q8 += 8; a += 8;
  1805. for (int l = 0; l < 8; ++l) aux16[l] = q8[l] * a[l];
  1806. for (int l = 0; l < 8; ++l) aux32[l] += scale * aux16[l];
  1807. q8 += 8; a += 8;
  1808. }
  1809. const float d = ggml_fp16_to_fp32(x[i].d) * y[i].d;
  1810. for (int l = 0; l < 8; ++l) sums[l] += d * aux32[l];
  1811. }
  1812. for (int l = 0; l < 8; ++l) sumf += sums[l];
  1813. *s = sumf;
  1814. #endif
  1815. }