ggml-metal.metal 70 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945
  1. #include <metal_stdlib>
  2. using namespace metal;
  3. #define MAX(x, y) ((x) > (y) ? (x) : (y))
  4. #define QK4_0 32
  5. #define QR4_0 2
  6. typedef struct {
  7. half d; // delta
  8. uint8_t qs[QK4_0 / 2]; // nibbles / quants
  9. } block_q4_0;
  10. #define QK4_1 32
  11. typedef struct {
  12. half d; // delta
  13. half m; // min
  14. uint8_t qs[QK4_1 / 2]; // nibbles / quants
  15. } block_q4_1;
  16. kernel void kernel_add(
  17. device const float * src0,
  18. device const float * src1,
  19. device float * dst,
  20. uint tpig[[thread_position_in_grid]]) {
  21. dst[tpig] = src0[tpig] + src1[tpig];
  22. }
  23. // assumption: src1 is a row
  24. // broadcast src1 into src0
  25. kernel void kernel_add_row(
  26. device const float * src0,
  27. device const float * src1,
  28. device float * dst,
  29. constant int64_t & ne00,
  30. uint tpig[[thread_position_in_grid]]) {
  31. dst[tpig] = src0[tpig] + src1[tpig % ne00];
  32. }
  33. kernel void kernel_mul(
  34. device const float * src0,
  35. device const float * src1,
  36. device float * dst,
  37. uint tpig[[thread_position_in_grid]]) {
  38. dst[tpig] = src0[tpig] * src1[tpig];
  39. }
  40. // assumption: src1 is a row
  41. // broadcast src1 into src0
  42. kernel void kernel_mul_row(
  43. device const float * src0,
  44. device const float * src1,
  45. device float * dst,
  46. constant int64_t & ne00,
  47. uint tpig[[thread_position_in_grid]]) {
  48. dst[tpig] = src0[tpig] * src1[tpig % ne00];
  49. }
  50. kernel void kernel_scale(
  51. device const float * src0,
  52. device float * dst,
  53. constant float & scale,
  54. uint tpig[[thread_position_in_grid]]) {
  55. dst[tpig] = src0[tpig] * scale;
  56. }
  57. kernel void kernel_silu(
  58. device const float * src0,
  59. device float * dst,
  60. uint tpig[[thread_position_in_grid]]) {
  61. float x = src0[tpig];
  62. dst[tpig] = x / (1.0f + exp(-x));
  63. }
  64. kernel void kernel_relu(
  65. device const float * src0,
  66. device float * dst,
  67. uint tpig[[thread_position_in_grid]]) {
  68. dst[tpig] = max(0.0f, src0[tpig]);
  69. }
  70. constant float GELU_COEF_A = 0.044715f;
  71. constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
  72. kernel void kernel_gelu(
  73. device const float * src0,
  74. device float * dst,
  75. uint tpig[[thread_position_in_grid]]) {
  76. float x = src0[tpig];
  77. dst[tpig] = 0.5f*x*(1.0f + tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
  78. }
  79. kernel void kernel_soft_max(
  80. device const float * src0,
  81. device float * dst,
  82. constant int64_t & ne00,
  83. constant int64_t & ne01,
  84. constant int64_t & ne02,
  85. threadgroup float * buf [[threadgroup(0)]],
  86. uint3 tgpig[[threadgroup_position_in_grid]],
  87. uint3 tpitg[[thread_position_in_threadgroup]],
  88. uint3 ntg[[threads_per_threadgroup]]) {
  89. const int64_t i03 = tgpig[2];
  90. const int64_t i02 = tgpig[1];
  91. const int64_t i01 = tgpig[0];
  92. device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
  93. device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
  94. // parallel max
  95. buf[tpitg[0]] = -INFINITY;
  96. for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
  97. buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
  98. }
  99. // reduce
  100. threadgroup_barrier(mem_flags::mem_threadgroup);
  101. for (uint i = ntg[0]/2; i > 0; i /= 2) {
  102. if (tpitg[0] < i) {
  103. buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
  104. }
  105. threadgroup_barrier(mem_flags::mem_threadgroup);
  106. }
  107. // broadcast
  108. if (tpitg[0] == 0) {
  109. buf[0] = buf[0];
  110. }
  111. threadgroup_barrier(mem_flags::mem_threadgroup);
  112. const float max = buf[0];
  113. // parallel sum
  114. buf[tpitg[0]] = 0.0f;
  115. for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
  116. buf[tpitg[0]] += exp(psrc0[i00] - max);
  117. }
  118. // reduce
  119. threadgroup_barrier(mem_flags::mem_threadgroup);
  120. for (uint i = ntg[0]/2; i > 0; i /= 2) {
  121. if (tpitg[0] < i) {
  122. buf[tpitg[0]] += buf[tpitg[0] + i];
  123. }
  124. threadgroup_barrier(mem_flags::mem_threadgroup);
  125. }
  126. // broadcast
  127. if (tpitg[0] == 0) {
  128. buf[0] = buf[0];
  129. }
  130. threadgroup_barrier(mem_flags::mem_threadgroup);
  131. const float sum = buf[0];
  132. for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
  133. pdst[i00] = exp(psrc0[i00] - max) / sum;
  134. }
  135. }
  136. kernel void kernel_diag_mask_inf(
  137. device const float * src0,
  138. device float * dst,
  139. constant int64_t & ne00,
  140. constant int64_t & ne01,
  141. constant int & n_past,
  142. uint3 tpig[[thread_position_in_grid]]) {
  143. const int64_t i02 = tpig[2];
  144. const int64_t i01 = tpig[1];
  145. const int64_t i00 = tpig[0];
  146. if (i00 > n_past + i01) {
  147. dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
  148. } else {
  149. dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
  150. }
  151. }
  152. kernel void kernel_norm(
  153. device const void * src0,
  154. device float * dst,
  155. constant int64_t & ne00,
  156. constant uint64_t & nb01,
  157. constant float & eps,
  158. threadgroup float * sum [[threadgroup(0)]],
  159. uint tgpig[[threadgroup_position_in_grid]],
  160. uint tpitg[[thread_position_in_threadgroup]],
  161. uint ntg[[threads_per_threadgroup]]) {
  162. device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
  163. // MEAN
  164. // parallel sum
  165. sum[tpitg] = 0.0f;
  166. for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
  167. sum[tpitg] += x[i00];
  168. }
  169. // reduce
  170. threadgroup_barrier(mem_flags::mem_threadgroup);
  171. for (uint i = ntg/2; i > 0; i /= 2) {
  172. if (tpitg < i) {
  173. sum[tpitg] += sum[tpitg + i];
  174. }
  175. threadgroup_barrier(mem_flags::mem_threadgroup);
  176. }
  177. // broadcast
  178. if (tpitg == 0) {
  179. sum[0] /= ne00;
  180. }
  181. threadgroup_barrier(mem_flags::mem_threadgroup);
  182. const float mean = sum[0];
  183. // recenter
  184. device float * y = dst + tgpig*ne00;
  185. for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
  186. y[i00] = x[i00] - mean;
  187. }
  188. // VARIANCE
  189. // parallel sum
  190. sum[tpitg] = 0.0f;
  191. for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
  192. sum[tpitg] += y[i00] * y[i00];
  193. }
  194. // reduce
  195. threadgroup_barrier(mem_flags::mem_threadgroup);
  196. for (uint i = ntg/2; i > 0; i /= 2) {
  197. if (tpitg < i) {
  198. sum[tpitg] += sum[tpitg + i];
  199. }
  200. threadgroup_barrier(mem_flags::mem_threadgroup);
  201. }
  202. // broadcast
  203. if (tpitg == 0) {
  204. sum[0] /= ne00;
  205. }
  206. threadgroup_barrier(mem_flags::mem_threadgroup);
  207. const float variance = sum[0];
  208. const float scale = 1.0f/sqrt(variance + eps);
  209. for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
  210. y[i00] = y[i00] * scale;
  211. }
  212. }
  213. kernel void kernel_rms_norm(
  214. device const void * src0,
  215. device float * dst,
  216. constant int64_t & ne00,
  217. constant uint64_t & nb01,
  218. constant float & eps,
  219. threadgroup float * sum [[threadgroup(0)]],
  220. uint tgpig[[threadgroup_position_in_grid]],
  221. uint tpitg[[thread_position_in_threadgroup]],
  222. uint sgitg[[simdgroup_index_in_threadgroup]],
  223. uint tiisg[[thread_index_in_simdgroup]],
  224. uint ntg[[threads_per_threadgroup]]) {
  225. device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
  226. device const float * x_scalar = (device const float *) x;
  227. float4 sumf=0;
  228. float all_sum=0;
  229. // parallel sum
  230. for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
  231. sumf += x[i00] * x[i00];
  232. }
  233. all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
  234. all_sum = simd_sum(all_sum);
  235. if (tiisg == 0) {
  236. sum[sgitg] = all_sum;
  237. }
  238. threadgroup_barrier(mem_flags::mem_threadgroup);
  239. // broadcast, simd group number is ntg / 32
  240. for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
  241. if (tpitg < i) {
  242. sum[tpitg] += sum[tpitg + i];
  243. }
  244. }
  245. if (tpitg == 0) {
  246. for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
  247. sum[0] /= ne00;
  248. }
  249. threadgroup_barrier(mem_flags::mem_threadgroup);
  250. const float mean = sum[0];
  251. const float scale = 1.0f/sqrt(mean + eps);
  252. device float4 * y = (device float4 *) (dst + tgpig*ne00);
  253. device float * y_scalar = (device float *) y;
  254. for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
  255. y[i00] = x[i00] * scale;
  256. }
  257. if (tpitg == 0) {
  258. for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
  259. }
  260. }
  261. // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
  262. // il indicates where the q4 quants begin (0 or QK4_0/4)
  263. // we assume that the yl's have been multiplied with the appropriate scale factor
  264. // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
  265. inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
  266. float d = qb_curr->d;
  267. float2 acc = 0.f;
  268. device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
  269. for (int i = 0; i < 8; i+=2) {
  270. acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
  271. + yl[i + 1] * (qs[i / 2] & 0x0F00);
  272. acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
  273. + yl[i + 9] * (qs[i / 2] & 0xF000);
  274. }
  275. return d * (sumy * -8.f + acc[0] + acc[1]);
  276. }
  277. // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
  278. // il indicates where the q4 quants begin (0 or QK4_0/4)
  279. // we assume that the yl's have been multiplied with the appropriate scale factor
  280. // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
  281. inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
  282. float d = qb_curr->d;
  283. float m = qb_curr->m;
  284. device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
  285. float2 acc = 0.f;
  286. for (int i = 0; i < 8; i+=2) {
  287. acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
  288. + yl[i + 1] * (qs[i / 2] & 0x0F00);
  289. acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
  290. + yl[i + 9] * (qs[i / 2] & 0xF000);
  291. }
  292. return d * (acc[0] + acc[1]) + sumy * m;
  293. }
  294. // putting them in the kernel cause a significant performance penalty
  295. #define N_DST 4 // each SIMD group works on 4 rows
  296. #define N_SIMDGROUP 2 // number of SIMD groups in a thread group
  297. #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
  298. //Note: This is a template, but strictly speaking it only applies to
  299. // quantizations where the block size is 32. It also does not
  300. // giard against the number of rows not being divisible by
  301. // N_DST, so this is another explicit assumption of the implementation.
  302. template<typename block_q_type, int nr, int nsg, int nw>
  303. void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
  304. int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
  305. uint3 tgpig, uint tiisg, uint sgitg) {
  306. const int nb = ne00/QK4_0;
  307. const int r0 = tgpig.x;
  308. const int r1 = tgpig.y;
  309. const int im = tgpig.z;
  310. const int first_row = (r0 * nsg + sgitg) * nr;
  311. const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
  312. device const block_q_type * x = (device const block_q_type *) src0 + offset0;
  313. device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
  314. float yl[16]; // src1 vector cache
  315. float sumf[nr]={0.f};
  316. const int ix = tiisg/2;
  317. const int il = 8*(tiisg%2);
  318. device const float * yb = y + ix * QK4_0 + il;
  319. // each thread in a SIMD group deals with half a block.
  320. for (int ib = ix; ib < nb; ib += nw/2) {
  321. float sumy = 0;
  322. for (int i = 0; i < 8; i += 2) {
  323. sumy += yb[i] + yb[i+1];
  324. yl[i+0] = yb[i+ 0];
  325. yl[i+1] = yb[i+ 1]/256.f;
  326. sumy += yb[i+16] + yb[i+17];
  327. yl[i+8] = yb[i+16]/16.f;
  328. yl[i+9] = yb[i+17]/4096.f;
  329. }
  330. for (int row = 0; row < nr; row++) {
  331. sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
  332. }
  333. yb += QK4_0 * 16;
  334. }
  335. for (int row = 0; row < nr; ++row) {
  336. const float tot = simd_sum(sumf[row]);
  337. if (tiisg == 0 && first_row + row < ne01) {
  338. dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
  339. }
  340. }
  341. }
  342. kernel void kernel_mul_mat_q4_0_f32(
  343. device const void * src0,
  344. device const float * src1,
  345. device float * dst,
  346. constant int64_t & ne00,
  347. constant int64_t & ne01[[buffer(4)]],
  348. constant int64_t & ne02[[buffer(5)]],
  349. constant int64_t & ne10[[buffer(9)]],
  350. constant int64_t & ne12[[buffer(11)]],
  351. constant int64_t & ne0[[buffer(15)]],
  352. constant int64_t & ne1[[buffer(16)]],
  353. constant uint & gqa[[buffer(17)]],
  354. uint3 tgpig[[threadgroup_position_in_grid]],
  355. uint tiisg[[thread_index_in_simdgroup]],
  356. uint sgitg[[simdgroup_index_in_threadgroup]]) {
  357. mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
  358. }
  359. kernel void kernel_mul_mat_q4_1_f32(
  360. device const void * src0,
  361. device const float * src1,
  362. device float * dst,
  363. constant int64_t & ne00,
  364. constant int64_t & ne01[[buffer(4)]],
  365. constant int64_t & ne02[[buffer(5)]],
  366. constant int64_t & ne10[[buffer(9)]],
  367. constant int64_t & ne12[[buffer(11)]],
  368. constant int64_t & ne0[[buffer(15)]],
  369. constant int64_t & ne1[[buffer(16)]],
  370. constant uint & gqa[[buffer(17)]],
  371. uint3 tgpig[[threadgroup_position_in_grid]],
  372. uint tiisg[[thread_index_in_simdgroup]],
  373. uint sgitg[[simdgroup_index_in_threadgroup]]) {
  374. mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
  375. }
  376. kernel void kernel_mul_mat_f16_f32(
  377. device const char * src0,
  378. device const char * src1,
  379. device float * dst,
  380. constant int64_t & ne00,
  381. constant int64_t & ne01,
  382. constant int64_t & ne02,
  383. constant uint64_t & nb00,
  384. constant uint64_t & nb01,
  385. constant uint64_t & nb02,
  386. constant int64_t & ne10,
  387. constant int64_t & ne11,
  388. constant int64_t & ne12,
  389. constant uint64_t & nb10,
  390. constant uint64_t & nb11,
  391. constant uint64_t & nb12,
  392. constant int64_t & ne0,
  393. constant int64_t & ne1,
  394. threadgroup float * sum [[threadgroup(0)]],
  395. uint3 tgpig[[threadgroup_position_in_grid]],
  396. uint3 tpig[[thread_position_in_grid]],
  397. uint3 tpitg[[thread_position_in_threadgroup]],
  398. uint3 tptg[[threads_per_threadgroup]]) {
  399. const int64_t r0 = tgpig.x;
  400. const int64_t r1 = tgpig.y;
  401. const int64_t im = tgpig.z;
  402. device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
  403. device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
  404. sum[tpitg.x] = 0.0f;
  405. for (int i = tpitg.x; i < ne00; i += tptg.x) {
  406. sum[tpitg.x] += (float) x[i] * (float) y[i];
  407. }
  408. // accumulate the sum from all threads in the threadgroup
  409. threadgroup_barrier(mem_flags::mem_threadgroup);
  410. for (uint i = tptg.x/2; i > 0; i /= 2) {
  411. if (tpitg.x < i) {
  412. sum[tpitg.x] += sum[tpitg.x + i];
  413. }
  414. threadgroup_barrier(mem_flags::mem_threadgroup);
  415. }
  416. if (tpitg.x == 0) {
  417. dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
  418. }
  419. }
  420. kernel void kernel_alibi_f32(
  421. device const float * src0,
  422. device float * dst,
  423. constant int64_t & ne00,
  424. constant int64_t & ne01,
  425. constant int64_t & ne02,
  426. constant int64_t & ne03,
  427. constant uint64_t & nb00,
  428. constant uint64_t & nb01,
  429. constant uint64_t & nb02,
  430. constant uint64_t & nb03,
  431. constant int64_t & ne0,
  432. constant int64_t & ne1,
  433. constant int64_t & ne2,
  434. constant int64_t & ne3,
  435. constant uint64_t & nb0,
  436. constant uint64_t & nb1,
  437. constant uint64_t & nb2,
  438. constant uint64_t & nb3,
  439. constant float & m0,
  440. uint3 tgpig[[threadgroup_position_in_grid]],
  441. uint3 tpitg[[thread_position_in_threadgroup]],
  442. uint3 ntg[[threads_per_threadgroup]]) {
  443. const int64_t i03 = tgpig[2];
  444. const int64_t i02 = tgpig[1];
  445. const int64_t i01 = tgpig[0];
  446. const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
  447. const int64_t i3 = n / (ne2*ne1*ne0);
  448. const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
  449. const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
  450. const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
  451. device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
  452. float m_k = pow(m0, i2 + 1);
  453. for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
  454. device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
  455. dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
  456. }
  457. }
  458. kernel void kernel_rope(
  459. device const void * src0,
  460. device float * dst,
  461. constant int64_t & ne00,
  462. constant int64_t & ne01,
  463. constant int64_t & ne02,
  464. constant int64_t & ne03,
  465. constant uint64_t & nb00,
  466. constant uint64_t & nb01,
  467. constant uint64_t & nb02,
  468. constant uint64_t & nb03,
  469. constant int64_t & ne0,
  470. constant int64_t & ne1,
  471. constant int64_t & ne2,
  472. constant int64_t & ne3,
  473. constant uint64_t & nb0,
  474. constant uint64_t & nb1,
  475. constant uint64_t & nb2,
  476. constant uint64_t & nb3,
  477. constant int & n_past,
  478. constant int & n_dims,
  479. constant int & mode,
  480. constant float & freq_base,
  481. constant float & freq_scale,
  482. uint3 tpig[[thread_position_in_grid]]) {
  483. const int64_t i3 = tpig[2];
  484. const int64_t i2 = tpig[1];
  485. const int64_t i1 = tpig[0];
  486. const bool is_neox = mode & 2;
  487. const float theta_scale = pow(freq_base, -2.0f/n_dims);
  488. const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2);
  489. float theta = freq_scale * (float)p;
  490. if (!is_neox) {
  491. for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
  492. const float cos_theta = cos(theta);
  493. const float sin_theta = sin(theta);
  494. theta *= theta_scale;
  495. device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
  496. device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
  497. const float x0 = src[0];
  498. const float x1 = src[1];
  499. dst_data[0] = x0*cos_theta - x1*sin_theta;
  500. dst_data[1] = x0*sin_theta + x1*cos_theta;
  501. }
  502. } else {
  503. // TODO: implement
  504. }
  505. }
  506. kernel void kernel_cpy_f16_f16(
  507. device const half * src0,
  508. device half * dst,
  509. constant int64_t & ne00,
  510. constant int64_t & ne01,
  511. constant int64_t & ne02,
  512. constant int64_t & ne03,
  513. constant uint64_t & nb00,
  514. constant uint64_t & nb01,
  515. constant uint64_t & nb02,
  516. constant uint64_t & nb03,
  517. constant int64_t & ne0,
  518. constant int64_t & ne1,
  519. constant int64_t & ne2,
  520. constant int64_t & ne3,
  521. constant uint64_t & nb0,
  522. constant uint64_t & nb1,
  523. constant uint64_t & nb2,
  524. constant uint64_t & nb3,
  525. uint3 tgpig[[threadgroup_position_in_grid]],
  526. uint3 tpitg[[thread_position_in_threadgroup]],
  527. uint3 ntg[[threads_per_threadgroup]]) {
  528. const int64_t i03 = tgpig[2];
  529. const int64_t i02 = tgpig[1];
  530. const int64_t i01 = tgpig[0];
  531. const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
  532. const int64_t i3 = n / (ne2*ne1*ne0);
  533. const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
  534. const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
  535. const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
  536. device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
  537. for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
  538. device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
  539. dst_data[i00] = src[0];
  540. }
  541. }
  542. kernel void kernel_cpy_f32_f16(
  543. device const float * src0,
  544. device half * dst,
  545. constant int64_t & ne00,
  546. constant int64_t & ne01,
  547. constant int64_t & ne02,
  548. constant int64_t & ne03,
  549. constant uint64_t & nb00,
  550. constant uint64_t & nb01,
  551. constant uint64_t & nb02,
  552. constant uint64_t & nb03,
  553. constant int64_t & ne0,
  554. constant int64_t & ne1,
  555. constant int64_t & ne2,
  556. constant int64_t & ne3,
  557. constant uint64_t & nb0,
  558. constant uint64_t & nb1,
  559. constant uint64_t & nb2,
  560. constant uint64_t & nb3,
  561. uint3 tgpig[[threadgroup_position_in_grid]],
  562. uint3 tpitg[[thread_position_in_threadgroup]],
  563. uint3 ntg[[threads_per_threadgroup]]) {
  564. const int64_t i03 = tgpig[2];
  565. const int64_t i02 = tgpig[1];
  566. const int64_t i01 = tgpig[0];
  567. const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
  568. const int64_t i3 = n / (ne2*ne1*ne0);
  569. const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
  570. const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
  571. const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
  572. device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
  573. for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
  574. device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
  575. dst_data[i00] = src[0];
  576. }
  577. }
  578. kernel void kernel_cpy_f32_f32(
  579. device const float * src0,
  580. device float * dst,
  581. constant int64_t & ne00,
  582. constant int64_t & ne01,
  583. constant int64_t & ne02,
  584. constant int64_t & ne03,
  585. constant uint64_t & nb00,
  586. constant uint64_t & nb01,
  587. constant uint64_t & nb02,
  588. constant uint64_t & nb03,
  589. constant int64_t & ne0,
  590. constant int64_t & ne1,
  591. constant int64_t & ne2,
  592. constant int64_t & ne3,
  593. constant uint64_t & nb0,
  594. constant uint64_t & nb1,
  595. constant uint64_t & nb2,
  596. constant uint64_t & nb3,
  597. uint3 tgpig[[threadgroup_position_in_grid]],
  598. uint3 tpitg[[thread_position_in_threadgroup]],
  599. uint3 ntg[[threads_per_threadgroup]]) {
  600. const int64_t i03 = tgpig[2];
  601. const int64_t i02 = tgpig[1];
  602. const int64_t i01 = tgpig[0];
  603. const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
  604. const int64_t i3 = n / (ne2*ne1*ne0);
  605. const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
  606. const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
  607. const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
  608. device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
  609. for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
  610. device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
  611. dst_data[i00] = src[0];
  612. }
  613. }
  614. //============================================ k-quants ======================================================
  615. #ifndef QK_K
  616. #define QK_K 256
  617. #else
  618. static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
  619. #endif
  620. #if QK_K == 256
  621. #define K_SCALE_SIZE 12
  622. #else
  623. #define K_SCALE_SIZE 4
  624. #endif
  625. typedef struct {
  626. uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
  627. uint8_t qs[QK_K/4]; // quants
  628. half d; // super-block scale for quantized scales
  629. half dmin; // super-block scale for quantized mins
  630. } block_q2_K;
  631. // 84 bytes / block
  632. typedef struct {
  633. uint8_t hmask[QK_K/8]; // quants - high bit
  634. uint8_t qs[QK_K/4]; // quants - low 2 bits
  635. #if QK_K == 64
  636. uint8_t scales[2];
  637. #else
  638. uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
  639. #endif
  640. half d; // super-block scale
  641. } block_q3_K;
  642. #if QK_K == 64
  643. typedef struct {
  644. half d[2]; // super-block scales/mins
  645. uint8_t scales[2];
  646. uint8_t qs[QK_K/2]; // 4-bit quants
  647. } block_q4_K;
  648. #else
  649. typedef struct {
  650. half d; // super-block scale for quantized scales
  651. half dmin; // super-block scale for quantized mins
  652. uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
  653. uint8_t qs[QK_K/2]; // 4--bit quants
  654. } block_q4_K;
  655. #endif
  656. #if QK_K == 64
  657. typedef struct {
  658. half d; // super-block scales/mins
  659. int8_t scales[QK_K/16]; // 8-bit block scales
  660. uint8_t qh[QK_K/8]; // quants, high bit
  661. uint8_t qs[QK_K/2]; // quants, low 4 bits
  662. } block_q5_K;
  663. #else
  664. typedef struct {
  665. half d; // super-block scale for quantized scales
  666. half dmin; // super-block scale for quantized mins
  667. uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
  668. uint8_t qh[QK_K/8]; // quants, high bit
  669. uint8_t qs[QK_K/2]; // quants, low 4 bits
  670. } block_q5_K;
  671. // 176 bytes / block
  672. #endif
  673. typedef struct {
  674. uint8_t ql[QK_K/2]; // quants, lower 4 bits
  675. uint8_t qh[QK_K/4]; // quants, upper 2 bits
  676. int8_t scales[QK_K/16]; // scales, quantized with 8 bits
  677. half d; // super-block scale
  678. } block_q6_K;
  679. // 210 bytes / block
  680. static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
  681. uchar4 r;
  682. if (j < 4) {
  683. r[0] = q[j+0] & 63;
  684. r[2] = q[j+1] & 63;
  685. r[1] = q[j+4] & 63;
  686. r[3] = q[j+5] & 63;
  687. } else {
  688. r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
  689. r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
  690. r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
  691. r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
  692. }
  693. return r;
  694. }
  695. //====================================== dot products =========================
  696. kernel void kernel_mul_mat_q2_K_f32(
  697. device const void * src0,
  698. device const float * src1,
  699. device float * dst,
  700. constant int64_t & ne00,
  701. constant int64_t & ne01[[buffer(4)]],
  702. constant int64_t & ne02[[buffer(5)]],
  703. constant int64_t & ne10[[buffer(9)]],
  704. constant int64_t & ne12[[buffer(11)]],
  705. constant int64_t & ne0[[buffer(15)]],
  706. constant int64_t & ne1[[buffer(16)]],
  707. constant uint & gqa[[buffer(17)]],
  708. uint3 tgpig[[threadgroup_position_in_grid]],
  709. uint tiisg[[thread_index_in_simdgroup]],
  710. uint sgitg[[simdgroup_index_in_threadgroup]]) {
  711. const int nb = ne00/QK_K;
  712. const int r0 = tgpig.x;
  713. const int r1 = tgpig.y;
  714. const int r2 = tgpig.z;
  715. const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
  716. const int ib_row = first_row * nb;
  717. const uint offset0 = r2/gqa*(nb*ne0);
  718. device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
  719. device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
  720. float yl[32];
  721. float sumf[N_DST]={0.f}, all_sum;
  722. const int step = sizeof(block_q2_K) * nb;
  723. #if QK_K == 256
  724. const int ix = tiisg/8; // 0...3
  725. const int it = tiisg%8; // 0...7
  726. const int im = it/4; // 0 or 1
  727. const int ir = it%4; // 0...3
  728. const int is = (8*ir)/16;// 0 or 1
  729. device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir;
  730. for (int ib = ix; ib < nb; ib += 4) {
  731. float4 sumy = {0.f, 0.f, 0.f, 0.f};
  732. for (int i = 0; i < 8; ++i) {
  733. yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
  734. yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
  735. yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
  736. yl[i+24] = y4[i+96]; sumy[3] += yl[i+24];
  737. }
  738. device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is;
  739. device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
  740. device const half * dh = &x[ib].d;
  741. for (int row = 0; row < N_DST; row++) {
  742. float4 acc1 = {0.f, 0.f, 0.f, 0.f};
  743. float4 acc2 = {0.f, 0.f, 0.f, 0.f};
  744. for (int i = 0; i < 8; i += 2) {
  745. acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
  746. acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
  747. acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
  748. acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
  749. acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
  750. acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
  751. acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
  752. acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
  753. }
  754. float dall = dh[0];
  755. float dmin = dh[1] * 1.f/16.f;
  756. sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
  757. (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f +
  758. (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f +
  759. (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
  760. dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
  761. qs += step/2;
  762. sc += step;
  763. dh += step/2;
  764. }
  765. y4 += 4 * QK_K;
  766. }
  767. #else
  768. const int ix = tiisg/2; // 0...15
  769. const int it = tiisg%2; // 0...1
  770. device const float * y4 = y + ix * QK_K + 8 * it;
  771. for (int ib = ix; ib < nb; ib += 16) {
  772. float4 sumy = {0.f, 0.f, 0.f, 0.f};
  773. for (int i = 0; i < 8; ++i) {
  774. yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
  775. yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8];
  776. yl[i+16] = y4[i+32]; sumy[2] += yl[i+16];
  777. yl[i+24] = y4[i+48]; sumy[3] += yl[i+24];
  778. }
  779. device const uint8_t * sc = (device const uint8_t *)x[ib].scales;
  780. device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
  781. device const half * dh = &x[ib].d;
  782. for (int row = 0; row < N_DST; row++) {
  783. float4 acc1 = {0.f, 0.f, 0.f, 0.f};
  784. float4 acc2 = {0.f, 0.f, 0.f, 0.f};
  785. for (int i = 0; i < 8; i += 2) {
  786. acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003);
  787. acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300);
  788. acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c);
  789. acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00);
  790. acc1[2] += yl[i+16] * (qs[i/2] & 0x0030);
  791. acc2[2] += yl[i+17] * (qs[i/2] & 0x3000);
  792. acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0);
  793. acc2[3] += yl[i+25] * (qs[i/2] & 0xc000);
  794. }
  795. float dall = dh[0];
  796. float dmin = dh[1];
  797. sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f +
  798. (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f +
  799. (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f +
  800. (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) -
  801. dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4));
  802. qs += step/2;
  803. sc += step;
  804. dh += step/2;
  805. }
  806. y4 += 16 * QK_K;
  807. }
  808. #endif
  809. for (int row = 0; row < N_DST; ++row) {
  810. all_sum = simd_sum(sumf[row]);
  811. if (tiisg == 0) {
  812. dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
  813. }
  814. }
  815. }
  816. #if QK_K == 256
  817. kernel void kernel_mul_mat_q3_K_f32(
  818. device const void * src0,
  819. device const float * src1,
  820. device float * dst,
  821. constant int64_t & ne00,
  822. constant int64_t & ne01[[buffer(4)]],
  823. constant int64_t & ne02[[buffer(5)]],
  824. constant int64_t & ne10[[buffer(9)]],
  825. constant int64_t & ne12[[buffer(11)]],
  826. constant int64_t & ne0[[buffer(15)]],
  827. constant int64_t & ne1[[buffer(16)]],
  828. constant uint & gqa[[buffer(17)]],
  829. uint3 tgpig[[threadgroup_position_in_grid]],
  830. uint tiisg[[thread_index_in_simdgroup]],
  831. uint sgitg[[simdgroup_index_in_threadgroup]]) {
  832. const int nb = ne00/QK_K;
  833. const int64_t r0 = tgpig.x;
  834. const int64_t r1 = tgpig.y;
  835. const int64_t r2 = tgpig.z;
  836. const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
  837. const uint offset0 = r2/gqa*(nb*ne0);
  838. device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
  839. device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
  840. float yl[16];
  841. const uint16_t kmask1 = 0x0303;
  842. const uint16_t kmask2 = 0x0f0f;
  843. const int tid = tiisg/2;
  844. const int ix = tiisg%2;
  845. const int ip = tid/8; // 0 or 1
  846. const int il = tid/2 - 4*ip; // 0...3
  847. const int ir = tid%2;
  848. const int n = 8;
  849. const int l0 = n*ir;
  850. const uint16_t m1 = 1 << (4*ip + il);
  851. const uint16_t m2 = m1 << 8;
  852. const int shift = 2*il;
  853. const uint16_t qm1 = 0x0003 << shift;
  854. const uint16_t qm2 = 0x0300 << shift;
  855. const int32_t v1 = 4 << shift;
  856. const int32_t v2 = 1024 << shift;
  857. const uint16_t s_shift1 = 4*ip;
  858. const uint16_t s_shift2 = s_shift1 + 2*(il/2);
  859. const int ik = 4 + (il%2);
  860. const int q_offset = 32*ip + l0;
  861. const int y_offset = 128*ip + 32*il + l0;
  862. const int step = sizeof(block_q3_K) * nb / 2;
  863. device const float * y1 = yy + ix*QK_K + y_offset;
  864. float sumf1[2] = {0.f}, sumf2[2] = {0.f};
  865. for (int i = ix; i < nb; i += 2) {
  866. for (int l = 0; l < 8; ++l) {
  867. yl[l+0] = y1[l+ 0];
  868. yl[l+8] = y1[l+16];
  869. }
  870. device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
  871. device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
  872. device const uint16_t * a = (device const uint16_t *)(x[i].scales);
  873. device const half * dh = &x[i].d;
  874. for (int row = 0; row < 2; ++row) {
  875. const float d_all = (float)dh[0];
  876. const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
  877. float s1 = 0, s2 = 0;
  878. for (int l = 0; l < n; l += 2) {
  879. const uint16_t qs = q[l/2];
  880. s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
  881. s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
  882. }
  883. float d = d_all * (s1 + 1.f/256.f * s2);
  884. sumf1[row] += d * scales[0];
  885. sumf2[row] += d;
  886. s1 = s2 = 0;
  887. for (int l = 0; l < n; l += 2) {
  888. const uint16_t qs = q[l/2+8];
  889. s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
  890. s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
  891. }
  892. d = d_all * (s1 + 1.f/256.f * s2);
  893. sumf1[row] += d * scales[1];
  894. sumf2[row] += d;
  895. q += step;
  896. h += step;
  897. a += step;
  898. dh += step;
  899. }
  900. y1 += 2 * QK_K;
  901. }
  902. for (int row = 0; row < 2; ++row) {
  903. const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
  904. const float tot = simd_sum(sumf);
  905. if (tiisg == 0) {
  906. dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
  907. }
  908. }
  909. }
  910. #else
  911. kernel void kernel_mul_mat_q3_K_f32(
  912. device const void * src0,
  913. device const float * src1,
  914. device float * dst,
  915. constant int64_t & ne00,
  916. constant int64_t & ne01[[buffer(4)]],
  917. constant int64_t & ne02[[buffer(5)]],
  918. constant int64_t & ne10[[buffer(9)]],
  919. constant int64_t & ne12[[buffer(11)]],
  920. constant int64_t & ne0[[buffer(15)]],
  921. constant int64_t & ne1[[buffer(16)]],
  922. constant uint & gqa[[buffer(17)]],
  923. uint3 tgpig[[threadgroup_position_in_grid]],
  924. uint tiisg[[thread_index_in_simdgroup]],
  925. uint sgitg[[simdgroup_index_in_threadgroup]]) {
  926. const int nb = ne00/QK_K;
  927. const int64_t r0 = tgpig.x;
  928. const int64_t r1 = tgpig.y;
  929. const int64_t r2 = tgpig.z;
  930. const int row = 2 * r0 + sgitg;
  931. const uint offset0 = r2/gqa*(nb*ne0);
  932. device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
  933. device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
  934. const int ix = tiisg/4;
  935. const int il = 4 * (tiisg%4);// 0, 4, 8, 12
  936. const int im = il/8; // 0, 0, 1, 1
  937. const int in = il%8; // 0, 4, 0, 4
  938. float2 sum = {0.f, 0.f};
  939. for (int i = ix; i < nb; i += 8) {
  940. const float d_all = (float)(x[i].d);
  941. device const uint16_t * q = (device const uint16_t *)(x[i].qs + il);
  942. device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in);
  943. device const uint16_t * s = (device const uint16_t *)(x[i].scales);
  944. device const float * y = yy + i * QK_K + il;
  945. const float d1 = d_all * ((int32_t)(s[0] & 0x000F) - 8);
  946. const float d2 = d_all * ((int32_t)(s[0] & 0x00F0) - 128) * 1.f/64.f;
  947. const float d3 = d_all * ((int32_t)(s[0] & 0x0F00) - 2048) * 1.f/4096.f;
  948. const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f;
  949. for (int l = 0; l < 4; l += 2) {
  950. const uint16_t hm = h[l/2] >> im;
  951. sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4))
  952. + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16))
  953. + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64))
  954. + y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256));
  955. sum[1] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024))
  956. + y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096))
  957. + y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384))
  958. + y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536));
  959. }
  960. }
  961. const float sumf = sum[0] + sum[1] * 1.f/256.f;
  962. const float tot = simd_sum(sumf);
  963. if (tiisg == 0) {
  964. dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
  965. }
  966. }
  967. #endif
  968. #if QK_K == 256
  969. kernel void kernel_mul_mat_q4_K_f32(
  970. device const void * src0,
  971. device const float * src1,
  972. device float * dst,
  973. constant int64_t & ne00,
  974. constant int64_t & ne01[[buffer(4)]],
  975. constant int64_t & ne02[[buffer(5)]],
  976. constant int64_t & ne10[[buffer(9)]],
  977. constant int64_t & ne12[[buffer(11)]],
  978. constant int64_t & ne0[[buffer(15)]],
  979. constant int64_t & ne1[[buffer(16)]],
  980. constant uint & gqa[[buffer(17)]],
  981. uint3 tgpig[[threadgroup_position_in_grid]],
  982. uint tiisg[[thread_index_in_simdgroup]],
  983. uint sgitg[[simdgroup_index_in_threadgroup]]) {
  984. const uint16_t kmask1 = 0x3f3f;
  985. const uint16_t kmask2 = 0x0f0f;
  986. const uint16_t kmask3 = 0xc0c0;
  987. const int ix = tiisg/8; // 0...3
  988. const int it = tiisg%8; // 0...7
  989. const int im = it/4; // 0 or 1
  990. const int ir = it%4; // 0...3
  991. const int nb = ne00/QK_K;
  992. const int r0 = tgpig.x;
  993. const int r1 = tgpig.y;
  994. const int r2 = tgpig.z;
  995. const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
  996. const int ib_row = first_row * nb;
  997. const uint offset0 = r2/gqa*(nb*ne0);
  998. device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
  999. device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
  1000. float yl[16];
  1001. float yh[16];
  1002. float sumf[N_DST]={0.f}, all_sum;
  1003. const int step = sizeof(block_q4_K) * nb / 2;
  1004. device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir;
  1005. uint16_t sc16[4];
  1006. thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
  1007. for (int ib = ix; ib < nb; ib += 4) {
  1008. float4 sumy = {0.f, 0.f, 0.f, 0.f};
  1009. for (int i = 0; i < 8; ++i) {
  1010. yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0];
  1011. yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
  1012. yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
  1013. yh[i+8] = y4[i+160]; sumy[3] += yh[i+8];
  1014. }
  1015. device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im;
  1016. device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir;
  1017. device const half * dh = &x[ib].d;
  1018. for (int row = 0; row < N_DST; row++) {
  1019. sc16[0] = sc[0] & kmask1;
  1020. sc16[1] = sc[2] & kmask1;
  1021. sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
  1022. sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
  1023. device const uint16_t * q2 = q1 + 32;
  1024. float4 acc1 = {0.f, 0.f, 0.f, 0.f};
  1025. float4 acc2 = {0.f, 0.f, 0.f, 0.f};
  1026. for (int i = 0; i < 8; i += 2) {
  1027. acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
  1028. acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
  1029. acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
  1030. acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
  1031. acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
  1032. acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
  1033. acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
  1034. acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
  1035. }
  1036. float dall = dh[0];
  1037. float dmin = dh[1];
  1038. sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
  1039. (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
  1040. (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
  1041. (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
  1042. dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
  1043. q1 += step;
  1044. sc += step;
  1045. dh += step;
  1046. }
  1047. y4 += 4 * QK_K;
  1048. }
  1049. for (int row = 0; row < N_DST; ++row) {
  1050. all_sum = simd_sum(sumf[row]);
  1051. if (tiisg == 0) {
  1052. dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
  1053. }
  1054. }
  1055. }
  1056. #else
  1057. kernel void kernel_mul_mat_q4_K_f32(
  1058. device const void * src0,
  1059. device const float * src1,
  1060. device float * dst,
  1061. constant int64_t & ne00,
  1062. constant int64_t & ne01[[buffer(4)]],
  1063. constant int64_t & ne02[[buffer(5)]],
  1064. constant int64_t & ne10[[buffer(9)]],
  1065. constant int64_t & ne12[[buffer(11)]],
  1066. constant int64_t & ne0[[buffer(15)]],
  1067. constant int64_t & ne1[[buffer(16)]],
  1068. constant uint & gqa[[buffer(17)]],
  1069. uint3 tgpig[[threadgroup_position_in_grid]],
  1070. uint tiisg[[thread_index_in_simdgroup]],
  1071. uint sgitg[[simdgroup_index_in_threadgroup]]) {
  1072. const int ix = tiisg/4; // 0...7
  1073. const int it = tiisg%4; // 0...3
  1074. const int nb = ne00/QK_K;
  1075. const int r0 = tgpig.x;
  1076. const int r1 = tgpig.y;
  1077. const int r2 = tgpig.z;
  1078. const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
  1079. const int ib_row = first_row * nb;
  1080. const uint offset0 = r2/gqa*(nb*ne0);
  1081. device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
  1082. device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
  1083. float yl[8];
  1084. float yh[8];
  1085. float sumf[N_DST]={0.f}, all_sum;
  1086. const int step = sizeof(block_q4_K) * nb / 2;
  1087. device const float * y4 = y + ix * QK_K + 8 * it;
  1088. uint16_t sc16[4];
  1089. for (int ib = ix; ib < nb; ib += 8) {
  1090. float2 sumy = {0.f, 0.f};
  1091. for (int i = 0; i < 8; ++i) {
  1092. yl[i] = y4[i+ 0]; sumy[0] += yl[i];
  1093. yh[i] = y4[i+32]; sumy[1] += yh[i];
  1094. }
  1095. device const uint16_t * sc = (device const uint16_t *)x[ib].scales;
  1096. device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it;
  1097. device const half * dh = x[ib].d;
  1098. for (int row = 0; row < N_DST; row++) {
  1099. sc16[0] = sc[0] & 0x000f;
  1100. sc16[1] = sc[0] & 0x0f00;
  1101. sc16[2] = sc[0] & 0x00f0;
  1102. sc16[3] = sc[0] & 0xf000;
  1103. float2 acc1 = {0.f, 0.f};
  1104. float2 acc2 = {0.f, 0.f};
  1105. for (int i = 0; i < 8; i += 2) {
  1106. acc1[0] += yl[i+0] * (qs[i/2] & 0x000F);
  1107. acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00);
  1108. acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0);
  1109. acc2[1] += yh[i+1] * (qs[i/2] & 0xF000);
  1110. }
  1111. float dall = dh[0];
  1112. float dmin = dh[1];
  1113. sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] +
  1114. (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) -
  1115. dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f);
  1116. qs += step;
  1117. sc += step;
  1118. dh += step;
  1119. }
  1120. y4 += 8 * QK_K;
  1121. }
  1122. for (int row = 0; row < N_DST; ++row) {
  1123. all_sum = simd_sum(sumf[row]);
  1124. if (tiisg == 0) {
  1125. dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
  1126. }
  1127. }
  1128. }
  1129. #endif
  1130. kernel void kernel_mul_mat_q5_K_f32(
  1131. device const void * src0,
  1132. device const float * src1,
  1133. device float * dst,
  1134. constant int64_t & ne00,
  1135. constant int64_t & ne01[[buffer(4)]],
  1136. constant int64_t & ne02[[buffer(5)]],
  1137. constant int64_t & ne10[[buffer(9)]],
  1138. constant int64_t & ne12[[buffer(11)]],
  1139. constant int64_t & ne0[[buffer(15)]],
  1140. constant int64_t & ne1[[buffer(16)]],
  1141. constant uint & gqa[[buffer(17)]],
  1142. uint3 tgpig[[threadgroup_position_in_grid]],
  1143. uint tiisg[[thread_index_in_simdgroup]],
  1144. uint sgitg[[simdgroup_index_in_threadgroup]]) {
  1145. const int nb = ne00/QK_K;
  1146. const int64_t r0 = tgpig.x;
  1147. const int64_t r1 = tgpig.y;
  1148. const int r2 = tgpig.z;
  1149. const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
  1150. const uint offset0 = r2/gqa*(nb*ne0);
  1151. device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
  1152. device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
  1153. float sumf[2]={0.f};
  1154. const int step = sizeof(block_q5_K) * nb;
  1155. #if QK_K == 256
  1156. #
  1157. float yl[16], yh[16];
  1158. const uint16_t kmask1 = 0x3f3f;
  1159. const uint16_t kmask2 = 0x0f0f;
  1160. const uint16_t kmask3 = 0xc0c0;
  1161. const int tid = tiisg/4;
  1162. const int ix = tiisg%4;
  1163. const int im = tid/4;
  1164. const int ir = tid%4;
  1165. const int n = 8;
  1166. const int l0 = n*ir;
  1167. const int q_offset = 32*im + l0;
  1168. const int y_offset = 64*im + l0;
  1169. const uint8_t hm1 = 1u << (2*im);
  1170. const uint8_t hm2 = hm1 << 1;
  1171. const uint8_t hm3 = hm1 << 4;
  1172. const uint8_t hm4 = hm2 << 4;
  1173. uint16_t sc16[4];
  1174. thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
  1175. device const float * y1 = yy + ix*QK_K + y_offset;
  1176. for (int i = ix; i < nb; i += 4) {
  1177. device const uint8_t * q1 = x[i].qs + q_offset;
  1178. device const uint8_t * qh = x[i].qh + l0;
  1179. device const half * dh = &x[i].d;
  1180. device const uint16_t * a = (device const uint16_t *)x[i].scales + im;
  1181. device const float * y2 = y1 + 128;
  1182. float4 sumy = {0.f, 0.f, 0.f, 0.f};
  1183. for (int l = 0; l < 8; ++l) {
  1184. yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
  1185. yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
  1186. yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
  1187. yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
  1188. }
  1189. for (int row = 0; row < 2; ++row) {
  1190. device const uint8_t * q2 = q1 + 64;
  1191. sc16[0] = a[0] & kmask1;
  1192. sc16[1] = a[2] & kmask1;
  1193. sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
  1194. sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
  1195. float4 acc = {0.f, 0.f, 0.f, 0.f};
  1196. for (int l = 0; l < n; ++l) {
  1197. uint8_t h = qh[l];
  1198. acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
  1199. acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
  1200. acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
  1201. acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
  1202. }
  1203. const float dall = dh[0];
  1204. const float dmin = dh[1];
  1205. sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
  1206. dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
  1207. q1 += step;
  1208. qh += step;
  1209. dh += step/2;
  1210. a += step/2;
  1211. }
  1212. y1 += 4 * QK_K;
  1213. }
  1214. #else
  1215. float yl[8], yh[8];
  1216. const int il = 4 * (tiisg/8); // 0, 4, 8, 12
  1217. const int ix = tiisg%8;
  1218. const int im = il/8; // 0, 0, 1, 1
  1219. const int in = il%8; // 0, 4, 0, 4
  1220. device const float * y = yy + ix*QK_K + il;
  1221. for (int i = ix; i < nb; i += 8) {
  1222. for (int l = 0; l < 4; ++l) {
  1223. yl[l+0] = y[l+ 0];
  1224. yl[l+4] = y[l+16];
  1225. yh[l+0] = y[l+32];
  1226. yh[l+4] = y[l+48];
  1227. }
  1228. device const half * dh = &x[i].d;
  1229. device const uint8_t * q = x[i].qs + il;
  1230. device const uint8_t * h = x[i].qh + in;
  1231. device const int8_t * s = x[i].scales;
  1232. for (int row = 0; row < 2; ++row) {
  1233. const float d = dh[0];
  1234. float2 acc = {0.f, 0.f};
  1235. for (int l = 0; l < 4; ++l) {
  1236. const uint8_t hl = h[l] >> im;
  1237. acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16))
  1238. + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16));
  1239. acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256))
  1240. + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256));
  1241. }
  1242. sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]);
  1243. q += step;
  1244. h += step;
  1245. s += step;
  1246. dh += step/2;
  1247. }
  1248. y += 8 * QK_K;
  1249. }
  1250. #endif
  1251. for (int row = 0; row < 2; ++row) {
  1252. const float tot = simd_sum(sumf[row]);
  1253. if (tiisg == 0) {
  1254. dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
  1255. }
  1256. }
  1257. }
  1258. kernel void kernel_mul_mat_q6_K_f32(
  1259. device const void * src0,
  1260. device const float * src1,
  1261. device float * dst,
  1262. constant int64_t & ne00,
  1263. constant int64_t & ne01[[buffer(4)]],
  1264. constant int64_t & ne02[[buffer(5)]],
  1265. constant int64_t & ne10[[buffer(9)]],
  1266. constant int64_t & ne12[[buffer(11)]],
  1267. constant int64_t & ne0[[buffer(15)]],
  1268. constant int64_t & ne1[[buffer(16)]],
  1269. constant uint & gqa[[buffer(17)]],
  1270. uint3 tgpig[[threadgroup_position_in_grid]],
  1271. uint tiisg[[thread_index_in_simdgroup]],
  1272. uint sgitg[[simdgroup_index_in_threadgroup]]) {
  1273. const uint8_t kmask1 = 0x03;
  1274. const uint8_t kmask2 = 0x0C;
  1275. const uint8_t kmask3 = 0x30;
  1276. const uint8_t kmask4 = 0xC0;
  1277. const int nb = ne00/QK_K;
  1278. const int64_t r0 = tgpig.x;
  1279. const int64_t r1 = tgpig.y;
  1280. const int r2 = tgpig.z;
  1281. const int row = 2 * r0 + sgitg;
  1282. const uint offset0 = r2/gqa*(nb*ne0);
  1283. device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
  1284. device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
  1285. float sumf = 0;
  1286. #if QK_K == 256
  1287. const int tid = tiisg/2;
  1288. const int ix = tiisg%2;
  1289. const int ip = tid/8; // 0 or 1
  1290. const int il = tid%8;
  1291. const int n = 4;
  1292. const int l0 = n*il;
  1293. const int is = 8*ip + l0/16;
  1294. const int y_offset = 128*ip + l0;
  1295. const int q_offset_l = 64*ip + l0;
  1296. const int q_offset_h = 32*ip + l0;
  1297. for (int i = ix; i < nb; i += 2) {
  1298. device const uint8_t * q1 = x[i].ql + q_offset_l;
  1299. device const uint8_t * q2 = q1 + 32;
  1300. device const uint8_t * qh = x[i].qh + q_offset_h;
  1301. device const int8_t * sc = x[i].scales + is;
  1302. device const float * y = yy + i * QK_K + y_offset;
  1303. const float dall = x[i].d;
  1304. float4 sums = {0.f, 0.f, 0.f, 0.f};
  1305. for (int l = 0; l < n; ++l) {
  1306. sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
  1307. sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
  1308. sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
  1309. sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
  1310. }
  1311. sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
  1312. }
  1313. #else
  1314. const int ix = tiisg/4;
  1315. const int il = 4*(tiisg%4);
  1316. for (int i = ix; i < nb; i += 8) {
  1317. device const float * y = yy + i * QK_K + il;
  1318. device const uint8_t * ql = x[i].ql + il;
  1319. device const uint8_t * qh = x[i].qh + il;
  1320. device const int8_t * s = x[i].scales;
  1321. const float d = x[i].d;
  1322. float4 sums = {0.f, 0.f, 0.f, 0.f};
  1323. for (int l = 0; l < 4; ++l) {
  1324. sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
  1325. sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
  1326. sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
  1327. sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
  1328. }
  1329. sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
  1330. }
  1331. #endif
  1332. const float tot = simd_sum(sumf);
  1333. if (tiisg == 0) {
  1334. dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
  1335. }
  1336. }
  1337. //============================= templates and their specializations =============================
  1338. template <typename type4x4>
  1339. void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
  1340. half4x4 temp = *(((device half4x4 *)src));
  1341. for (int i = 0; i < 16; i++){
  1342. reg[i/4][i%4] = temp[i/4][i%4];
  1343. }
  1344. }
  1345. template <typename type4x4>
  1346. void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
  1347. device const uint16_t * qs = ((device const uint16_t *)xb + 1);
  1348. const half d = il ? (xb->d / 16.h) : xb->d;
  1349. const half m = il ? (-8.h * 16.h) : -8.h;
  1350. const ushort mask0 = il ? 0x00F0 : 0x000F;
  1351. const ushort mask1 = il ? 0xF000 : 0x0F00;
  1352. for (int i=0;i<8;i++) {
  1353. reg[i/2][2*(i%2)] = (((qs[i] & mask0)) + m) * d;
  1354. reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
  1355. }
  1356. }
  1357. template <typename type4x4>
  1358. void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
  1359. device const uint16_t * qs = ((device const uint16_t *)xb + 2);
  1360. const half d = il ? (xb->d / 16.h) : xb->d;
  1361. const half m = xb->m;
  1362. const ushort mask0 = il ? 0x00F0 : 0x000F;
  1363. const ushort mask1 = il ? 0xF000 : 0x0F00;
  1364. for (int i=0;i<8;i++) {
  1365. reg[i/2][2*(i%2)] = (((qs[i] & mask0)) * d) + m;
  1366. reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
  1367. }
  1368. }
  1369. template <typename type4x4>
  1370. void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
  1371. const half d = xb->d;
  1372. const half min = xb->dmin;
  1373. device const uint8_t * q = (device const uint8_t *)xb->qs;
  1374. half dl, ml;
  1375. uint8_t sc = xb->scales[il];
  1376. #if QK_K == 256
  1377. q = q + 32*(il/8) + 16*(il&1);
  1378. il = (il/2)%4;
  1379. #endif
  1380. half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
  1381. uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
  1382. dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
  1383. for (int i = 0; i < 16; ++i) {
  1384. reg[i/4][i%4] = dl * (q[i] & mask) - ml;
  1385. }
  1386. }
  1387. template <typename type4x4>
  1388. void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
  1389. const float d_all = (float)(xb->d);
  1390. device const uint8_t * q = (device const uint8_t *)xb->qs;
  1391. device const uint8_t * h = (device const uint8_t *)xb->hmask;
  1392. device const int8_t * scales = (device const int8_t *)xb->scales;
  1393. #if QK_K == 256
  1394. q = q + 32 * (il/8) + 16 * (il&1);
  1395. h = h + 16 * (il&1);
  1396. uint8_t m = 1 << (il/2);
  1397. uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
  1398. ((il/4)>0 ? 12 : 3);
  1399. uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
  1400. uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
  1401. int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
  1402. (scale_2&kmask2) | ((scale_1&kmask1) << 4);
  1403. float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
  1404. il = (il/2)%4;
  1405. float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
  1406. uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
  1407. for (int i = 0; i < 16; ++i) {
  1408. reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
  1409. }
  1410. #else
  1411. float kcoef = il&1 ? 1.f/16.f : 1.f;
  1412. uint16_t kmask = il&1 ? 0xF0 : 0x0F;
  1413. float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
  1414. float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
  1415. uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
  1416. uint8_t m = 1<<(il*2);
  1417. for (int i = 0; i < 16; ++i) {
  1418. reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
  1419. }
  1420. #endif
  1421. }
  1422. template <typename type4x4>
  1423. void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
  1424. device const uint8_t * q = xb->qs;
  1425. #if QK_K == 256
  1426. const float d = (float)(xb->d);
  1427. const float min = (float)(xb->dmin);
  1428. short is = (il/4) * 2;
  1429. q = q + (il/4) * 32 + 16 * (il&1);
  1430. il = il%4;
  1431. const uchar4 sc = get_scale_min_k4(is, xb->scales);
  1432. const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
  1433. const float ml = il<2 ? min * sc[1] : min * sc[3];
  1434. #else
  1435. q = q + 16 * (il&1);
  1436. device const uint8_t * s = xb->scales;
  1437. device const half2 * dh = (device const half2 *)xb->d;
  1438. const float2 d = (float2)dh[0];
  1439. const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
  1440. const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
  1441. #endif
  1442. const ushort mask = il<2 ? 0x0F : 0xF0;
  1443. for (int i = 0; i < 16; ++i) {
  1444. reg[i/4][i%4] = dl * (q[i] & mask) - ml;
  1445. }
  1446. }
  1447. template <typename type4x4>
  1448. void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
  1449. device const uint8_t * q = xb->qs;
  1450. device const uint8_t * qh = xb->qh;
  1451. #if QK_K == 256
  1452. const float d = (float)(xb->d);
  1453. const float min = (float)(xb->dmin);
  1454. short is = (il/4) * 2;
  1455. q = q + 32 * (il/4) + 16 * (il&1);
  1456. qh = qh + 16 * (il&1);
  1457. uint8_t ul = 1 << (il/2);
  1458. il = il%4;
  1459. const uchar4 sc = get_scale_min_k4(is, xb->scales);
  1460. const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
  1461. const float ml = il<2 ? min * sc[1] : min * sc[3];
  1462. const ushort mask = il<2 ? 0x0F : 0xF0;
  1463. const float qh_val = il<2 ? 16.f : 256.f;
  1464. for (int i = 0; i < 16; ++i) {
  1465. reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
  1466. }
  1467. #else
  1468. q = q + 16 * (il&1);
  1469. device const int8_t * s = xb->scales;
  1470. const float dl = xb->d * s[il];
  1471. uint8_t m = 1<<(il*2);
  1472. const float coef = il<2 ? 1.f : 1.f/16.f;
  1473. const ushort mask = il<2 ? 0x0F : 0xF0;
  1474. for (int i = 0; i < 16; ++i) {
  1475. reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
  1476. }
  1477. #endif
  1478. }
  1479. template <typename type4x4>
  1480. void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
  1481. const float d_all = (float)(xb->d);
  1482. device const uint8_t * ql = (device const uint8_t *)xb->ql;
  1483. device const uint8_t * qh = (device const uint8_t *)xb->qh;
  1484. device const int8_t * scales = (device const int8_t *)xb->scales;
  1485. #if QK_K == 256
  1486. ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
  1487. qh = qh + 32*(il/8) + 16*(il&1);
  1488. float sc = scales[(il%2) + 2 * ((il/2))];
  1489. il = (il/2)%4;
  1490. #else
  1491. ql = ql + 16 * (il&1);
  1492. float sc = scales[il];
  1493. #endif
  1494. for (int i = 0; i < 16; ++i) {
  1495. uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
  1496. uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
  1497. const float coef = il>1 ? 1.f/16.f : 1.f;
  1498. float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
  1499. ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
  1500. reg[i/4][i%4] = d_all * sc * q * coef;
  1501. }
  1502. }
  1503. template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
  1504. kernel void kernel_get_rows(
  1505. device const void * src0,
  1506. device const int * src1,
  1507. device float * dst,
  1508. constant int64_t & ne00,
  1509. constant uint64_t & nb01,
  1510. constant uint64_t & nb1,
  1511. uint tgpig[[threadgroup_position_in_grid]],
  1512. uint tiitg[[thread_index_in_threadgroup]],
  1513. uint tptg[[threads_per_threadgroup]]) {
  1514. const int i = tgpig;
  1515. const int r = ((device int32_t *) src1)[i];
  1516. for (int ind = tiitg; ind < ne00/16; ind += tptg) {
  1517. float4x4 temp;
  1518. dequantize_func(
  1519. ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
  1520. *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
  1521. }
  1522. }
  1523. #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
  1524. #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
  1525. #define BLOCK_SIZE_K 32
  1526. #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
  1527. #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
  1528. #define THREAD_PER_BLOCK 128
  1529. #define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
  1530. #define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
  1531. #define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
  1532. #define SG_MAT_ROW 8
  1533. // each block_q contains 16*nl weights
  1534. template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
  1535. kernel void kernel_mul_mm(device const uchar * src0,
  1536. device const float * src1,
  1537. device float * dst,
  1538. constant int64_t & ne00,
  1539. constant int64_t & ne02,
  1540. constant int64_t & nb01,
  1541. constant int64_t & nb02,
  1542. constant int64_t & ne12,
  1543. constant int64_t & ne0,
  1544. constant int64_t & ne1,
  1545. constant uint & gqa,
  1546. threadgroup uchar * shared_memory [[threadgroup(0)]],
  1547. uint3 tgpig[[threadgroup_position_in_grid]],
  1548. uint tiitg[[thread_index_in_threadgroup]],
  1549. uint sgitg[[simdgroup_index_in_threadgroup]]) {
  1550. threadgroup half * sa = ((threadgroup half *)shared_memory);
  1551. threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
  1552. const uint r0 = tgpig.y;
  1553. const uint r1 = tgpig.x;
  1554. const uint im = tgpig.z;
  1555. // if this block is of 64x32 shape or smaller
  1556. short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
  1557. short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
  1558. // a thread shouldn't load data outside of the matrix
  1559. short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
  1560. short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
  1561. simdgroup_half8x8 ma[4];
  1562. simdgroup_float8x8 mb[2];
  1563. simdgroup_float8x8 c_res[8];
  1564. for (int i = 0; i < 8; i++){
  1565. c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
  1566. }
  1567. short il = (tiitg % THREAD_PER_ROW);
  1568. uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
  1569. device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
  1570. device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
  1571. + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
  1572. for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
  1573. //load data and store to threadgroup memory
  1574. half4x4 temp_a;
  1575. dequantize_func(x, il, temp_a);
  1576. #pragma unroll(16)
  1577. for (int i = 0; i < 16; i++) {
  1578. *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
  1579. + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
  1580. + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
  1581. }
  1582. *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
  1583. = *((device float2x4 *)y);
  1584. il = (il + 2 < nl) ? il + 2 : il % 2;
  1585. x = (il < 2) ? x + (2+nl-1)/nl : x;
  1586. y += BLOCK_SIZE_K;
  1587. threadgroup_barrier(mem_flags::mem_threadgroup);
  1588. //load matrices from threadgroup memory and conduct outer products
  1589. threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
  1590. threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
  1591. #pragma unroll(4)
  1592. for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
  1593. #pragma unroll(4)
  1594. for (int i = 0; i < 4; i++) {
  1595. simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
  1596. }
  1597. simdgroup_barrier(mem_flags::mem_none);
  1598. #pragma unroll(2)
  1599. for (int i = 0; i < 2; i++) {
  1600. simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
  1601. }
  1602. lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
  1603. lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
  1604. #pragma unroll(8)
  1605. for (int i = 0; i < 8; i++){
  1606. simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
  1607. }
  1608. }
  1609. }
  1610. if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
  1611. device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
  1612. + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
  1613. for (int i = 0; i < 8; i++) {
  1614. simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
  1615. }
  1616. } else {
  1617. // block is smaller than 64x32, we should avoid writing data outside of the matrix
  1618. threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
  1619. + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
  1620. for (int i = 0; i < 8; i++) {
  1621. simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
  1622. }
  1623. threadgroup_barrier(mem_flags::mem_threadgroup);
  1624. device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
  1625. if (sgitg==0) {
  1626. for (int i = 0; i < n_rows; i++) {
  1627. for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
  1628. *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
  1629. }
  1630. }
  1631. }
  1632. }
  1633. }
  1634. #if QK_K == 256
  1635. #define QK_NL 16
  1636. #else
  1637. #define QK_NL 4
  1638. #endif
  1639. typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
  1640. constant uint64_t &, constant uint64_t &, uint, uint, uint);
  1641. template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
  1642. template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
  1643. template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
  1644. template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
  1645. template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
  1646. template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
  1647. template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
  1648. template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
  1649. typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
  1650. constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
  1651. constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
  1652. template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
  1653. template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
  1654. template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
  1655. template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
  1656. template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
  1657. template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
  1658. template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
  1659. template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;