repack.cpp 92 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335
  1. #define GGML_COMMON_IMPL_CPP
  2. #define GGML_COMMON_DECL_CPP
  3. #include "ggml-common.h"
  4. #include "ggml-backend-impl.h"
  5. #include "ggml-impl.h"
  6. #include "ggml-cpu.h"
  7. #include "ggml-cpu-impl.h"
  8. #include "simd-mappings.h"
  9. #include "traits.h"
  10. #include "arch-fallback.h"
  11. #include <cmath>
  12. #include <cstring>
  13. #include <cassert>
  14. #include <cstdio> // for GGML_ASSERT
  15. #include "repack.h"
  16. #if defined(__GNUC__)
  17. #pragma GCC diagnostic ignored "-Woverlength-strings"
  18. #endif
  19. #define UNUSED GGML_UNUSED
  20. static inline int nearest_int(float fval) {
  21. assert(fabsf(fval) <= 4194303.f);
  22. float val = fval + 12582912.f;
  23. int i; memcpy(&i, &val, sizeof(int));
  24. return (i & 0x007fffff) - 0x00400000;
  25. }
  26. // Functions to create the interleaved data layout formats
  27. // interleave 4 block_q4_0s in blocks of blck_size_interleave
  28. // returns an interleaved block_q4_0x4
  29. // in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks
  30. // first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave
  31. //
  32. // - in : an array of block_q4_0 pointers
  33. // - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of
  34. // blck_size_interleave bytes
  35. // - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes
  36. // from bias offset form to pure sign form (this saves subtract
  37. // operations durin unpacking)
  38. //
  39. extern "C" {
  40. void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
  41. assert(QK8_0 == 32);
  42. assert(k % QK8_0 == 0);
  43. const int nb = k / QK8_0;
  44. block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
  45. // scalar
  46. const int blck_size_interleave = 4;
  47. float srcv[4][QK8_0];
  48. float id[4];
  49. for (int i = 0; i < nb; i++) {
  50. for (int row_iter = 0; row_iter < 4; row_iter++) {
  51. float amax = 0.0f; // absolute max
  52. for (int j = 0; j < QK8_0; j++) {
  53. srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
  54. amax = MAX(amax, fabsf(srcv[row_iter][j]));
  55. }
  56. const float d = amax / ((1 << 7) - 1);
  57. id[row_iter] = d ? 1.0f / d : 0.0f;
  58. y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);
  59. }
  60. for (int j = 0; j < QK8_0 * 4; j++) {
  61. int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
  62. int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
  63. src_offset += (j % blck_size_interleave);
  64. float x0 = srcv[src_id][src_offset] * id[src_id];
  65. y[i].qs[j] = roundf(x0);
  66. }
  67. }
  68. }
  69. void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
  70. assert(QK8_0 == 32);
  71. assert(k % QK8_0 == 0);
  72. const int nb = k / QK8_0;
  73. block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy;
  74. // scalar
  75. const int blck_size_interleave = 8;
  76. float srcv[4][QK8_0];
  77. float id[4];
  78. for (int i = 0; i < nb; i++) {
  79. for (int row_iter = 0; row_iter < 4; row_iter++) {
  80. float amax = 0.0f; // absolute max
  81. for (int j = 0; j < QK8_0; j++) {
  82. srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j];
  83. amax = MAX(amax, fabsf(srcv[row_iter][j]));
  84. }
  85. const float d = amax / ((1 << 7) - 1);
  86. id[row_iter] = d ? 1.0f / d : 0.0f;
  87. y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d);
  88. }
  89. for (int j = 0; j < QK8_0 * 4; j++) {
  90. int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
  91. int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
  92. src_offset += (j % blck_size_interleave);
  93. float x0 = srcv[src_id][src_offset] * id[src_id];
  94. y[i].qs[j] = roundf(x0);
  95. }
  96. }
  97. }
  98. void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
  99. assert(QK_K == 256);
  100. assert(k % QK_K == 0);
  101. const int nb = k / QK_K;
  102. block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
  103. // scalar
  104. const int blck_size_interleave = 4;
  105. float srcv[4][QK_K];
  106. float iscale[4];
  107. for (int i = 0; i < nb; i++) {
  108. for (int row_iter = 0; row_iter < 4; row_iter++) {
  109. float amax = 0.0f; // absolute max
  110. float max = 0;
  111. for (int j = 0; j < QK_K; j++) {
  112. srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
  113. // Update the maximum value of the corresponding super block
  114. if(amax < fabsf(srcv[row_iter][j])) {
  115. amax = fabsf(srcv[row_iter][j]);
  116. max = srcv[row_iter][j];
  117. }
  118. }
  119. iscale[row_iter] = amax ? -127.f/max : 0;
  120. y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
  121. }
  122. for (int j = 0; j < QK_K / 4; j++) {
  123. y[i].bsums[j] = 0;
  124. }
  125. // Quants values are interleaved in sequence of four bytes from corresponding super blocks
  126. // Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
  127. // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
  128. for (int j = 0; j < QK_K * 4; j++) {
  129. int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
  130. int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
  131. src_offset += (j % blck_size_interleave);
  132. int index = (((j & 15) >> 2) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
  133. float x0 = srcv[src_id][src_offset] * iscale[src_id];
  134. y[i].qs[j] = nearest_int(x0);
  135. y[i].bsums[index] += y[i].qs[j];
  136. }
  137. }
  138. }
  139. void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
  140. assert(QK_K == 256);
  141. assert(k % QK_K == 0);
  142. const int nb = k / QK_K;
  143. block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
  144. // scalar
  145. const int blck_size_interleave = 8;
  146. float srcv[4][QK_K];
  147. float iscale[4];
  148. for (int i = 0; i < nb; i++) {
  149. for (int row_iter = 0; row_iter < 4; row_iter++) {
  150. float amax = 0.0f; // absolute max
  151. float max = 0;
  152. for (int j = 0; j < QK_K; j++) {
  153. srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
  154. // Update the maximum value of the corresponding super block
  155. if(amax < fabsf(srcv[row_iter][j])) {
  156. amax = fabsf(srcv[row_iter][j]);
  157. max = srcv[row_iter][j];
  158. }
  159. }
  160. iscale[row_iter] = amax ? -127.f/max : 0;
  161. y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
  162. }
  163. for (int j = 0; j < QK_K / 4; j++) {
  164. y[i].bsums[j] = 0;
  165. }
  166. // Quants values are interleaved in sequence of eight bytes from corresponding super blocks
  167. // Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
  168. // i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
  169. for (int j = 0; j < QK_K * 4; j++) {
  170. int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
  171. int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
  172. src_offset += (j % blck_size_interleave);
  173. int index = (((j & 31) >> 3) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
  174. float x0 = srcv[src_id][src_offset] * iscale[src_id];
  175. y[i].qs[j] = nearest_int(x0);
  176. y[i].bsums[index] += y[i].qs[j];
  177. }
  178. }
  179. }
  180. } // extern "C"
  181. template <int64_t INTER_SIZE, ggml_type PARAM_TYPE>
  182. void ggml_quantize_mat_t(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row);
  183. template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
  184. assert(nrow == 4);
  185. UNUSED(nrow);
  186. ggml_quantize_mat_q8_0_4x4(x, vy, n_per_row);
  187. }
  188. template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
  189. assert(nrow == 4);
  190. UNUSED(nrow);
  191. ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
  192. }
  193. template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
  194. assert(nrow == 4);
  195. UNUSED(nrow);
  196. ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row);
  197. }
  198. template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
  199. assert(nrow == 4);
  200. UNUSED(nrow);
  201. ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
  202. }
  203. extern "C" {
  204. void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
  205. const int qk = QK8_0;
  206. const int nb = n / qk;
  207. const int ncols_interleaved = 4;
  208. const int blocklen = 4;
  209. assert(nr == 1);
  210. assert(n % qk == 0);
  211. assert(nc % ncols_interleaved == 0);
  212. UNUSED(s);
  213. UNUSED(bs);
  214. UNUSED(vx);
  215. UNUSED(vy);
  216. UNUSED(nr);
  217. UNUSED(nc);
  218. UNUSED(nb);
  219. UNUSED(ncols_interleaved);
  220. UNUSED(blocklen);
  221. float sumf[4];
  222. int sumi;
  223. const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
  224. for (int x = 0; x < nc / ncols_interleaved; x++) {
  225. const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
  226. for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
  227. for (int l = 0; l < nb; l++) {
  228. for (int k = 0; k < (qk / (2 * blocklen)); k++) {
  229. for (int j = 0; j < ncols_interleaved; j++) {
  230. sumi = 0;
  231. for (int i = 0; i < blocklen; ++i) {
  232. const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
  233. const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
  234. sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
  235. }
  236. sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
  237. }
  238. }
  239. }
  240. for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
  241. }
  242. }
  243. void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
  244. const int qk = QK8_0;
  245. const int nb = n / qk;
  246. const int ncols_interleaved = 4;
  247. const int blocklen = 8;
  248. assert (n % qk == 0);
  249. assert (nc % ncols_interleaved == 0);
  250. UNUSED(s);
  251. UNUSED(bs);
  252. UNUSED(vx);
  253. UNUSED(vy);
  254. UNUSED(nr);
  255. UNUSED(nc);
  256. UNUSED(nb);
  257. UNUSED(ncols_interleaved);
  258. UNUSED(blocklen);
  259. float sumf[4];
  260. int sumi;
  261. const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
  262. for (int x = 0; x < nc / ncols_interleaved; x++) {
  263. const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
  264. for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
  265. for (int l = 0; l < nb; l++) {
  266. for (int k = 0; k < (qk / (2 * blocklen)); k++) {
  267. for (int j = 0; j < ncols_interleaved; j++) {
  268. sumi = 0;
  269. for (int i = 0; i < blocklen; ++i) {
  270. const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
  271. const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
  272. sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
  273. }
  274. sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
  275. }
  276. }
  277. }
  278. for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
  279. }
  280. }
  281. void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
  282. const int qk = QK8_0;
  283. const int nb = n / qk;
  284. const int ncols_interleaved = 8;
  285. const int blocklen = 8;
  286. assert (n % qk == 0);
  287. assert (nc % ncols_interleaved == 0);
  288. UNUSED(s);
  289. UNUSED(bs);
  290. UNUSED(vx);
  291. UNUSED(vy);
  292. UNUSED(nr);
  293. UNUSED(nc);
  294. UNUSED(nb);
  295. UNUSED(ncols_interleaved);
  296. UNUSED(blocklen);
  297. float sumf[8];
  298. int sumi;
  299. const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
  300. for (int x = 0; x < nc / ncols_interleaved; x++) {
  301. const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
  302. for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
  303. for (int l = 0; l < nb; l++) {
  304. for (int k = 0; k < (qk / (2 * blocklen)); k++) {
  305. for (int j = 0; j < ncols_interleaved; j++) {
  306. sumi = 0;
  307. for (int i = 0; i < blocklen; ++i) {
  308. const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
  309. const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
  310. sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4;
  311. }
  312. sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
  313. }
  314. }
  315. }
  316. for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
  317. }
  318. }
  319. void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
  320. const int qk = QK_K;
  321. const int nb = n / qk;
  322. const int ncols_interleaved = 8;
  323. const int blocklen = 4;
  324. static const uint32_t kmask1 = 0x3f3f3f3f;
  325. static const uint32_t kmask2 = 0x0f0f0f0f;
  326. static const uint32_t kmask3 = 0x03030303;
  327. assert (n % qk == 0);
  328. assert (nc % ncols_interleaved == 0);
  329. UNUSED(bs);
  330. UNUSED(nr);
  331. float sumf[8];
  332. float sum_minf[8];
  333. uint32_t utmp[32];
  334. int sumi1;
  335. int sumi2;
  336. int sumi;
  337. const block_q8_K * a_ptr = (const block_q8_K *) vy;
  338. for (int x = 0; x < nc / ncols_interleaved; x++) {
  339. const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
  340. for (int j = 0; j < ncols_interleaved; j++) {
  341. sumf[j] = 0.0;
  342. sum_minf[j] = 0.0;
  343. }
  344. for (int l = 0; l < nb; l++) {
  345. for (int sb = 0; sb < 8; sb++) {
  346. memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
  347. utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
  348. const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
  349. utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
  350. utmp[sb * 4 + 2] = uaux_0;
  351. utmp[sb * 4 + 0] &= kmask1;
  352. }
  353. for (int k = 0; k < (qk / (2 * blocklen)); k++) {
  354. uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
  355. uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
  356. for (int j = 0; j < ncols_interleaved; j++) {
  357. sumi1 = 0;
  358. sumi2 = 0;
  359. sumi = 0;
  360. for (int i = 0; i < blocklen; ++i) {
  361. const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
  362. const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
  363. sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i]);
  364. sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i + 32]);
  365. sumi1 = sumi1 * scales_0[j];
  366. sumi2 = sumi2 * scales_1[j];
  367. sumi += sumi1 + sumi2;
  368. }
  369. sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
  370. }
  371. }
  372. for (int sb = 0; sb < 8; sb++) {
  373. uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
  374. for (int j = 0; j < ncols_interleaved; j++) {
  375. sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
  376. }
  377. }
  378. }
  379. for (int j = 0; j < ncols_interleaved; j++) {
  380. s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
  381. }
  382. }
  383. }
  384. void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
  385. const int qk = QK_K;
  386. const int nb = n / qk;
  387. const int ncols_interleaved = 8;
  388. const int blocklen = 8;
  389. static const uint32_t kmask1 = 0x3f3f3f3f;
  390. static const uint32_t kmask2 = 0x0f0f0f0f;
  391. static const uint32_t kmask3 = 0x03030303;
  392. assert (n % qk == 0);
  393. assert (nc % ncols_interleaved == 0);
  394. UNUSED(s);
  395. UNUSED(bs);
  396. UNUSED(vx);
  397. UNUSED(vy);
  398. UNUSED(nr);
  399. UNUSED(nc);
  400. UNUSED(nb);
  401. UNUSED(ncols_interleaved);
  402. UNUSED(blocklen);
  403. float sumf[8];
  404. float sum_minf[8];
  405. uint32_t utmp[32];
  406. int sumi1;
  407. int sumi2;
  408. int sumi;
  409. const block_q8_K * a_ptr = (const block_q8_K *) vy;
  410. for (int x = 0; x < nc / ncols_interleaved; x++) {
  411. const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
  412. for (int j = 0; j < ncols_interleaved; j++) {
  413. sumf[j] = 0.0;
  414. sum_minf[j] = 0.0;
  415. }
  416. for (int l = 0; l < nb; l++) {
  417. for (int sb = 0; sb < 8; sb++) {
  418. memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
  419. utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
  420. const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
  421. utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
  422. utmp[sb * 4 + 2] = uaux_0;
  423. utmp[sb * 4 + 0] &= kmask1;
  424. }
  425. for (int k = 0; k < (qk / (2 * blocklen)); k++) {
  426. uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
  427. uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
  428. for (int j = 0; j < ncols_interleaved; j++) {
  429. sumi1 = 0;
  430. sumi2 = 0;
  431. sumi = 0;
  432. for (int i = 0; i < blocklen; ++i) {
  433. const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
  434. const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
  435. sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i]);
  436. sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i + 32]);
  437. sumi1 = sumi1 * scales_0[j];
  438. sumi2 = sumi2 * scales_1[j];
  439. sumi += sumi1 + sumi2;
  440. }
  441. sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
  442. }
  443. }
  444. for (int sb = 0; sb < 8; sb++) {
  445. uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
  446. for (int j = 0; j < ncols_interleaved; j++) {
  447. sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
  448. }
  449. }
  450. }
  451. for (int j = 0; j < ncols_interleaved; j++) {
  452. s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
  453. }
  454. }
  455. }
  456. void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
  457. const int qk = QK_K;
  458. const int nb = n / qk;
  459. const int ncols_interleaved = 8;
  460. const int blocklen = 8;
  461. assert (n % qk == 0);
  462. assert (nc % ncols_interleaved == 0);
  463. UNUSED(s);
  464. UNUSED(bs);
  465. UNUSED(vx);
  466. UNUSED(vy);
  467. UNUSED(nr);
  468. UNUSED(nc);
  469. UNUSED(nb);
  470. UNUSED(ncols_interleaved);
  471. UNUSED(blocklen);
  472. float sumf[8];
  473. float sum_minf[8];
  474. int sumi1,sumi2,sumi3,sumi4;
  475. int sumi;
  476. const block_q8_K * a_ptr = (const block_q8_K *)vy;
  477. for(int x = 0; x < nc / ncols_interleaved; x++) {
  478. const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
  479. for (int j = 0; j < ncols_interleaved; j++) {
  480. sumf[j] = 0.0;
  481. sum_minf[j] = 0.0;
  482. }
  483. for (int l = 0; l < nb; l++) {
  484. for (int k = 0; k < (qk / (4 * blocklen)); k++) {
  485. const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ;
  486. const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16;
  487. const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32;
  488. const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48;
  489. for (int j = 0; j < ncols_interleaved; j++) {
  490. sumi1 = 0;
  491. sumi2 = 0;
  492. sumi3 = 0;
  493. sumi4 = 0;
  494. sumi = 0;
  495. int offset = ((k / 2) % 2) + j * 2;
  496. for (int i = 0; i < blocklen; ++i){
  497. const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3);
  498. const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3);
  499. const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3);
  500. const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3);
  501. sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i]);
  502. sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 32]);
  503. sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 64]);
  504. sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 96]);
  505. sumi1 = sumi1 * (scales_0[offset] & 0xF);
  506. sumi2 = sumi2 * (scales_1[offset] & 0xF);
  507. sumi3 = sumi3 * (scales_2[offset] & 0xF);
  508. sumi4 = sumi4 * (scales_3[offset] & 0xF);
  509. sumi += sumi1 + sumi2 + sumi3 + sumi4;
  510. }
  511. sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
  512. }
  513. }
  514. for(int sb = 0; sb < 8; sb++) {
  515. const uint8_t *mins = b_ptr[l].scales + sb * 16;
  516. for(int j = 0; j < ncols_interleaved; j++){
  517. sum_minf[j] += ((mins[j * 2] >> 4) * a_ptr[l].bsums[sb * 2] + (mins[(j * 2)+ 1] >> 4) * a_ptr[l].bsums[sb * 2 + 1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
  518. }
  519. }
  520. }
  521. for (int j = 0; j < ncols_interleaved; j++) {
  522. s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
  523. }
  524. }
  525. }
  526. void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
  527. const int qk = QK8_0;
  528. const int nb = n / qk;
  529. const int ncols_interleaved = 4;
  530. const int blocklen = 4;
  531. assert(nr == 1);
  532. assert(n % qk == 0);
  533. assert(nc % ncols_interleaved == 0);
  534. UNUSED(bs);
  535. UNUSED(nr);
  536. float sumf[4];
  537. int sumi;
  538. const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
  539. for (int x = 0; x < nc / ncols_interleaved; x++) {
  540. const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
  541. for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
  542. for (int l = 0; l < nb; l++) {
  543. for (int k = 0; k < (qk / (2 * blocklen)); k++) {
  544. for (int j = 0; j < ncols_interleaved; j++) {
  545. sumi = 0;
  546. for (int i = 0; i < blocklen; ++i) {
  547. const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
  548. const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
  549. sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
  550. }
  551. sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
  552. }
  553. }
  554. }
  555. for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
  556. }
  557. }
  558. void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
  559. const int qk = QK8_0;
  560. const int nb = n / qk;
  561. const int ncols_interleaved = 8;
  562. const int blocklen = 8;
  563. assert(nr == 1);
  564. assert(n % qk == 0);
  565. assert(nc % ncols_interleaved == 0);
  566. UNUSED(bs);
  567. UNUSED(nr);
  568. float sumf[8];
  569. int sumi;
  570. const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
  571. for (int x = 0; x < nc / ncols_interleaved; x++) {
  572. const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb);
  573. for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
  574. for (int l = 0; l < nb; l++) {
  575. for (int k = 0; k < (qk / (2 * blocklen)); k++) {
  576. for (int j = 0; j < ncols_interleaved; j++) {
  577. sumi = 0;
  578. for (int i = 0; i < blocklen; ++i) {
  579. const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
  580. const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
  581. sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
  582. }
  583. sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
  584. }
  585. }
  586. }
  587. for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
  588. }
  589. }
  590. void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
  591. const int qk = QK8_0;
  592. const int nb = n / qk;
  593. const int ncols_interleaved = 4;
  594. const int blocklen = 4;
  595. assert (n % qk == 0);
  596. assert (nr % 4 == 0);
  597. assert (nc % ncols_interleaved == 0);
  598. UNUSED(s);
  599. UNUSED(bs);
  600. UNUSED(vx);
  601. UNUSED(vy);
  602. UNUSED(nr);
  603. UNUSED(nc);
  604. UNUSED(nb);
  605. UNUSED(ncols_interleaved);
  606. UNUSED(blocklen);
  607. {
  608. float sumf[4][4];
  609. int sumi;
  610. for (int y = 0; y < nr / 4; y++) {
  611. const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
  612. for (int x = 0; x < nc / ncols_interleaved; x++) {
  613. const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
  614. for (int m = 0; m < 4; m++) {
  615. for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
  616. }
  617. for (int l = 0; l < nb; l++) {
  618. for (int k = 0; k < (qk / (2 * blocklen)); k++) {
  619. for (int m = 0; m < 4; m++) {
  620. for (int j = 0; j < ncols_interleaved; j++) {
  621. sumi = 0;
  622. for (int i = 0; i < blocklen; ++i) {
  623. const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
  624. const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
  625. sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
  626. (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
  627. }
  628. sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
  629. }
  630. }
  631. }
  632. }
  633. for (int m = 0; m < 4; m++) {
  634. for (int j = 0; j < ncols_interleaved; j++)
  635. s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
  636. }
  637. }
  638. }
  639. }
  640. }
  641. void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
  642. const int qk = QK8_0;
  643. const int nb = n / qk;
  644. const int ncols_interleaved = 4;
  645. const int blocklen = 8;
  646. assert (n % qk == 0);
  647. assert (nr % 4 == 0);
  648. assert (nc % ncols_interleaved == 0);
  649. UNUSED(s);
  650. UNUSED(bs);
  651. UNUSED(vx);
  652. UNUSED(vy);
  653. UNUSED(nr);
  654. UNUSED(nc);
  655. UNUSED(nb);
  656. UNUSED(ncols_interleaved);
  657. UNUSED(blocklen);
  658. float sumf[4][4];
  659. int sumi;
  660. for (int y = 0; y < nr / 4; y++) {
  661. const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
  662. for (int x = 0; x < nc / ncols_interleaved; x++) {
  663. const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb);
  664. for (int m = 0; m < 4; m++) {
  665. for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
  666. }
  667. for (int l = 0; l < nb; l++) {
  668. for (int k = 0; k < (qk / (2 * blocklen)); k++) {
  669. for (int m = 0; m < 4; m++) {
  670. for (int j = 0; j < ncols_interleaved; j++) {
  671. sumi = 0;
  672. for (int i = 0; i < blocklen; ++i) {
  673. const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
  674. const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
  675. sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
  676. (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
  677. }
  678. sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
  679. }
  680. }
  681. }
  682. }
  683. for (int m = 0; m < 4; m++) {
  684. for (int j = 0; j < ncols_interleaved; j++)
  685. s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
  686. }
  687. }
  688. }
  689. }
  690. void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
  691. const int qk = QK8_0;
  692. const int nb = n / qk;
  693. const int ncols_interleaved = 8;
  694. const int blocklen = 8;
  695. assert (n % qk == 0);
  696. assert (nr % 4 == 0);
  697. assert (nc % ncols_interleaved == 0);
  698. UNUSED(s);
  699. UNUSED(bs);
  700. UNUSED(vx);
  701. UNUSED(vy);
  702. UNUSED(nr);
  703. UNUSED(nc);
  704. UNUSED(nb);
  705. UNUSED(ncols_interleaved);
  706. UNUSED(blocklen);
  707. float sumf[4][8];
  708. int sumi;
  709. for (int y = 0; y < nr / 4; y++) {
  710. const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
  711. for (int x = 0; x < nc / ncols_interleaved; x++) {
  712. const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb);
  713. for (int m = 0; m < 4; m++) {
  714. for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
  715. }
  716. for (int l = 0; l < nb; l++) {
  717. for (int k = 0; k < (qk / (2 * blocklen)); k++) {
  718. for (int m = 0; m < 4; m++) {
  719. for (int j = 0; j < ncols_interleaved; j++) {
  720. sumi = 0;
  721. for (int i = 0; i < blocklen; ++i) {
  722. const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4);
  723. const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0);
  724. sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
  725. (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4;
  726. }
  727. sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
  728. }
  729. }
  730. }
  731. }
  732. for (int m = 0; m < 4; m++) {
  733. for (int j = 0; j < ncols_interleaved; j++)
  734. s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
  735. }
  736. }
  737. }
  738. }
  739. void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
  740. const int qk = QK_K;
  741. const int nb = n / qk;
  742. const int ncols_interleaved = 8;
  743. const int blocklen = 4;
  744. static const uint32_t kmask1 = 0x3f3f3f3f;
  745. static const uint32_t kmask2 = 0x0f0f0f0f;
  746. static const uint32_t kmask3 = 0x03030303;
  747. assert (n % qk == 0);
  748. assert (nr % 4 == 0);
  749. assert (nc % ncols_interleaved == 0);
  750. UNUSED(nb);
  751. UNUSED(ncols_interleaved);
  752. UNUSED(blocklen);
  753. float sumf[4][8];
  754. float sum_minf[4][8];
  755. uint32_t utmp[32];
  756. int sumi1;
  757. int sumi2;
  758. int sumi;
  759. for (int y = 0; y < nr / 4; y++) {
  760. const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
  761. for (int x = 0; x < nc / ncols_interleaved; x++) {
  762. const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
  763. for (int m = 0; m < 4; m++) {
  764. for (int j = 0; j < ncols_interleaved; j++) {
  765. sumf[m][j] = 0.0;
  766. sum_minf[m][j] = 0.0;
  767. }
  768. }
  769. for (int l = 0; l < nb; l++) {
  770. for (int sb = 0; sb < 8; sb++) {
  771. memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
  772. utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
  773. const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
  774. utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
  775. utmp[sb * 4 + 2] = uaux_0;
  776. utmp[sb * 4 + 0] &= kmask1;
  777. }
  778. for (int k = 0; k < (qk / (2 * blocklen)); k++) {
  779. uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
  780. uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
  781. for (int m = 0; m < 4; m++) {
  782. for (int j = 0; j < ncols_interleaved; j++) {
  783. sumi1 = 0;
  784. sumi2 = 0;
  785. sumi = 0;
  786. for (int i = 0; i < blocklen; ++i) {
  787. const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
  788. const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
  789. sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i]);
  790. sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i + 128]);
  791. sumi1 = sumi1 * scales_0[j];
  792. sumi2 = sumi2 * scales_1[j];
  793. sumi += sumi1 + sumi2;
  794. }
  795. sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
  796. }
  797. }
  798. }
  799. for (int sb = 0; sb < 8; sb++) {
  800. uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
  801. for(int m = 0; m < 4; m++) {
  802. const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
  803. for(int j = 0; j < ncols_interleaved; j++) {
  804. sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
  805. }
  806. }
  807. }
  808. }
  809. for (int m = 0; m < 4; m++) {
  810. for (int j = 0; j < ncols_interleaved; j++) {
  811. s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
  812. }
  813. }
  814. }
  815. }
  816. }
  817. void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
  818. const int qk = QK_K;
  819. const int nb = n / qk;
  820. const int ncols_interleaved = 8;
  821. const int blocklen = 8;
  822. static const uint32_t kmask1 = 0x3f3f3f3f;
  823. static const uint32_t kmask2 = 0x0f0f0f0f;
  824. static const uint32_t kmask3 = 0x03030303;
  825. assert (n % qk == 0);
  826. assert (nr % 4 == 0);
  827. assert (nc % ncols_interleaved == 0);
  828. UNUSED(s);
  829. UNUSED(bs);
  830. UNUSED(vx);
  831. UNUSED(vy);
  832. UNUSED(nr);
  833. UNUSED(nc);
  834. UNUSED(nb);
  835. UNUSED(ncols_interleaved);
  836. UNUSED(blocklen);
  837. float sumf[4][8];
  838. float sum_minf[4][8];
  839. uint32_t utmp[32];
  840. int sumi1;
  841. int sumi2;
  842. int sumi;
  843. for (int y = 0; y < nr / 4; y++) {
  844. const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
  845. for (int x = 0; x < nc / ncols_interleaved; x++) {
  846. const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
  847. for (int m = 0; m < 4; m++) {
  848. for (int j = 0; j < ncols_interleaved; j++) {
  849. sumf[m][j] = 0.0;
  850. sum_minf[m][j] = 0.0;
  851. }
  852. }
  853. for (int l = 0; l < nb; l++) {
  854. for (int sb = 0; sb < 8; sb++) {
  855. memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
  856. utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
  857. const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
  858. utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
  859. utmp[sb * 4 + 2] = uaux_0;
  860. utmp[sb * 4 + 0] &= kmask1;
  861. }
  862. for (int k = 0; k < (qk / (2 * blocklen)); k++) {
  863. uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32;
  864. uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16;
  865. for (int m = 0; m < 4; m++) {
  866. for (int j = 0; j < ncols_interleaved; j++) {
  867. sumi1 = 0;
  868. sumi2 = 0;
  869. sumi = 0;
  870. for (int i = 0; i < blocklen; ++i) {
  871. const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
  872. const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
  873. sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i]);
  874. sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
  875. sumi1 = sumi1 * scales_0[j];
  876. sumi2 = sumi2 * scales_1[j];
  877. sumi += sumi1 + sumi2;
  878. }
  879. sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
  880. }
  881. }
  882. }
  883. for (int sb = 0; sb < 8; sb++) {
  884. uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16;
  885. for(int m = 0; m < 4; m++) {
  886. const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
  887. for(int j = 0; j < ncols_interleaved; j++) {
  888. sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
  889. }
  890. }
  891. }
  892. }
  893. for (int m = 0; m < 4; m++) {
  894. for (int j = 0; j < ncols_interleaved; j++) {
  895. s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
  896. }
  897. }
  898. }
  899. }
  900. }
  901. void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
  902. const int qk = QK_K;
  903. const int nb = n / qk;
  904. const int ncols_interleaved = 8;
  905. const int blocklen = 8;
  906. assert (n % qk == 0);
  907. assert (nr % 4 == 0);
  908. assert (nc % ncols_interleaved == 0);
  909. UNUSED(s);
  910. UNUSED(bs);
  911. UNUSED(vx);
  912. UNUSED(vy);
  913. UNUSED(nr);
  914. UNUSED(nc);
  915. UNUSED(nb);
  916. UNUSED(ncols_interleaved);
  917. UNUSED(blocklen);
  918. float sumf[4][8];
  919. float sum_minf[4][8];
  920. int sumi1, sumi2, sumi3, sumi4;
  921. int sumi;
  922. for (int y = 0; y < nr / 4; y++) {
  923. const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
  924. for (int x = 0; x < nc / ncols_interleaved; x++) {
  925. const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb);
  926. for (int m = 0; m < 4; m++) {
  927. for (int j = 0; j < ncols_interleaved; j++) {
  928. sumf[m][j] = 0.0;
  929. sum_minf[m][j] = 0.0;
  930. }
  931. }
  932. for (int l = 0; l < nb; l++) {
  933. for (int k = 0; k < (qk / (4 * blocklen)); k++) {
  934. const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ;
  935. const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16;
  936. const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32;
  937. const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48;
  938. for (int m = 0; m < 4; m++) {
  939. for (int j = 0; j < ncols_interleaved; j++) {
  940. sumi1 = 0;
  941. sumi2 = 0;
  942. sumi3 = 0;
  943. sumi4 = 0;
  944. sumi = 0;
  945. int offset = ((k / 2) % 2) + j * 2;
  946. for (int i = 0; i < blocklen; ++i){
  947. const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3);
  948. const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3);
  949. const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3);
  950. const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3);
  951. sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i]);
  952. sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]);
  953. sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 256]);
  954. sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 384]);
  955. sumi1 = sumi1 * (scales_0[offset] & 0xF);
  956. sumi2 = sumi2 * (scales_1[offset] & 0xF);
  957. sumi3 = sumi3 * (scales_2[offset] & 0xF);
  958. sumi4 = sumi4 * (scales_3[offset] & 0xF);
  959. sumi += sumi1 + sumi2 + sumi3 + sumi4;
  960. }
  961. sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
  962. }
  963. }
  964. }
  965. for(int sb = 0; sb < 8; sb++) {
  966. const uint8_t *mins = b_ptr[l].scales + sb * 16;
  967. for(int m = 0; m < 4; m++) {
  968. const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
  969. for(int j = 0; j < ncols_interleaved; j++) {
  970. int mins_prod = ((mins[j * 2] >> 4) * bsums[0] + (mins[(j * 2)+ 1] >> 4) * bsums[1]);
  971. sum_minf[m][j] += (mins_prod) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
  972. }
  973. }
  974. }
  975. }
  976. for (int m = 0; m < 4; m++) {
  977. for (int j = 0; j < ncols_interleaved; j++) {
  978. s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
  979. }
  980. }
  981. }
  982. }
  983. }
  984. void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
  985. const int qk = QK8_0;
  986. const int nb = n / qk;
  987. const int ncols_interleaved = 4;
  988. const int blocklen = 4;
  989. assert (n % qk == 0);
  990. assert (nr % 4 == 0);
  991. assert (nc % ncols_interleaved == 0);
  992. UNUSED(s);
  993. UNUSED(bs);
  994. UNUSED(vx);
  995. UNUSED(vy);
  996. UNUSED(nr);
  997. UNUSED(nc);
  998. UNUSED(nb);
  999. UNUSED(ncols_interleaved);
  1000. UNUSED(blocklen);
  1001. {
  1002. float sumf[4][4];
  1003. int sumi;
  1004. for (int y = 0; y < nr / 4; y++) {
  1005. const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
  1006. for (int x = 0; x < nc / ncols_interleaved; x++) {
  1007. const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb);
  1008. for (int m = 0; m < 4; m++) {
  1009. for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
  1010. }
  1011. for (int l = 0; l < nb; l++) {
  1012. for (int k = 0; k < (qk / (2 * blocklen)); k++) {
  1013. for (int m = 0; m < 4; m++) {
  1014. for (int j = 0; j < ncols_interleaved; j++) {
  1015. sumi = 0;
  1016. for (int i = 0; i < blocklen; ++i) {
  1017. const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
  1018. const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
  1019. sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
  1020. (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
  1021. }
  1022. sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
  1023. }
  1024. }
  1025. }
  1026. }
  1027. for (int m = 0; m < 4; m++) {
  1028. for (int j = 0; j < ncols_interleaved; j++)
  1029. s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
  1030. }
  1031. }
  1032. }
  1033. }
  1034. }
  1035. void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
  1036. const int qk = QK8_0;
  1037. const int nb = n / qk;
  1038. const int ncols_interleaved = 8;
  1039. const int blocklen = 8;
  1040. assert(n % qk == 0);
  1041. assert(nr % 4 == 0);
  1042. assert(nc % ncols_interleaved == 0);
  1043. float sumf[4][8];
  1044. int sumi;
  1045. for (int y = 0; y < nr / 4; y++) {
  1046. const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
  1047. for (int x = 0; x < nc / ncols_interleaved; x++) {
  1048. const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb);
  1049. for (int m = 0; m < 4; m++) {
  1050. for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
  1051. }
  1052. for (int l = 0; l < nb; l++) {
  1053. for (int k = 0; k < (qk / (2 * blocklen)); k++) {
  1054. for (int m = 0; m < 4; m++) {
  1055. for (int j = 0; j < ncols_interleaved; j++) {
  1056. sumi = 0;
  1057. for (int i = 0; i < blocklen; ++i) {
  1058. const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
  1059. const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
  1060. sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
  1061. (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
  1062. }
  1063. sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
  1064. }
  1065. }
  1066. }
  1067. }
  1068. for (int m = 0; m < 4; m++) {
  1069. for (int j = 0; j < ncols_interleaved; j++)
  1070. s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
  1071. }
  1072. }
  1073. }
  1074. }
  1075. } // extern "C"
  1076. static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave) {
  1077. block_q4_0x4 out;
  1078. for (int i = 0; i < 4; i++) {
  1079. out.d[i] = in[i].d;
  1080. }
  1081. const int end = QK4_0 * 2 / blck_size_interleave;
  1082. if (blck_size_interleave == 8) {
  1083. const uint64_t xor_mask = 0x8888888888888888ULL;
  1084. for (int i = 0; i < end; ++i) {
  1085. int src_id = i % 4;
  1086. int src_offset = (i / 4) * blck_size_interleave;
  1087. int dst_offset = i * blck_size_interleave;
  1088. uint64_t elems;
  1089. // Using memcpy to avoid unaligned memory accesses
  1090. memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
  1091. elems ^= xor_mask;
  1092. memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
  1093. }
  1094. } else if (blck_size_interleave == 4) {
  1095. const uint32_t xor_mask = 0x88888888;
  1096. for (int i = 0; i < end; ++i) {
  1097. int src_id = i % 4;
  1098. int src_offset = (i / 4) * blck_size_interleave;
  1099. int dst_offset = i * blck_size_interleave;
  1100. uint32_t elems;
  1101. memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint32_t));
  1102. elems ^= xor_mask;
  1103. memcpy(&out.qs[dst_offset], &elems, sizeof(uint32_t));
  1104. }
  1105. } else {
  1106. GGML_ASSERT(false);
  1107. }
  1108. return out;
  1109. }
  1110. // interleave 8 block_q4_0s in blocks of blck_size_interleave
  1111. // returns an interleaved block_q4_0x8
  1112. // in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
  1113. // first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
  1114. static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave) {
  1115. block_q4_0x8 out;
  1116. for (int i = 0; i < 8; i++) {
  1117. out.d[i] = in[i].d;
  1118. }
  1119. const int end = QK4_0 * 4 / blck_size_interleave;
  1120. const uint64_t xor_mask = 0x8888888888888888ULL;
  1121. for (int i = 0; i < end; ++i) {
  1122. int src_id = i % 8;
  1123. int src_offset = (i / 8) * blck_size_interleave;
  1124. int dst_offset = i * blck_size_interleave;
  1125. uint64_t elems;
  1126. memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
  1127. elems ^= xor_mask;
  1128. memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
  1129. }
  1130. return out;
  1131. }
  1132. static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) {
  1133. block_q4_Kx8 out;
  1134. //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure
  1135. for (int i = 0; i < 8; i++) {
  1136. out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
  1137. }
  1138. for (int i = 0; i < 8; i++) {
  1139. out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
  1140. }
  1141. const int end = QK_K * 4 / blck_size_interleave;
  1142. // Interleave Q4_K quants by taking 8 bytes at a time
  1143. for (int i = 0; i < end; ++i) {
  1144. int src_id = i % 8;
  1145. int src_offset = (i / 8) * blck_size_interleave;
  1146. int dst_offset = i * blck_size_interleave;
  1147. uint64_t elems;
  1148. memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
  1149. memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
  1150. }
  1151. // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K
  1152. // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value)
  1153. // The output Q4_Kx8 structure has 96 bytes
  1154. // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure
  1155. // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures
  1156. uint8_t s[8], m[8];
  1157. for (int i = 0; i < 4; i++) {
  1158. for (int j = 0; j < 8; j++) {
  1159. s[j] = in[j].scales[i] & 63;
  1160. m[j] = in[j].scales[i + 4] & 63;
  1161. }
  1162. out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2);
  1163. out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2);
  1164. out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2);
  1165. out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2);
  1166. out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2);
  1167. out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2);
  1168. out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2);
  1169. out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2);
  1170. out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4);
  1171. out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4);
  1172. out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4);
  1173. out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4);
  1174. }
  1175. for (int i = 0; i < 4; i++) {
  1176. for (int j = 0; j < 8; j++) {
  1177. s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15);
  1178. m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4);
  1179. }
  1180. out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2);
  1181. out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2);
  1182. out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2);
  1183. out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2);
  1184. out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2);
  1185. out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2);
  1186. out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2);
  1187. out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2);
  1188. out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4);
  1189. out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4);
  1190. out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4);
  1191. out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4);
  1192. }
  1193. return out;
  1194. }
  1195. static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_interleave) {
  1196. block_q2_Kx8 out;
  1197. // Delta(scale) and dmin values of the eight Q2_K structures are copied onto the output interleaved structure
  1198. for (int i = 0; i < 8; i++) {
  1199. out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d;
  1200. }
  1201. for (int i = 0; i < 8; i++) {
  1202. out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin;
  1203. }
  1204. const int end = QK_K * 2 / blck_size_interleave;
  1205. // Interleave Q2_K quants by taking 8 bytes at a time
  1206. for (int i = 0; i < end; ++i) {
  1207. int src_id = i % 8;
  1208. int src_offset = (i / 8) * blck_size_interleave;
  1209. int dst_offset = i * blck_size_interleave;
  1210. uint64_t elems;
  1211. memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t));
  1212. memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t));
  1213. }
  1214. // The below logic is designed so as to unpack and rearrange scales and mins values in Q2_K
  1215. // Currently the Q2_K structure has 16 scales and 16 mins packed in 16 bytes ( 4 bits for each value)
  1216. // The output Q2_Kx8 structure has 128 bytes for storing scales and mins
  1217. // Every 16 byte is packed such that it contains scales and mins for corresponding sub blocks from Q2_K structure
  1218. // For eg - First 16 bytes contains 16 scales and 16 mins - each of first and second sub blocks from different Q2_K structures
  1219. for(int i = 0; i < 128; i++){
  1220. // Index for selecting which q2k super block
  1221. int src1 = (i % 16) / 2;
  1222. // Index for selecting scale
  1223. int src2 = ((i / 16) * 2) + (i % 2);
  1224. out.scales[i] = in[src1].scales[src2];
  1225. }
  1226. return out;
  1227. }
  1228. static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
  1229. GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
  1230. GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
  1231. constexpr int nrows_interleaved = 4;
  1232. block_q4_0x4 * dst = (block_q4_0x4 *)t->data;
  1233. const block_q4_0 * src = (const block_q4_0 *)data;
  1234. block_q4_0 dst_tmp[4];
  1235. int nrow = ggml_nrows(t);
  1236. int nblocks = t->ne[0] / QK4_0;
  1237. GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
  1238. if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
  1239. return -1;
  1240. }
  1241. for (int b = 0; b < nrow; b += nrows_interleaved) {
  1242. for (int64_t x = 0; x < nblocks; x++) {
  1243. for (int i = 0; i < nrows_interleaved; i++) {
  1244. dst_tmp[i] = src[x + i * nblocks];
  1245. }
  1246. *dst++ = make_block_q4_0x4(dst_tmp, interleave_block);
  1247. }
  1248. src += nrows_interleaved * nblocks;
  1249. }
  1250. return 0;
  1251. GGML_UNUSED(data_size);
  1252. }
  1253. static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
  1254. GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
  1255. GGML_ASSERT(interleave_block == 8 || interleave_block == 4);
  1256. constexpr int nrows_interleaved = 8;
  1257. block_q4_Kx8 * dst = (block_q4_Kx8*)t->data;
  1258. const block_q4_K * src = (const block_q4_K*) data;
  1259. block_q4_K dst_tmp[8];
  1260. int nrow = ggml_nrows(t);
  1261. int nblocks = t->ne[0] / QK_K;
  1262. GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K));
  1263. if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
  1264. return -1;
  1265. }
  1266. for (int b = 0; b < nrow; b += nrows_interleaved) {
  1267. for (int64_t x = 0; x < nblocks; x++) {
  1268. for (int i = 0; i < nrows_interleaved; i++ ) {
  1269. dst_tmp[i] = src[x + i * nblocks];
  1270. }
  1271. *dst++ = make_block_q4_Kx8(dst_tmp, interleave_block);
  1272. }
  1273. src += nrows_interleaved * nblocks;
  1274. }
  1275. return 0;
  1276. GGML_UNUSED(data_size);
  1277. }
  1278. static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
  1279. GGML_ASSERT(t->type == GGML_TYPE_Q2_K);
  1280. GGML_ASSERT(interleave_block == 8);
  1281. constexpr int nrows_interleaved = 8;
  1282. block_q2_Kx8 * dst = (block_q2_Kx8*)t->data;
  1283. const block_q2_K * src = (const block_q2_K*) data;
  1284. block_q2_K dst_tmp[8];
  1285. int nrow = ggml_nrows(t);
  1286. int nblocks = t->ne[0] / QK_K;
  1287. GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K));
  1288. if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
  1289. return -1;
  1290. }
  1291. for (int b = 0; b < nrow; b += nrows_interleaved) {
  1292. for (int64_t x = 0; x < nblocks; x++) {
  1293. for (int i = 0; i < nrows_interleaved; i++ ) {
  1294. dst_tmp[i] = src[x + i * nblocks];
  1295. }
  1296. *dst++ = make_block_q2_Kx8(dst_tmp, interleave_block);
  1297. }
  1298. src += nrows_interleaved * nblocks;
  1299. }
  1300. return 0;
  1301. GGML_UNUSED(data_size);
  1302. }
  1303. static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
  1304. GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
  1305. GGML_ASSERT(interleave_block == 8);
  1306. constexpr int nrows_interleaved = 8;
  1307. block_q4_0x8 * dst = (block_q4_0x8*)t->data;
  1308. const block_q4_0 * src = (const block_q4_0*) data;
  1309. block_q4_0 dst_tmp[8];
  1310. int nrow = ggml_nrows(t);
  1311. int nblocks = t->ne[0] / QK4_0;
  1312. GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0));
  1313. if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
  1314. return -1;
  1315. }
  1316. for (int b = 0; b < nrow; b += nrows_interleaved) {
  1317. for (int64_t x = 0; x < nblocks; x++) {
  1318. for (int i = 0; i < nrows_interleaved; i++ ) {
  1319. dst_tmp[i] = src[x + i * nblocks];
  1320. }
  1321. *dst++ = make_block_q4_0x8(dst_tmp, interleave_block);
  1322. }
  1323. src += nrows_interleaved * nblocks;
  1324. }
  1325. return 0;
  1326. GGML_UNUSED(data_size);
  1327. }
  1328. static block_iq4_nlx4 make_block_iq4_nlx4(block_iq4_nl * in, unsigned int blck_size_interleave) {
  1329. block_iq4_nlx4 out;
  1330. for (int i = 0; i < 4; i++) {
  1331. out.d[i] = in[i].d;
  1332. }
  1333. const int end = QK4_NL * 2 / blck_size_interleave;
  1334. // TODO: this branch seems wrong
  1335. //if (blck_size_interleave == 8) {
  1336. // for (int i = 0; i < end; ++i) {
  1337. // int src_id = i % 4;
  1338. // int src_offset = (i / 4) * blck_size_interleave;
  1339. // int dst_offset = i * blck_size_interleave;
  1340. // // Using memcpy to avoid unaligned memory accesses
  1341. // memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
  1342. // }
  1343. //} else
  1344. if (blck_size_interleave == 4) {
  1345. for (int i = 0; i < end; ++i) {
  1346. int src_id = i % 4;
  1347. int src_offset = (i / 4) * blck_size_interleave;
  1348. int dst_offset = i * blck_size_interleave;
  1349. memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t));
  1350. }
  1351. } else {
  1352. GGML_ASSERT(false);
  1353. }
  1354. return out;
  1355. }
  1356. static int repack_iq4_nl_to_iq4_nl_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
  1357. GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
  1358. GGML_ASSERT(interleave_block == 4);
  1359. const block_iq4_nl * src = (const block_iq4_nl *)data;
  1360. block_iq4_nlx4 * dst = ( block_iq4_nlx4 *)t->data;
  1361. block_iq4_nl dst_tmp[4];
  1362. int nrow = ggml_nrows(t);
  1363. int nrows_interleaved = 4;
  1364. int nblocks = t->ne[0] / QK4_NL;
  1365. GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
  1366. if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
  1367. return -1;
  1368. }
  1369. for (int b = 0; b < nrow; b += nrows_interleaved) {
  1370. for (int64_t x = 0; x < nblocks; x++) {
  1371. for (int i = 0; i < nrows_interleaved; i++) {
  1372. dst_tmp[i] = src[x + i * nblocks];
  1373. }
  1374. *dst++ = make_block_iq4_nlx4(dst_tmp, interleave_block);
  1375. }
  1376. src += nrows_interleaved * nblocks;
  1377. }
  1378. return 0;
  1379. GGML_UNUSED(data_size);
  1380. }
  1381. static block_iq4_nlx8 make_block_iq4_nlx8(block_iq4_nl * in, unsigned int blck_size_interleave) {
  1382. block_iq4_nlx8 out;
  1383. for (int i = 0; i < 8; i++) {
  1384. out.d[i] = in[i].d;
  1385. }
  1386. const int end = QK4_NL * 4 / blck_size_interleave;
  1387. if (blck_size_interleave == 8) {
  1388. for (int i = 0; i < end; ++i) {
  1389. int src_id = i % 8;
  1390. int src_offset = (i / 8) * blck_size_interleave;
  1391. int dst_offset = i * blck_size_interleave;
  1392. memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
  1393. }
  1394. } else {
  1395. GGML_ASSERT(false);
  1396. }
  1397. return out;
  1398. }
  1399. static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
  1400. GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL);
  1401. GGML_ASSERT(interleave_block == 8);
  1402. const block_iq4_nl * src = (const block_iq4_nl *)data;
  1403. block_iq4_nlx8 * dst = ( block_iq4_nlx8 *)t->data;
  1404. block_iq4_nl dst_tmp[8];
  1405. int nrow = ggml_nrows(t);
  1406. int nrows_interleaved = 8;
  1407. int nblocks = t->ne[0] / QK4_NL;
  1408. GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl));
  1409. if (t->ne[1] % nrows_interleaved != 0) {
  1410. return -1;
  1411. }
  1412. for (int b = 0; b < nrow; b += nrows_interleaved) {
  1413. for (int64_t x = 0; x < nblocks; x++) {
  1414. for (int i = 0; i < nrows_interleaved; i++) {
  1415. dst_tmp[i] = src[x + i * nblocks];
  1416. }
  1417. *dst++ = make_block_iq4_nlx8(dst_tmp, interleave_block);
  1418. }
  1419. src += nrows_interleaved * nblocks;
  1420. }
  1421. return 0;
  1422. GGML_UNUSED(data_size);
  1423. }
  1424. namespace ggml::cpu::repack {
  1425. // repack
  1426. template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
  1427. int repack(struct ggml_tensor *, const void *, size_t);
  1428. // TODO: generalise.
  1429. template <> int repack<block_q4_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
  1430. return repack_q4_0_to_q4_0_4_bl(t, 4, data, data_size);
  1431. }
  1432. template <> int repack<block_q4_0, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
  1433. return repack_q4_0_to_q4_0_4_bl(t, 8, data, data_size);
  1434. }
  1435. template <> int repack<block_q4_0, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
  1436. return repack_q4_0_to_q4_0_8_bl(t, 8, data, data_size);
  1437. }
  1438. template <> int repack<block_q4_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
  1439. return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
  1440. }
  1441. template <> int repack<block_q4_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
  1442. return repack_q4_K_to_q4_K_8_bl(t, 4, data, data_size);
  1443. }
  1444. template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
  1445. return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
  1446. }
  1447. template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
  1448. return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size);
  1449. }
  1450. // TODO: needs to be revisited
  1451. //template <> int repack<block_iq4_nl, 8, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
  1452. // return repack_iq4_nl_to_iq4_nl_4_bl(t, 8, data, data_size);
  1453. //}
  1454. template <> int repack<block_iq4_nl, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
  1455. return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size);
  1456. }
  1457. // gemv
  1458. template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
  1459. void gemv(int, float *, size_t, const void *, const void *, int, int);
  1460. template <> void gemv<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
  1461. ggml_gemv_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
  1462. }
  1463. template <> void gemv<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
  1464. ggml_gemv_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
  1465. }
  1466. template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
  1467. ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
  1468. }
  1469. template <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
  1470. ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
  1471. }
  1472. template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
  1473. ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
  1474. }
  1475. template <> void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
  1476. ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
  1477. }
  1478. template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
  1479. ggml_gemv_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
  1480. }
  1481. template <> void gemv<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
  1482. ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
  1483. }
  1484. // gemm
  1485. template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE>
  1486. void gemm(int, float *, size_t, const void *, const void *, int, int);
  1487. template <> void gemm<block_q4_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
  1488. ggml_gemm_q4_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
  1489. }
  1490. template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
  1491. ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
  1492. }
  1493. template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
  1494. ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
  1495. }
  1496. template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
  1497. ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
  1498. }
  1499. template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
  1500. ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
  1501. }
  1502. template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
  1503. ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
  1504. }
  1505. template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
  1506. ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
  1507. }
  1508. template <> void gemm<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
  1509. ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
  1510. }
  1511. class tensor_traits_base : public ggml::cpu::tensor_traits {
  1512. public:
  1513. virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0;
  1514. };
  1515. template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE> class tensor_traits : public tensor_traits_base {
  1516. bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
  1517. // not realy a GGML_TYPE_Q8_0 but same size.
  1518. switch (op->op) {
  1519. case GGML_OP_MUL_MAT:
  1520. {
  1521. size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
  1522. return true;
  1523. }
  1524. case GGML_OP_MUL_MAT_ID:
  1525. {
  1526. size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1]));
  1527. size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc.
  1528. const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert
  1529. const int64_t ne12 = op->src[1]->ne[2]; // n_tokens
  1530. const size_t sizeof_mmid_row_mapping = sizeof(int64_t);
  1531. size += sizeof_mmid_row_mapping*ne02*(ne12 + 1);
  1532. return true;
  1533. }
  1534. default:
  1535. // GGML_ABORT("fatal error");
  1536. break;
  1537. }
  1538. return false;
  1539. }
  1540. bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override {
  1541. switch (op->op) {
  1542. case GGML_OP_MUL_MAT:
  1543. forward_mul_mat(params, op);
  1544. return true;
  1545. case GGML_OP_MUL_MAT_ID:
  1546. forward_mul_mat_id(params, op);
  1547. return true;
  1548. default:
  1549. // GGML_ABORT("fatal error");
  1550. break;
  1551. }
  1552. return false;
  1553. }
  1554. void forward_mul_mat_one_chunk(ggml_compute_params * params,
  1555. ggml_tensor * op,
  1556. int64_t src0_start,
  1557. int64_t src0_end,
  1558. int64_t src1_start,
  1559. int64_t src1_end) {
  1560. const ggml_tensor * src0 = op->src[0];
  1561. const ggml_tensor * src1 = op->src[1];
  1562. ggml_tensor * dst = op;
  1563. GGML_TENSOR_BINARY_OP_LOCALS
  1564. const size_t src1_col_stride = ggml_row_size(PARAM_TYPE, ne10);
  1565. GGML_ASSERT(ne03 == 1 && ne13 == 1);
  1566. GGML_ASSERT(ne12 % ne02 == 0);
  1567. const int64_t r2 = ne12 / ne02;
  1568. const int64_t i12 = src1_start / ne1;
  1569. const int64_t i11 = src1_start - i12 * ne1;
  1570. // Determine batch index
  1571. const int64_t i02 = i12 / r2;
  1572. const int64_t i1 = i11;
  1573. const int64_t i2 = i12;
  1574. const char * src0_ptr = (const char *) src0->data + i02 * nb02;
  1575. const char * src1_ptr = (const char *) params->wdata + (i11 + i12 * ne11) * src1_col_stride;
  1576. char * dst_ptr = ((char *) dst->data + (i1 * nb1 + i2 * nb2));
  1577. const int64_t nrows = src1_end - src1_start;
  1578. const int64_t ncols = src0_end - src0_start;
  1579. GGML_ASSERT(src1_ptr + src1_col_stride * nrows <= (const char *) params->wdata + params->wsize);
  1580. // If there are more than three rows in src1, use gemm; otherwise, use gemv.
  1581. if (nrows > 3) {
  1582. gemm<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr) + src0_start, nb1 / nb0,
  1583. src0_ptr + src0_start * nb01, src1_ptr,
  1584. nrows - (nrows % 4), ncols);
  1585. }
  1586. for (int iter = nrows - (nrows % 4); iter < nrows; iter++) {
  1587. gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, (float *) (dst_ptr + (iter * nb1)) + src0_start,
  1588. ne01, src0_ptr + src0_start * nb01,
  1589. src1_ptr + (src1_col_stride * iter), 1 /* nrows */, ncols);
  1590. }
  1591. }
  1592. void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) {
  1593. const ggml_tensor * src0 = op->src[0];
  1594. const ggml_tensor * src1 = op->src[1];
  1595. ggml_tensor * dst = op;
  1596. GGML_TENSOR_BINARY_OP_LOCALS
  1597. const int ith = params->ith;
  1598. const int nth = params->nth;
  1599. GGML_ASSERT(ne0 == ne01);
  1600. GGML_ASSERT(ne1 == ne11);
  1601. GGML_ASSERT(ne2 == ne12);
  1602. GGML_ASSERT(ne3 == ne13);
  1603. // dst cannot be transposed or permuted
  1604. GGML_ASSERT(nb0 == sizeof(float));
  1605. GGML_ASSERT(nb0 <= nb1);
  1606. GGML_ASSERT(nb1 <= nb2);
  1607. GGML_ASSERT(nb2 <= nb3);
  1608. // TODO: General batched mul mat for 4D tensors
  1609. // Currently only supports 3D tensors
  1610. GGML_ASSERT(ne03 == 1);
  1611. GGML_ASSERT(ne13 == 1);
  1612. GGML_ASSERT(ne3 == 1);
  1613. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  1614. GGML_ASSERT(ggml_n_dims(op->src[0]) == 2);
  1615. // GGML_ASSERT(ggml_n_dims(op->src[1]) == 2);
  1616. char * wdata = static_cast<char *>(params->wdata);
  1617. const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
  1618. const size_t nbw2 = nbw1 * ne11;
  1619. assert(params->wsize >= nbw2 * ne12);
  1620. const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
  1621. // INFO: Quantization is done in planes to avoid extra complexity in chunking.
  1622. // Flattening dimensions not multiple of INTER_SIZE would require extra handling depending on how
  1623. // the planes are broadcast.
  1624. for (int64_t i12 = 0; i12 < ne12; i12++) {
  1625. char * data_ptr = (char *) src1->data + i12 * nb12;
  1626. char * wdata_ptr = wdata + i12 * nbw2;
  1627. for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
  1628. ggml_quantize_mat_t<INTER_SIZE, PARAM_TYPE>((float *) (data_ptr + i11 * nb11),
  1629. (void *) (wdata_ptr + i11 * nbw1), 4, ne10);
  1630. }
  1631. const int64_t i11_processed = ne11 - ne11 % 4;
  1632. for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
  1633. from_float((float *) (data_ptr + i11 * nb11), (void *) (wdata_ptr + i11 * nbw1), ne10);
  1634. }
  1635. }
  1636. // disable for NUMA
  1637. const bool disable_chunking = ggml_is_numa();
  1638. // 4x chunks per thread
  1639. const int64_t nr0 = ggml_nrows(op->src[0]);
  1640. int nth_scaled = nth * 4;
  1641. int64_t chunk_size0 = (nr0 + nth_scaled - 1) / nth_scaled;
  1642. int64_t nchunk0 = (nr0 + chunk_size0 - 1) / chunk_size0;
  1643. // src1 is chunked only by full planes.
  1644. // When we flatten we need to address dimensions not multiple of the q8 INTER_SIZE
  1645. // to route them thorugh GEMV.
  1646. // nchunk1 = ne12 also avoids messing the chunking for models with no 3d tensors
  1647. // to avoid affecting their performance
  1648. int64_t nchunk1 = ne12;
  1649. // Ensure minimum chunk size to avoid alignment issues with high thread counts
  1650. // Minimum chunk size should be at least NB_COLS to prevent overlapping chunks after alignment
  1651. const int64_t min_chunk_size = NB_COLS;
  1652. if (nchunk0 > 0 && (nr0 / nchunk0) < min_chunk_size && nr0 >= min_chunk_size) {
  1653. nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
  1654. }
  1655. int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
  1656. // Only increase nchunk0 to nth if it won't make chunks too small
  1657. if (nth == 1 || ((nchunk0 < nth || disable_chunking) && (nr0 + nth - 1) / nth >= min_chunk_size)) {
  1658. nchunk0 = nth;
  1659. dr0 = (nr0 + nchunk0 - 1) / nchunk0;
  1660. }
  1661. // Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
  1662. // This prevents creating too many tiny chunks that could overlap after alignment
  1663. const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
  1664. nchunk0 = MIN(nchunk0, max_nchunk);
  1665. if (ith == 0) {
  1666. // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
  1667. ggml_threadpool_chunk_set(params->threadpool, nth);
  1668. }
  1669. ggml_barrier(params->threadpool);
  1670. // The first chunk comes from our thread_id, the rest will get auto-assigned.
  1671. int current_chunk = ith;
  1672. while (current_chunk < nchunk0 * nchunk1) {
  1673. const int64_t ith0 = current_chunk % nchunk0;
  1674. const int64_t ith1 = current_chunk / nchunk0;
  1675. int64_t src0_start = dr0 * ith0;
  1676. int64_t src0_end = MIN(src0_start + dr0, nr0);
  1677. // full-plane range for src1
  1678. int64_t src1_start = ith1 * ne11;
  1679. int64_t src1_end = (ith1 + 1) * ne11;
  1680. // Align boundaries to NB_COLS - round up to ensure all data is included
  1681. // The chunk size limiting above ensures chunks are large enough to prevent overlaps
  1682. src0_start = (src0_start % NB_COLS) ? src0_start + NB_COLS - (src0_start % NB_COLS) : src0_start;
  1683. src0_end = (src0_end % NB_COLS) ? src0_end + NB_COLS - (src0_end % NB_COLS) : src0_end;
  1684. src0_end = MIN(src0_end, ne01);
  1685. // Make sure current plane is the last one before exiting
  1686. if (src0_start >= src0_end) {
  1687. current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
  1688. continue;
  1689. }
  1690. forward_mul_mat_one_chunk(params, dst, src0_start, src0_end, src1_start, src1_end);
  1691. current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
  1692. }
  1693. }
  1694. void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) {
  1695. const ggml_tensor * src0 = op->src[0];
  1696. const ggml_tensor * src1 = op->src[1];
  1697. const ggml_tensor * ids = op->src[2];
  1698. ggml_tensor * dst = op;
  1699. GGML_TENSOR_BINARY_OP_LOCALS
  1700. const int ith = params->ith;
  1701. const int nth = params->nth;
  1702. const ggml_from_float_t from_float = ggml_get_type_traits_cpu(PARAM_TYPE)->from_float;
  1703. // we don't support permuted src0 or src1
  1704. GGML_ASSERT(nb00 == ggml_type_size(src0->type));
  1705. GGML_ASSERT(nb10 == ggml_type_size(src1->type));
  1706. // dst cannot be transposed or permuted
  1707. GGML_ASSERT(nb0 == sizeof(float));
  1708. GGML_ASSERT(nb0 <= nb1);
  1709. GGML_ASSERT(nb1 <= nb2);
  1710. GGML_ASSERT(nb2 <= nb3);
  1711. GGML_ASSERT(ne03 == 1);
  1712. GGML_ASSERT(ne13 == 1);
  1713. GGML_ASSERT(ne3 == 1);
  1714. GGML_ASSERT(src1->type == GGML_TYPE_F32);
  1715. // row groups
  1716. const int n_ids = ids->ne[0]; // n_expert_used
  1717. const int n_as = ne02; // n_expert
  1718. const size_t nbw1 = ggml_row_size(PARAM_TYPE, ne10);
  1719. const size_t nbw2 = nbw1*ne11;
  1720. const size_t nbw3 = nbw2*ne12;
  1721. struct mmid_row_mapping {
  1722. int32_t i1;
  1723. int32_t i2;
  1724. };
  1725. GGML_ASSERT(params->wsize >=
  1726. (GGML_PAD(nbw3, sizeof(int64_t)) +
  1727. n_as*(ne12 + 1)*sizeof(mmid_row_mapping))
  1728. );
  1729. auto * wdata = (char *)params->wdata;
  1730. auto * wdata_src1_end = (char *)wdata + GGML_PAD(nbw3, sizeof(int64_t));
  1731. // total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t)
  1732. auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
  1733. struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12]
  1734. // src1: float32 => param type
  1735. for (int64_t i12 = 0; i12 < ne12; ++i12) {
  1736. for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
  1737. from_float((float *)((char *) src1->data + i12 * nb12 + i11 * nb11),
  1738. (void *) (wdata + i12 * nbw2 + i11 * nbw1),
  1739. ne10);
  1740. }
  1741. }
  1742. #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ne12 + (i1)]
  1743. if (ith == 0) {
  1744. // initialize matrix_row_counts
  1745. memset(matrix_row_counts, 0, n_as * sizeof(int64_t));
  1746. // group rows by src0 matrix
  1747. for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
  1748. for (int32_t id = 0; id < n_ids; ++id) {
  1749. const int32_t i02 =
  1750. *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]);
  1751. GGML_ASSERT(i02 >= 0 && i02 < n_as);
  1752. MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 };
  1753. matrix_row_counts[i02] += 1;
  1754. }
  1755. }
  1756. }
  1757. ggml_barrier(params->threadpool);
  1758. // compute each matrix multiplication in sequence
  1759. for (int cur_a = 0; cur_a < n_as; ++cur_a) {
  1760. const int64_t cne1 = matrix_row_counts[cur_a];
  1761. if (cne1 == 0) {
  1762. continue;
  1763. }
  1764. const auto * src0_cur = (const char *) src0->data + cur_a*nb02;
  1765. //const int64_t nr0 = ne01; // src0 rows
  1766. const int64_t nr1 = cne1; // src1 rows
  1767. int64_t src0_cur_start = (ith * ne01) / nth;
  1768. int64_t src0_cur_end = ((ith + 1) * ne01) / nth;
  1769. // Align boundaries to NB_COLS - round up to ensure all data is included
  1770. src0_cur_start = (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start;
  1771. src0_cur_end = (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end;
  1772. if (src0_cur_end > ne01) {
  1773. src0_cur_end = ne01;
  1774. }
  1775. if (src0_cur_start >= src0_cur_end) {
  1776. return;
  1777. }
  1778. for (int ir1 = 0; ir1 < nr1; ir1++) {
  1779. struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1);
  1780. const int id = row_mapping.i1; // selected expert index
  1781. const int64_t i11 = id % ne11;
  1782. const int64_t i12 = row_mapping.i2; // row index in src1
  1783. const int64_t i1 = id; // selected expert index
  1784. const int64_t i2 = i12; // row
  1785. const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2);
  1786. gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00,
  1787. (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01,
  1788. src0_cur + src0_cur_start * nb01,
  1789. src1_col, 1, src0_cur_end - src0_cur_start);
  1790. }
  1791. }
  1792. #undef MMID_MATRIX_ROW
  1793. }
  1794. int repack(struct ggml_tensor * t, const void * data, size_t data_size) override {
  1795. GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type),
  1796. (int) NB_COLS, (int) INTER_SIZE);
  1797. return ggml::cpu::repack::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size);
  1798. }
  1799. };
  1800. } // namespace ggml::cpu::repack
  1801. static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(const struct ggml_tensor * cur) {
  1802. // instance for Q4
  1803. static const ggml::cpu::repack::tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
  1804. static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
  1805. static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
  1806. // instance for Q4_K
  1807. static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K;
  1808. static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
  1809. // instance for Q2
  1810. static const ggml::cpu::repack::tensor_traits<block_q2_K, 8, 8, GGML_TYPE_Q8_K> q2_K_8x8_q8_K;
  1811. // instance for IQ4
  1812. static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
  1813. static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0> iq4_nl_8x8_q8_0;
  1814. if (cur->type == GGML_TYPE_Q4_0) {
  1815. if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) {
  1816. if (cur->ne[1] % 8 == 0) {
  1817. return &q4_0_8x8_q8_0;
  1818. }
  1819. }
  1820. if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
  1821. if (cur->ne[1] % 4 == 0) {
  1822. return &q4_0_4x8_q8_0;
  1823. }
  1824. }
  1825. if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
  1826. if (cur->ne[1] % 4 == 0) {
  1827. return &q4_0_4x4_q8_0;
  1828. }
  1829. }
  1830. } else if (cur->type == GGML_TYPE_Q4_K) {
  1831. if (ggml_cpu_has_avx2()) {
  1832. if (cur->ne[1] % 8 == 0) {
  1833. return &q4_K_8x8_q8_K;
  1834. }
  1835. }
  1836. if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
  1837. if (cur->ne[1] % 8 == 0) {
  1838. return &q4_K_8x8_q8_K;
  1839. }
  1840. }
  1841. if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
  1842. if (cur->ne[1] % 8 == 0) {
  1843. return &q4_K_8x4_q8_K;
  1844. }
  1845. }
  1846. } else if (cur->type == GGML_TYPE_Q2_K) {
  1847. if (ggml_cpu_has_avx512()) {
  1848. if (cur->ne[1] % 8 == 0) {
  1849. return &q2_K_8x8_q8_K;
  1850. }
  1851. }
  1852. } else if (cur->type == GGML_TYPE_IQ4_NL) {
  1853. if (ggml_cpu_has_avx2()) {
  1854. if (cur->ne[1] % 8 == 0) {
  1855. return &iq4_nl_8x8_q8_0;
  1856. }
  1857. }
  1858. if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
  1859. if (cur->ne[1] % 4 == 0) {
  1860. return &iq4_nl_4x4_q8_0;
  1861. }
  1862. }
  1863. }
  1864. return nullptr;
  1865. }
  1866. static enum ggml_status ggml_backend_cpu_repack_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
  1867. tensor->extra = (void *) const_cast<ggml::cpu::tensor_traits *>(ggml_repack_get_optimal_repack_type(tensor));
  1868. GGML_UNUSED(buffer);
  1869. return GGML_STATUS_SUCCESS;
  1870. }
  1871. static void ggml_backend_cpu_repack_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
  1872. const void * data, size_t offset, size_t size) {
  1873. GGML_ASSERT(offset == 0);
  1874. GGML_ASSERT(size == ggml_nbytes(tensor));
  1875. auto tensor_traits = (ggml::cpu::repack::tensor_traits_base *) tensor->extra;
  1876. auto OK = tensor_traits->repack(tensor, data, size);
  1877. GGML_ASSERT(OK == 0);
  1878. GGML_UNUSED(buffer);
  1879. }
  1880. static const char * ggml_backend_cpu_repack_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
  1881. return "CPU_REPACK";
  1882. GGML_UNUSED(buft);
  1883. }
  1884. static ggml_backend_buffer_t ggml_backend_cpu_repack_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
  1885. ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
  1886. if (buffer == nullptr) {
  1887. return nullptr;
  1888. }
  1889. buffer->buft = buft;
  1890. buffer->iface.init_tensor = ggml_backend_cpu_repack_buffer_init_tensor;
  1891. buffer->iface.set_tensor = ggml_backend_cpu_repack_buffer_set_tensor;
  1892. buffer->iface.get_tensor = nullptr;
  1893. buffer->iface.cpy_tensor = nullptr;
  1894. return buffer;
  1895. }
  1896. static size_t ggml_backend_cpu_repack_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
  1897. return TENSOR_ALIGNMENT;
  1898. GGML_UNUSED(buft);
  1899. }
  1900. namespace ggml::cpu::repack {
  1901. class extra_buffer_type : ggml::cpu::extra_buffer_type {
  1902. bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
  1903. if ( op->op == GGML_OP_MUL_MAT &&
  1904. op->src[0]->buffer &&
  1905. (ggml_n_dims(op->src[0]) == 2) &&
  1906. op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type() &&
  1907. ggml_repack_get_optimal_repack_type(op->src[0])
  1908. ) {
  1909. if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
  1910. return false;
  1911. }
  1912. if (op->src[1]->type == GGML_TYPE_F32) {
  1913. return true;
  1914. }
  1915. //if (op->src[1]->type == GGML_TYPE_Q8_0) {
  1916. // return true;
  1917. //}
  1918. // may be possible if Q8_0 packed...
  1919. } else if (op->op == GGML_OP_MUL_MAT_ID
  1920. && op->src[0]->buffer
  1921. && (ggml_n_dims(op->src[0]) == 3)
  1922. && op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()
  1923. && ggml_repack_get_optimal_repack_type(op->src[0])
  1924. ) {
  1925. if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
  1926. return false;
  1927. }
  1928. if (op->src[1]->type == GGML_TYPE_F32) {
  1929. return true;
  1930. }
  1931. //if (op->src[1]->type == GGML_TYPE_Q8_0) {
  1932. // return true;
  1933. //}
  1934. }
  1935. return false;
  1936. }
  1937. ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
  1938. if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) {
  1939. if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_repack_buffer_type()) {
  1940. return (ggml::cpu::tensor_traits *) op->src[0]->extra;
  1941. }
  1942. }
  1943. return nullptr;
  1944. }
  1945. };
  1946. } // namespace ggml::cpu::repack
  1947. ggml_backend_buffer_type_t ggml_backend_cpu_repack_buffer_type(void) {
  1948. static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_repack = {
  1949. /* .iface = */ {
  1950. /* .get_name = */ ggml_backend_cpu_repack_buffer_type_get_name,
  1951. /* .alloc_buffer = */ ggml_backend_cpu_repack_buffer_type_alloc_buffer,
  1952. /* .get_alignment = */ ggml_backend_cpu_repack_buffer_type_get_alignment,
  1953. /* .get_max_size = */ nullptr, // defaults to SIZE_MAX
  1954. /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
  1955. /* .is_host = */ nullptr,
  1956. },
  1957. /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
  1958. /* .context = */ new ggml::cpu::repack::extra_buffer_type(),
  1959. };
  1960. return &ggml_backend_cpu_buffer_type_repack;
  1961. }