ggml_vk_generate_shaders.py 87 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335
  1. #!/usr/bin/env python
  2. import argparse
  3. import asyncio
  4. import os
  5. import sys
  6. from tempfile import gettempdir, NamedTemporaryFile
  7. shader_f32 = """
  8. #define FLOAT_TYPE float
  9. """
  10. shader_f16 = """
  11. #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
  12. #define FLOAT_TYPE float16_t
  13. """
  14. shader_int8_ext = """
  15. #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
  16. """
  17. # Type-specific defines
  18. shader_f16_defines = """
  19. #define QUANT_K 1
  20. #define QUANT_R 1
  21. #define A_TYPE float16_t
  22. """
  23. shader_q4_0_defines = """
  24. #define QUANT_K 32
  25. #define QUANT_R 2
  26. struct block_q4_0
  27. {
  28. float16_t d;
  29. uint8_t qs[16];
  30. };
  31. #define A_TYPE block_q4_0
  32. """
  33. shader_q4_1_defines = """
  34. #define QUANT_K 32
  35. #define QUANT_R 2
  36. struct block_q4_1
  37. {
  38. float16_t d;
  39. float16_t m;
  40. uint8_t qs[16];
  41. };
  42. #define A_TYPE block_q4_1
  43. """
  44. shader_q5_0_defines = """
  45. #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
  46. #define QUANT_K 32
  47. #define QUANT_R 2
  48. struct block_q5_0
  49. {
  50. float16_t d;
  51. uint16_t qh[2];
  52. uint8_t qs[16];
  53. };
  54. #define A_TYPE block_q5_0
  55. """
  56. shader_q5_1_defines = """
  57. #define QUANT_K 32
  58. #define QUANT_R 2
  59. struct block_q5_1
  60. {
  61. float16_t d;
  62. float16_t m;
  63. uint qh;
  64. uint8_t qs[16];
  65. };
  66. #define A_TYPE block_q5_1
  67. """
  68. shader_q8_0_defines = """
  69. #define QUANT_K 32
  70. #define QUANT_R 1
  71. struct block_q8_0
  72. {
  73. float16_t d;
  74. int8_t qs[32];
  75. };
  76. #define A_TYPE block_q8_0
  77. """
  78. # K-quants
  79. shader_q2_K_defines = """
  80. #define QUANT_K 256
  81. struct block_q2_K
  82. {
  83. uint8_t scales[QUANT_K/16];
  84. uint8_t qs[QUANT_K/4];
  85. f16vec2 d;
  86. };
  87. #define A_TYPE block_q2_K
  88. """
  89. shader_q3_K_defines = """
  90. #define QUANT_K 256
  91. struct block_q3_K
  92. {
  93. uint8_t hmask[QUANT_K/8];
  94. uint8_t qs[QUANT_K/4];
  95. uint8_t scales[12];
  96. float16_t d;
  97. };
  98. #define A_TYPE block_q3_K
  99. """
  100. shader_q4_K_defines = """
  101. #define QUANT_K 256
  102. struct block_q4_K
  103. {
  104. f16vec2 d;
  105. uint8_t scales[3*QUANT_K/64];
  106. uint8_t qs[QUANT_K/2];
  107. };
  108. #define A_TYPE block_q4_K
  109. """
  110. shader_q5_K_defines = """
  111. #define QUANT_K 256
  112. struct block_q5_K
  113. {
  114. f16vec2 d;
  115. uint8_t scales[12];
  116. uint8_t qh[QUANT_K/8];
  117. uint8_t qs[QUANT_K/2];
  118. };
  119. #define A_TYPE block_q5_K
  120. """
  121. shader_q6_K_defines = """
  122. #define QUANT_K 256
  123. struct block_q6_K
  124. {
  125. uint8_t ql[QUANT_K/2];
  126. uint8_t qh[QUANT_K/4];
  127. int8_t scales[QUANT_K/16];
  128. float16_t d;
  129. };
  130. #define A_TYPE block_q6_K
  131. """
  132. # Dequant functions
  133. shader_f16_dequant_func = """
  134. #define DEQUANT_FUNC vec2 v = vec2(data_a[ib + 0], data_a[ib + 1]);
  135. """
  136. shader_q4_0_dequant_func = """
  137. #define DEQUANT_FUNC const float d = float(data_a[ib].d); \
  138. const uint vui = uint(data_a[ib].qs[iqs]); \
  139. vec2 v = vec2(vui & 0xF, vui >> 4); \
  140. v = (v - 8.0f)*d;
  141. """
  142. shader_q4_1_dequant_func = """
  143. #define DEQUANT_FUNC const float d = float(data_a[ib].d); \
  144. const float m = float(data_a[ib].m); \
  145. const uint vui = uint(data_a[ib].qs[iqs]); \
  146. vec2 v = vec2(vui & 0xF, vui >> 4); \
  147. v = v*d + m;
  148. """
  149. shader_q5_0_dequant_func = """
  150. #define DEQUANT_FUNC const float d = float(data_a[ib].d); \
  151. const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; \
  152. const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \
  153. const uint vui = uint(data_a[ib].qs[iqs]); \
  154. vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \
  155. v = (v - 16.0f) * d;
  156. """
  157. shader_q5_1_dequant_func = """
  158. #define DEQUANT_FUNC const float d = float(data_a[ib].d); \
  159. const float m = float(data_a[ib].m); \
  160. const ivec2 qh = ivec2(((data_a[ib].qh >> iqs) << 4) & 0x10, (data_a[ib].qh >> (iqs + 12)) & 0x10); \
  161. const uint vui = uint(data_a[ib].qs[iqs]); \
  162. vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \
  163. v = v*d + m;
  164. """
  165. shader_q8_0_dequant_func = """
  166. #define DEQUANT_FUNC const float d = float(data_a[ib].d); \
  167. vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])); \
  168. v = v * d;
  169. """
  170. # MULMAT
  171. mulmat_head = """#version 450
  172. #extension GL_EXT_control_flow_attributes : enable
  173. #extension GL_EXT_shader_16bit_storage : require
  174. #ifndef LOAD_VEC
  175. #define LOAD_VEC 1
  176. #endif
  177. """
  178. mulmat_body = """
  179. layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
  180. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  181. layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
  182. layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
  183. layout (push_constant) uniform parameter
  184. {
  185. uint M;
  186. uint N;
  187. uint K;
  188. uint stride_a;
  189. uint stride_b;
  190. uint stride_d;
  191. uint k_split;
  192. uint ne02;
  193. uint ne12;
  194. uint broadcast2;
  195. uint broadcast3;
  196. uint batch_stride_a;
  197. uint batch_stride_b;
  198. uint batch_stride_d;
  199. } p;
  200. layout (constant_id = 1) const uint BM = 64;
  201. layout (constant_id = 2) const uint BN = 64;
  202. layout (constant_id = 3) const uint BK = 16;
  203. layout (constant_id = 4) const uint WM = 32;
  204. layout (constant_id = 5) const uint WN = 32;
  205. layout (constant_id = 6) const uint WMITER = 2;
  206. layout (constant_id = 7) const uint TM = 4;
  207. layout (constant_id = 8) const uint TN = 2;
  208. layout (constant_id = 9) const uint WARP = 32;
  209. shared FLOAT_TYPE buf_a[BM * (BK+1)];
  210. shared FLOAT_TYPE buf_b[BN * (BK+1)];
  211. void main() {
  212. const uint i13 = gl_GlobalInvocationID.z / p.ne12;
  213. const uint i12 = gl_GlobalInvocationID.z % p.ne12;
  214. const uint i03 = i13 / p.broadcast3;
  215. const uint i02 = i12 / p.broadcast2;
  216. const uint batch_idx_a = i03 * p.ne02 + i02;
  217. const uint blocks_m = (p.M + BM - 1) / BM;
  218. const uint ir = gl_WorkGroupID.x % blocks_m;
  219. const uint ik = gl_WorkGroupID.x / blocks_m;
  220. const uint ic = gl_WorkGroupID.y;
  221. const uint warp_i = gl_LocalInvocationID.x / WARP;
  222. const uint warp_r = warp_i % (BM / WM);
  223. const uint warp_c = warp_i / (BM / WM);
  224. const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
  225. const uint WSUBM = WM / WMITER;
  226. const uint WSUBN = WN / WNITER;
  227. const uint tiw = gl_LocalInvocationID.x % WARP;
  228. const uint tiwr = tiw % (WSUBM / TM);
  229. const uint tiwc = tiw / (WSUBM / TM);
  230. const uint loadr = gl_LocalInvocationID.x % (BK / LOAD_VEC);
  231. const uint loadc = gl_LocalInvocationID.x / (BK / LOAD_VEC);
  232. const uint loadstride = gl_WorkGroupSize.x * LOAD_VEC / BK;
  233. const uint start_k = ik * p.k_split;
  234. const uint end_k = min(p.K, (ik + 1) * p.k_split);
  235. uint pos_a = (batch_idx_a * p.batch_stride_a + ir * BM * p.stride_a + start_k) / LOAD_VEC;
  236. uint pos_b = (gl_GlobalInvocationID.z * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC;
  237. float sums[WMITER * TM * WNITER * TN];
  238. FLOAT_TYPE cache_a[WMITER * TM];
  239. FLOAT_TYPE cache_b[WNITER * TN];
  240. [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
  241. sums[i] = 0.0f;
  242. }
  243. [[unroll]] for (uint block = start_k; block < end_k; block += BK) {
  244. [[unroll]] for (uint l = 0; l < BM; l += loadstride) {
  245. #if LOAD_VEC == 8
  246. const uint idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr;
  247. buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_a[idx][0].x);
  248. buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_a[idx][0].y);
  249. buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx][0].z);
  250. buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_a[idx][0].w);
  251. buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 4] = FLOAT_TYPE(data_a[idx][1].x);
  252. buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 5] = FLOAT_TYPE(data_a[idx][1].y);
  253. buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(data_a[idx][1].z);
  254. buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(data_a[idx][1].w);
  255. #elif LOAD_VEC == 4
  256. const uint idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr;
  257. buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_a[idx].x);
  258. buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_a[idx].y);
  259. buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx].z);
  260. buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_a[idx].w);
  261. #else
  262. if (ir * BM + loadc + l < p.M && block + loadr < end_k) {
  263. buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_a[pos_a + (loadc + l) * p.stride_a + loadr]);
  264. } else {
  265. buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f);
  266. }
  267. #endif
  268. }
  269. [[unroll]] for (uint l = 0; l < BN; l += loadstride) {
  270. #if LOAD_VEC == 8
  271. const uint idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr;
  272. buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_b[idx][0].x);
  273. buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_b[idx][0].y);
  274. buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx][0].z);
  275. buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_b[idx][0].w);
  276. buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 4] = FLOAT_TYPE(data_b[idx][1].x);
  277. buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 5] = FLOAT_TYPE(data_b[idx][1].y);
  278. buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(data_b[idx][1].z);
  279. buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(data_b[idx][1].w);
  280. #elif LOAD_VEC == 4
  281. const uint idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr;
  282. buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_b[idx].x);
  283. buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_b[idx].y);
  284. buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx].z);
  285. buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_b[idx].w);
  286. #else
  287. if (ic * BN + loadc + l < p.N && block + loadr < end_k) {
  288. buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_b[pos_b + (loadc + l) * p.stride_b + loadr]);
  289. } else {
  290. buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f);
  291. }
  292. #endif
  293. }
  294. barrier();
  295. pos_a += BK / LOAD_VEC;
  296. pos_b += BK / LOAD_VEC;
  297. for (uint i = 0; i < BK; i++) {
  298. // Load from shared into cache
  299. [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
  300. [[unroll]] for (uint j = 0; j < TM; j++) {
  301. cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
  302. }
  303. }
  304. [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
  305. [[unroll]] for (uint j = 0; j < TN; j++) {
  306. cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
  307. }
  308. }
  309. [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
  310. [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
  311. [[unroll]] for (uint cc = 0; cc < TN; cc++) {
  312. [[unroll]] for (uint cr = 0; cr < TM; cr++) {
  313. sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
  314. }
  315. }
  316. }
  317. }
  318. }
  319. barrier();
  320. }
  321. const uint dr = ir * BM + warp_r * WM;
  322. const uint dc = ic * BN + warp_c * WN;
  323. const uint offsets = gl_GlobalInvocationID.z * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
  324. [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
  325. [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
  326. const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
  327. const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
  328. [[unroll]] for (uint cc = 0; cc < TN; cc++) {
  329. [[unroll]] for (uint cr = 0; cr < TM; cr++) {
  330. if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
  331. data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
  332. }
  333. }
  334. }
  335. }
  336. }
  337. }
  338. """
  339. mulmat_split_k_reduce_src = """#version 450
  340. #extension GL_EXT_control_flow_attributes : enable
  341. layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
  342. layout (binding = 0) readonly buffer A {float data_a[];};
  343. layout (binding = 1) writeonly buffer D {float data_d[];};
  344. layout (push_constant) uniform parameter {
  345. uint ne;
  346. uint k_num;
  347. } p;
  348. void main() {
  349. const uint idx = gl_GlobalInvocationID.x;
  350. if (idx >= p.ne) {
  351. return;
  352. }
  353. float result = 0.0f;
  354. [[unroll]] for (uint i = 0; i < p.k_num; i++) {
  355. result += data_a[i * p.ne + idx];
  356. }
  357. data_d[idx] = result;
  358. }
  359. """
  360. # DEQUANT SHADER
  361. dequant_head = """#version 450
  362. #extension GL_EXT_control_flow_attributes : require
  363. #extension GL_EXT_shader_16bit_storage : require
  364. """
  365. dequant_body = """
  366. layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
  367. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  368. layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
  369. layout (push_constant) uniform parameter
  370. {
  371. int M;
  372. int K;
  373. int stride_a;
  374. int stride_b;
  375. } p;
  376. void main() {
  377. const int i = int(gl_GlobalInvocationID.x);
  378. // Transposed
  379. const int row = i % (p.K / QUANT_K);
  380. const int col = i / (p.K / QUANT_K);
  381. if (row * QUANT_K >= p.K || col >= p.M) {
  382. return;
  383. }
  384. const int stride_a = p.stride_a / QUANT_K;
  385. const int ib = col * stride_a + row;
  386. const int y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
  387. const int step = QUANT_R == 1 ? 2 : 1;
  388. [[unroll]] for (int iqs = 0; iqs < QUANT_K/QUANT_R; iqs += step) {
  389. DEQUANT_FUNC
  390. data_b[col * p.stride_b + row*QUANT_K + iqs + 0 ] = D_TYPE(v.x);
  391. data_b[col * p.stride_b + row*QUANT_K + iqs + y_offset] = D_TYPE(v.y);
  392. }
  393. }
  394. """
  395. # K-quants
  396. dequant_q2_K_body = """
  397. layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
  398. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  399. layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
  400. layout (push_constant) uniform parameter
  401. {
  402. int M;
  403. int K;
  404. int stride_a;
  405. int stride_b;
  406. } p;
  407. void main() {
  408. [[unroll]] for (int wgy = 0; wgy < 256; wgy++) {
  409. const int i = int(gl_WorkGroupID.x * 256 + wgy);
  410. if (i >= p.M * p.K / QUANT_K) {
  411. return;
  412. }
  413. const int tid = int(gl_LocalInvocationID.x);
  414. const int ip = tid / 32;
  415. const int il = tid - 32 * ip;
  416. const int is = 8 * ip + il / 16;
  417. const int y_idx = i * QUANT_K + 128 * ip + il;
  418. const int ql_idx = 32 * ip + il;
  419. const uint8_t qs = data_a[i].qs[32 * ip + il];
  420. FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
  421. FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
  422. data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4));
  423. data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4));
  424. data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4));
  425. data_b[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+6] >> 4));
  426. }
  427. }
  428. """
  429. dequant_q3_K_body = """
  430. layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
  431. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  432. layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
  433. layout (push_constant) uniform parameter
  434. {
  435. int M;
  436. int K;
  437. int stride_a;
  438. int stride_b;
  439. } p;
  440. void main() {
  441. [[unroll]] for (int wgy = 0; wgy < 256; wgy++) {
  442. const int i = int(gl_WorkGroupID.x * 256 + wgy);
  443. if (i >= p.M * p.K / QUANT_K) {
  444. return;
  445. }
  446. const int r = int(gl_LocalInvocationID.x) / 4;
  447. const int tid = r / 2;
  448. const int is0 = r % 2;
  449. const int l0 = 16 * is0 + 4 * (int(gl_LocalInvocationID.x) % 4);
  450. const int n = tid / 4;
  451. const int j = tid - 4*n;
  452. const uint8_t m = uint8_t(1 << (4*n + j));
  453. const int is = 8*n + 2*j + is0;
  454. const int shift = 2*j;
  455. const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) :
  456. is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) :
  457. is < 12 ? (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) :
  458. (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4));
  459. const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d);
  460. const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32);
  461. const int y_idx = i * QUANT_K + 128 * n + 32 * j;
  462. const int qs_idx = 32*n;
  463. for (int l = l0; l < l0 + 4; ++l) {
  464. data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4)));
  465. }
  466. }
  467. }
  468. """
  469. dequant_q4_K_body = """
  470. layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
  471. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  472. layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
  473. layout (push_constant) uniform parameter
  474. {
  475. int M;
  476. int K;
  477. int stride_a;
  478. int stride_b;
  479. } p;
  480. void main() {
  481. [[unroll]] for (int wgy = 0; wgy < 256; wgy++) {
  482. const int i = int(gl_WorkGroupID.x * 256 + wgy);
  483. if (i >= p.M * p.K / QUANT_K) {
  484. return;
  485. }
  486. const int tid = int(gl_LocalInvocationID.x);
  487. const int il = tid / 8;
  488. const int ir = tid % 8;
  489. const int is = 2 * il;
  490. const int n = 4;
  491. const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
  492. const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
  493. const int y_idx = i * QUANT_K + 64 * il + n * ir;
  494. const int qs_idx = 32*il + n * ir;
  495. uint8_t sc;
  496. uint8_t m;
  497. if (is < 4) {
  498. sc = uint8_t(data_a[i].scales[is] & 63);
  499. m = uint8_t(data_a[i].scales[is + 4] & 63);
  500. } else {
  501. sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4));
  502. m = uint8_t((data_a[i].scales[is + 4] >> 4) | ((data_a[i].scales[is ] >> 6) << 4));
  503. }
  504. const FLOAT_TYPE d1 = dall * sc;
  505. const FLOAT_TYPE m1 = dmin * m;
  506. if (is < 4) {
  507. sc = uint8_t(data_a[i].scales[is + 1] & 63);
  508. m = uint8_t(data_a[i].scales[is + 5] & 63);
  509. } else {
  510. sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4));
  511. m = uint8_t((data_a[i].scales[is + 5] >> 4) | ((data_a[i].scales[is + 1] >> 6) << 4));
  512. }
  513. const FLOAT_TYPE d2 = dall * sc;
  514. const FLOAT_TYPE m2 = dmin * m;
  515. [[unroll]] for (int l = 0; l < n; ++l) {
  516. data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] & 0xF) - m1);
  517. data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[i].qs[qs_idx + l] >> 4) - m2);
  518. }
  519. }
  520. }
  521. """
  522. dequant_q5_K_body = """
  523. layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
  524. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  525. layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
  526. layout (push_constant) uniform parameter
  527. {
  528. int M;
  529. int K;
  530. int stride_a;
  531. int stride_b;
  532. } p;
  533. void main() {
  534. [[unroll]] for (int wgy = 0; wgy < 256; wgy++) {
  535. const int i = int(gl_WorkGroupID.x * 256 + wgy);
  536. if (i >= p.M * p.K / QUANT_K) {
  537. return;
  538. }
  539. const int tid = int(gl_LocalInvocationID.x);
  540. const int il = tid / 16;
  541. const int ir = tid % 16;
  542. const int is = 2 * il;
  543. const FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x);
  544. const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y);
  545. const int y_idx = i * QUANT_K + 64 * il + 2 * ir;
  546. const int qs_idx = 32*il + 2 * ir;
  547. const int qh_idx = 2 * ir;
  548. uint8_t sc;
  549. uint8_t m;
  550. if (is < 4) {
  551. sc = uint8_t(data_a[i].scales[is] & 63);
  552. m = uint8_t(data_a[i].scales[is + 4] & 63);
  553. } else {
  554. sc = uint8_t((data_a[i].scales[is + 4] & 0xF) | ((data_a[i].scales[is - 4] >> 6) << 4));
  555. m = uint8_t((data_a[i].scales[is + 4] >> 4) | ((data_a[i].scales[is ] >> 6) << 4));
  556. }
  557. const FLOAT_TYPE d1 = dall * sc;
  558. const FLOAT_TYPE m1 = dmin * m;
  559. if (is < 4) {
  560. sc = uint8_t(data_a[i].scales[is + 1] & 63);
  561. m = uint8_t(data_a[i].scales[is + 5] & 63);
  562. } else {
  563. sc = uint8_t((data_a[i].scales[is + 5] & 0xF) | ((data_a[i].scales[is - 3] >> 6) << 4));
  564. m = uint8_t((data_a[i].scales[is + 5] >> 4) | ((data_a[i].scales[is + 1] >> 6) << 4));
  565. }
  566. const FLOAT_TYPE d2 = dall * sc;
  567. const FLOAT_TYPE m2 = dmin * m;
  568. const uint8_t hm1 = uint8_t(1 << (2 * il ));
  569. const uint8_t hm2 = uint8_t(1 << (2 * il + 1));
  570. data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx ] & 0xF) + (((data_a[i].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1);
  571. data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1] & 0xF) + (((data_a[i].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1);
  572. data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx ] >> 4) + (((data_a[i].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2);
  573. data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[i].qs[qs_idx + 1] >> 4) + (((data_a[i].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2);
  574. }
  575. }
  576. """
  577. dequant_q6_K_body = """
  578. layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
  579. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  580. layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
  581. layout (push_constant) uniform parameter
  582. {
  583. int M;
  584. int K;
  585. int stride_a;
  586. int stride_b;
  587. } p;
  588. void main() {
  589. [[unroll]] for (int wgy = 0; wgy < 256; wgy++) {
  590. const int i = int(gl_WorkGroupID.x * 256 + wgy);
  591. if (i >= p.M * p.K / QUANT_K) {
  592. return;
  593. }
  594. const int tid = int(gl_LocalInvocationID.x);
  595. const int ip = tid / 32;
  596. const int il = tid - 32 * ip;
  597. const int is = 8 * ip + il / 16;
  598. const int y_idx = i * QUANT_K + 128 * ip + il;
  599. const int ql_idx = 64 * ip + il;
  600. const uint8_t qh = data_a[i].qh[32 * ip + il];
  601. const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d);
  602. data_b[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 0] * (int8_t((data_a[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32)));
  603. data_b[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 2] * (int8_t((data_a[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32)));
  604. data_b[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 4] * (int8_t((data_a[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32)));
  605. data_b[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 6] * (int8_t((data_a[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32)));
  606. }
  607. }
  608. """
  609. # Mul Mat Vec
  610. mul_mat_vec_head = """#version 450
  611. #extension GL_EXT_control_flow_attributes : enable
  612. #extension GL_EXT_shader_16bit_storage : require
  613. #extension GL_EXT_shader_8bit_storage : require
  614. """
  615. mul_mat_vec_body = """
  616. layout(local_size_x = QUANT_K, local_size_y = 1, local_size_z = 1) in;
  617. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  618. layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
  619. layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
  620. layout (push_constant) uniform parameter
  621. {
  622. int ncols;
  623. int b_offset;
  624. int d_offset;
  625. } p;
  626. shared FLOAT_TYPE tmp[QUANT_K];
  627. void main() {
  628. const int block_size = int(gl_WorkGroupSize.x);
  629. const int row = int(gl_WorkGroupID.x);
  630. const int tid = int(gl_LocalInvocationID.x);
  631. const int y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
  632. tmp[tid] = FLOAT_TYPE(0.0f);
  633. [[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) {
  634. const int col = i*block_size + 2*tid;
  635. const int ib = (row*p.ncols + col)/QUANT_K; // block index
  636. const int iqs = (col%QUANT_K)/QUANT_R; // quant index
  637. const int iybs = col - col%QUANT_K; // y block start index
  638. DEQUANT_FUNC
  639. // matrix multiplication
  640. tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(data_b[p.b_offset + iybs + iqs + 0]);
  641. tmp[tid] += FLOAT_TYPE(v.y) * FLOAT_TYPE(data_b[p.b_offset + iybs + iqs + y_offset]);
  642. }
  643. // sum up partial sums and write back result
  644. barrier();
  645. [[unroll]] for (int s = block_size/2; s > 0; s >>= 1) {
  646. if (tid < s) {
  647. tmp[tid] += tmp[tid + s];
  648. }
  649. barrier();
  650. }
  651. if (tid == 0) {
  652. dst[p.d_offset + row] = D_TYPE(tmp[0]);
  653. }
  654. }
  655. """
  656. # K-quants
  657. mul_mat_vec_q2_K_body = """
  658. layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
  659. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  660. layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
  661. layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
  662. layout (push_constant) uniform parameter
  663. {
  664. int ncols;
  665. int b_offset;
  666. int d_offset;
  667. } p;
  668. shared FLOAT_TYPE tmp[32];
  669. void main() {
  670. const int row = int(gl_WorkGroupID.x);
  671. const int num_blocks_per_row = p.ncols / QUANT_K;
  672. const int ib0 = row*num_blocks_per_row;
  673. const int tid = int(gl_LocalInvocationID.x)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
  674. const int ix = int(gl_LocalInvocationID.x)%K_QUANTS_PER_ITERATION; // 0 or 0, 1
  675. const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
  676. const int v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
  677. const int v_in = tid - step*v_im; // 0...15 or 0...7
  678. const int l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
  679. const int q_offset = 32*v_im + l0;
  680. const int s_offset = 8*v_im;
  681. const int y_offset = 128*v_im + l0;
  682. tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
  683. [[unroll]] for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  684. const int y_idx = i * QUANT_K + y_offset;
  685. const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
  686. const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
  687. FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
  688. FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
  689. for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
  690. sum1 += FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3)
  691. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3)
  692. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3)
  693. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 2) & 3)
  694. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 4) & 3)
  695. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 4) & 3)
  696. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 6) & 3)
  697. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 6) & 3);
  698. sum2 += FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 0]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 0] >> 4) & 0xF)
  699. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 16]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 1] >> 4) & 0xF)
  700. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 32]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 2] >> 4) & 0xF)
  701. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 48]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 3] >> 4) & 0xF)
  702. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 64]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 4] >> 4) & 0xF)
  703. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 80]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 5] >> 4) & 0xF)
  704. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 96]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 6] >> 4) & 0xF)
  705. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l +112]) * FLOAT_TYPE((data_a[ib0 + i].scales[s_offset + 7] >> 4) & 0xF);
  706. }
  707. tmp[16 * ix + tid] += dall * sum1 - dmin * sum2;
  708. }
  709. // sum up partial sums and write back result
  710. barrier();
  711. [[unroll]] for (int s = 16; s > 0; s >>= 1) {
  712. if (tid < s) {
  713. tmp[tid] += tmp[tid + s];
  714. }
  715. barrier();
  716. }
  717. if (tid == 0) {
  718. dst[p.d_offset + row] = D_TYPE(tmp[0]);
  719. }
  720. }
  721. """
  722. mul_mat_vec_q3_K_body = """
  723. layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
  724. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  725. layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
  726. layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
  727. layout (push_constant) uniform parameter
  728. {
  729. int ncols;
  730. int b_offset;
  731. int d_offset;
  732. } p;
  733. shared FLOAT_TYPE tmp[32];
  734. void main() {
  735. const int row = int(gl_WorkGroupID.x);
  736. const int num_blocks_per_row = p.ncols / QUANT_K;
  737. const int ib0 = row*num_blocks_per_row;
  738. const int tid = int(gl_LocalInvocationID.x)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
  739. const int ix = int(gl_LocalInvocationID.x)%K_QUANTS_PER_ITERATION; // 0 or 0, 1
  740. const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
  741. const int v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
  742. const int v_in = tid - step*v_im; // 0...15 or 0...7
  743. const uint8_t m = uint8_t(1 << (4 * v_im));
  744. const int l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
  745. const int q_offset = 32*v_im + l0;
  746. const int y_offset = 128*v_im + l0;
  747. tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
  748. const uint s_shift = 4 * v_im;
  749. [[unroll]] for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  750. const int y_idx = i * QUANT_K + y_offset;
  751. const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
  752. FLOAT_TYPE sum = FLOAT_TYPE(0.0);
  753. for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
  754. sum += FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4))
  755. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4))
  756. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4))
  757. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 96]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[6] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4))
  758. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 16]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[1] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4))
  759. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 48]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[3] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4))
  760. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l + 80]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[5] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 9] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4))
  761. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l +112]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[7] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[11] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4));
  762. }
  763. tmp[16 * ix + tid] += d * sum;
  764. }
  765. // sum up partial sums and write back result
  766. barrier();
  767. [[unroll]] for (int s = 16; s > 0; s >>= 1) {
  768. if (tid < s) {
  769. tmp[tid] += tmp[tid + s];
  770. }
  771. barrier();
  772. }
  773. if (tid == 0) {
  774. dst[p.d_offset + row] = D_TYPE(tmp[0]);
  775. }
  776. }
  777. """
  778. mul_mat_vec_q4_K_body = """
  779. layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
  780. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  781. layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
  782. layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
  783. layout (push_constant) uniform parameter
  784. {
  785. int ncols;
  786. int b_offset;
  787. int d_offset;
  788. } p;
  789. shared FLOAT_TYPE tmp[32];
  790. void main() {
  791. const int row = int(gl_WorkGroupID.x);
  792. const int num_blocks_per_row = p.ncols / QUANT_K;
  793. const int ib0 = row*num_blocks_per_row;
  794. const int tid = int(gl_LocalInvocationID.x)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
  795. const int ix = int(gl_LocalInvocationID.x)%K_QUANTS_PER_ITERATION; // 0 or 0, 1
  796. const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
  797. const int il = tid/step; // 0...3
  798. const int ir = tid - step*il; // 0...7 or 0...3
  799. const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
  800. const int v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
  801. const int v_in = il % 2;
  802. const int l0 = n * (2 * ir + v_in); // 0...15
  803. const int q_offset = 32*v_im + l0;
  804. const int y_offset = 64*v_im + l0;
  805. tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
  806. [[unroll]] for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  807. const int y1_idx = i * QUANT_K + y_offset;
  808. const int y2_idx = y1_idx + 128;
  809. const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
  810. const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
  811. const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f);
  812. const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f);
  813. const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f);
  814. const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f);
  815. const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2));
  816. const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
  817. const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
  818. const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
  819. #if K_QUANTS_PER_ITERATION == 2
  820. const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
  821. const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
  822. const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] & 0xf);
  823. const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] & 0xf);
  824. const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
  825. const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
  826. const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] >> 4);
  827. const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] >> 4);
  828. const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
  829. const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
  830. const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] & 0xf);
  831. const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] & 0xf);
  832. const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
  833. const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
  834. const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] >> 4);
  835. const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] >> 4);
  836. const FLOAT_TYPE sx = FLOAT_TYPE(data_b[p.b_offset + y1_idx] * q4_0 + data_b[p.b_offset + y1_idx + 1] * q4_1 + data_b[p.b_offset + y1_idx + 2] * q4_2 + data_b[p.b_offset + y1_idx + 3] * q4_3);
  837. const FLOAT_TYPE sy = FLOAT_TYPE(data_b[p.b_offset + y1_idx + 32] * q4_4 + data_b[p.b_offset + y1_idx + 33] * q4_5 + data_b[p.b_offset + y1_idx + 34] * q4_6 + data_b[p.b_offset + y1_idx + 35] * q4_7);
  838. const FLOAT_TYPE sz = FLOAT_TYPE(data_b[p.b_offset + y2_idx] * q4_8 + data_b[p.b_offset + y2_idx + 1] * q4_9 + data_b[p.b_offset + y2_idx + 2] * q4_10 + data_b[p.b_offset + y2_idx + 3] * q4_11);
  839. const FLOAT_TYPE sw = FLOAT_TYPE(data_b[p.b_offset + y2_idx + 32] * q4_12 + data_b[p.b_offset + y2_idx + 33] * q4_13 + data_b[p.b_offset + y2_idx + 34] * q4_14 + data_b[p.b_offset + y2_idx + 35] * q4_15);
  840. const FLOAT_TYPE smin = FLOAT_TYPE(
  841. data_b[p.b_offset + y1_idx ] * sc2 + data_b[p.b_offset + y1_idx + 32] * sc3 + data_b[p.b_offset + y2_idx ] * sc6 + data_b[p.b_offset + y2_idx + 32] * sc7
  842. + data_b[p.b_offset + y1_idx + 1] * sc2 + data_b[p.b_offset + y1_idx + 33] * sc3 + data_b[p.b_offset + y2_idx + 1] * sc6 + data_b[p.b_offset + y2_idx + 33] * sc7
  843. + data_b[p.b_offset + y1_idx + 2] * sc2 + data_b[p.b_offset + y1_idx + 34] * sc3 + data_b[p.b_offset + y2_idx + 2] * sc6 + data_b[p.b_offset + y2_idx + 34] * sc7
  844. + data_b[p.b_offset + y1_idx + 3] * sc2 + data_b[p.b_offset + y1_idx + 35] * sc3 + data_b[p.b_offset + y2_idx + 3] * sc6 + data_b[p.b_offset + y2_idx + 35] * sc7
  845. );
  846. tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
  847. #else
  848. const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
  849. const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
  850. const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
  851. const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
  852. const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
  853. const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
  854. const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
  855. const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
  856. const FLOAT_TYPE sx = FLOAT_TYPE(data_b[p.b_offset + y1_idx ] * q4_0 + data_b[p.b_offset + y1_idx + 1] * q4_1);
  857. const FLOAT_TYPE sy = FLOAT_TYPE(data_b[p.b_offset + y1_idx + 32] * q4_2 + data_b[p.b_offset + y1_idx + 33] * q4_3);
  858. const FLOAT_TYPE sz = FLOAT_TYPE(data_b[p.b_offset + y2_idx ] * q4_4 + data_b[p.b_offset + y2_idx + 1] * q4_5);
  859. const FLOAT_TYPE sw = FLOAT_TYPE(data_b[p.b_offset + y2_idx + 32] * q4_6 + data_b[p.b_offset + y2_idx + 33] * q4_7);
  860. const FLOAT_TYPE smin = FLOAT_TYPE(
  861. data_b[p.b_offset + y1_idx] * sc2 + data_b[p.b_offset + y1_idx + 32] * sc3 + data_b[p.b_offset + y2_idx] * sc6 + data_b[p.b_offset + y2_idx + 32] * sc7
  862. + data_b[p.b_offset + y1_idx + 1] * sc2 + data_b[p.b_offset + y1_idx + 33] * sc3 + data_b[p.b_offset + y2_idx + 1] * sc6 + data_b[p.b_offset + y2_idx + 33] * sc7
  863. );
  864. tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) + sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin);
  865. #endif
  866. }
  867. // sum up partial sums and write back result
  868. barrier();
  869. [[unroll]] for (int s = 16; s > 0; s >>= 1) {
  870. if (tid < s) {
  871. tmp[tid] += tmp[tid + s];
  872. }
  873. barrier();
  874. }
  875. if (tid == 0) {
  876. dst[p.d_offset + row] = D_TYPE(tmp[0]);
  877. }
  878. }
  879. """
  880. mul_mat_vec_q5_K_body = """
  881. layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
  882. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  883. layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
  884. layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
  885. layout (push_constant) uniform parameter
  886. {
  887. int ncols;
  888. int b_offset;
  889. int d_offset;
  890. } p;
  891. shared FLOAT_TYPE tmp[32];
  892. void main() {
  893. const int row = int(gl_WorkGroupID.x);
  894. const int num_blocks_per_row = p.ncols / QUANT_K;
  895. const int ib0 = row*num_blocks_per_row;
  896. const int tid = int(gl_LocalInvocationID.x)/2; // 0...31 or 0...16
  897. const int ix = int(gl_LocalInvocationID.x)%2; // 0 or 0, 1
  898. const int il = tid/4; // 0...3
  899. const int ir = tid - 4*il; // 0...7 or 0...3
  900. const int v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
  901. const int v_in = il % 2;
  902. const int l0 = 4*ir + 2*v_in; // 0...15
  903. const int q_offset = 32*v_im + l0;
  904. const int y_offset = 64*v_im + l0;
  905. const uint8_t hm1 = uint8_t(1 << (2*v_im));
  906. const uint8_t hm2 = uint8_t(hm1 << 4);
  907. tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
  908. [[unroll]] for (int i = ix; i < num_blocks_per_row; i += 2) {
  909. const int y1_idx = i * QUANT_K + y_offset;
  910. const int y2_idx = y1_idx + 128;
  911. const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
  912. const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
  913. const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f);
  914. const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f);
  915. const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f);
  916. const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f);
  917. const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2));
  918. const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
  919. const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
  920. const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
  921. const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
  922. const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
  923. const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] & 0xf);
  924. const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] & 0xf);
  925. const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
  926. const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
  927. const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] >> 4);
  928. const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] >> 4);
  929. const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
  930. const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
  931. const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] & 0xf);
  932. const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] & 0xf);
  933. const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
  934. const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
  935. const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] >> 4);
  936. const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] >> 4);
  937. const FLOAT_TYPE sx = FLOAT_TYPE(
  938. data_b[p.b_offset + y1_idx ] * (q4_0 + (((data_a[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0))
  939. + data_b[p.b_offset + y1_idx + 1] * (q4_1 + (((data_a[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0))
  940. + data_b[p.b_offset + y1_idx + 16] * (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0))
  941. + data_b[p.b_offset + y1_idx + 17] * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0))
  942. );
  943. const FLOAT_TYPE sy = FLOAT_TYPE(
  944. data_b[p.b_offset + y1_idx + 32] * (q4_4 + (((data_a[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0))
  945. + data_b[p.b_offset + y1_idx + 33] * (q4_5 + (((data_a[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0))
  946. + data_b[p.b_offset + y1_idx + 48] * (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0))
  947. + data_b[p.b_offset + y1_idx + 49] * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0))
  948. );
  949. const FLOAT_TYPE sz = FLOAT_TYPE(
  950. data_b[p.b_offset + y2_idx ] * (q4_8 + (((data_a[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0))
  951. + data_b[p.b_offset + y2_idx + 1] * (q4_9 + (((data_a[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0))
  952. + data_b[p.b_offset + y2_idx + 16] * (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0))
  953. + data_b[p.b_offset + y2_idx + 17] * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0))
  954. );
  955. const FLOAT_TYPE sw = FLOAT_TYPE(
  956. data_b[p.b_offset + y2_idx + 32] * (q4_12 + (((data_a[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0))
  957. + data_b[p.b_offset + y2_idx + 33] * (q4_13 + (((data_a[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0))
  958. + data_b[p.b_offset + y2_idx + 48] * (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0))
  959. + data_b[p.b_offset + y2_idx + 49] * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0))
  960. );
  961. const FLOAT_TYPE smin = FLOAT_TYPE(
  962. (data_b[p.b_offset + y1_idx] + data_b[p.b_offset + y1_idx + 1] + data_b[p.b_offset + y1_idx + 16] + data_b[p.b_offset + y1_idx + 17]) * sc2 + (data_b[p.b_offset + y1_idx + 32] + data_b[p.b_offset + y1_idx + 33] + data_b[p.b_offset + y1_idx + 48] + data_b[p.b_offset + y1_idx + 49]) * sc3
  963. + (data_b[p.b_offset + y2_idx] + data_b[p.b_offset + y2_idx + 1] + data_b[p.b_offset + y2_idx + 16] + data_b[p.b_offset + y2_idx + 17]) * sc6 + (data_b[p.b_offset + y2_idx + 32] + data_b[p.b_offset + y2_idx + 33] + data_b[p.b_offset + y2_idx + 48] + data_b[p.b_offset + y2_idx + 49]) * sc7
  964. );
  965. tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
  966. }
  967. // sum up partial sums and write back result
  968. barrier();
  969. [[unroll]] for (int s = 16; s > 0; s >>= 1) {
  970. if (tid < s) {
  971. tmp[tid] += tmp[tid + s];
  972. }
  973. barrier();
  974. }
  975. if (tid == 0) {
  976. dst[p.d_offset + row] = D_TYPE(tmp[0]);
  977. }
  978. }
  979. """
  980. mul_mat_vec_q6_K_body = """
  981. layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
  982. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  983. layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
  984. layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
  985. layout (push_constant) uniform parameter
  986. {
  987. int ncols;
  988. int b_offset;
  989. int d_offset;
  990. } p;
  991. shared FLOAT_TYPE tmp[32];
  992. void main() {
  993. const int row = int(gl_WorkGroupID.x);
  994. const int num_blocks_per_row = p.ncols / QUANT_K;
  995. const int ib0 = row*num_blocks_per_row;
  996. const int tid = int(gl_LocalInvocationID.x)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
  997. const int ix = int(gl_LocalInvocationID.x)%K_QUANTS_PER_ITERATION; // 0 or 0, 1
  998. const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
  999. const int v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
  1000. const int v_in = tid - step*v_im; // 0...15 or 0...7
  1001. #if K_QUANTS_PER_ITERATION == 1
  1002. const int l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
  1003. const int is = 0;
  1004. #else
  1005. const int l0 = 4 * v_in; // 0, 4, 8, ..., 28
  1006. const int is = v_in / 4;
  1007. #endif
  1008. const int ql_offset = 64*v_im + l0;
  1009. const int qh_offset = 32*v_im + l0;
  1010. const int s_offset = 8*v_im + is;
  1011. const int y_offset = 128*v_im + l0;
  1012. tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
  1013. [[unroll]] for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
  1014. const int y_idx = i * QUANT_K + y_offset;
  1015. const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
  1016. #if K_QUANTS_PER_ITERATION == 1
  1017. FLOAT_TYPE sum = FLOAT_TYPE(data_b[p.b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32)
  1018. + FLOAT_TYPE(data_b[p.b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32)
  1019. + FLOAT_TYPE(data_b[p.b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32)
  1020. + FLOAT_TYPE(data_b[p.b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32)
  1021. + FLOAT_TYPE(data_b[p.b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32)
  1022. + FLOAT_TYPE(data_b[p.b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32)
  1023. + FLOAT_TYPE(data_b[p.b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32)
  1024. + FLOAT_TYPE(data_b[p.b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32);
  1025. tmp[16 * ix + tid] += sum;
  1026. #else
  1027. FLOAT_TYPE sum = FLOAT_TYPE(0.0);
  1028. [[unroll]] for (int l = 0; l < 4; ++l) {
  1029. sum += FLOAT_TYPE(data_b[p.b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32)
  1030. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32)
  1031. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32)
  1032. + FLOAT_TYPE(data_b[p.b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32);
  1033. }
  1034. tmp[16 * ix + tid] += sum;
  1035. #endif
  1036. }
  1037. // sum up partial sums and write back result
  1038. barrier();
  1039. [[unroll]] for (int s = 16; s > 0; s >>= 1) {
  1040. if (tid < s) {
  1041. tmp[tid] += tmp[tid + s];
  1042. }
  1043. barrier();
  1044. }
  1045. if (tid == 0) {
  1046. dst[p.d_offset + row] = D_TYPE(tmp[0]);
  1047. }
  1048. }
  1049. """
  1050. mul_mat_p021_src = """#version 450
  1051. #extension GL_EXT_control_flow_attributes : enable
  1052. #extension GL_EXT_shader_16bit_storage : require
  1053. #define BLOCK_SIZE 32
  1054. #define FLOAT_TYPE float
  1055. layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
  1056. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  1057. layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
  1058. layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
  1059. layout (push_constant) uniform parameter
  1060. {
  1061. uint ncols_x;
  1062. uint nrows_x;
  1063. uint nchannels_x;
  1064. uint nchannels_y;
  1065. uint b_offset;
  1066. uint d_offset;
  1067. } p;
  1068. shared FLOAT_TYPE tmp[BLOCK_SIZE];
  1069. void main() {
  1070. const uint tid = gl_LocalInvocationID.x;
  1071. const uint row_x = gl_GlobalInvocationID.y;
  1072. const uint channel = gl_GlobalInvocationID.z;
  1073. const uint channel_x = channel / (p.nchannels_y / p.nchannels_x);
  1074. const uint nrows_y = p.ncols_x;
  1075. const uint nrows_dst = p.nrows_x;
  1076. const uint row_dst = row_x;
  1077. tmp[tid] = FLOAT_TYPE(0.0f);
  1078. for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
  1079. const uint col_x = col_x0 + tid;
  1080. if (col_x >= p.ncols_x) {
  1081. break;
  1082. }
  1083. // x is transposed and permuted
  1084. const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
  1085. const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
  1086. const uint row_y = col_x;
  1087. // y is not transposed but permuted
  1088. const uint iy = channel*nrows_y + row_y;
  1089. tmp[tid] += xi * FLOAT_TYPE(data_b[iy]);
  1090. }
  1091. // dst is not transposed and not permuted
  1092. const uint idst = channel*nrows_dst + row_dst;
  1093. // sum up partial sums and write back result
  1094. barrier();
  1095. [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
  1096. if (tid < s) {
  1097. tmp[tid] += tmp[tid + s];
  1098. }
  1099. barrier();
  1100. }
  1101. if (tid == 0) {
  1102. dst[idst] = tmp[0];
  1103. }
  1104. }
  1105. """
  1106. mul_mat_nc_src = """#version 450
  1107. #extension GL_EXT_control_flow_attributes : enable
  1108. #extension GL_EXT_shader_16bit_storage : require
  1109. #define BLOCK_SIZE 32
  1110. #define FLOAT_TYPE float
  1111. layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
  1112. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  1113. layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
  1114. layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
  1115. layout (push_constant) uniform parameter
  1116. {
  1117. uint ncols_x;
  1118. uint nrows_x;
  1119. uint row_stride_x;
  1120. uint channel_stride_x;
  1121. uint channel_x_divisor;
  1122. uint b_offset;
  1123. uint d_offset;
  1124. } p;
  1125. shared FLOAT_TYPE tmp[BLOCK_SIZE];
  1126. void main() {
  1127. const uint tid = gl_LocalInvocationID.x;
  1128. const uint row_x = gl_GlobalInvocationID.y;
  1129. const uint channel = gl_GlobalInvocationID.z;
  1130. const uint channel_x = channel / p.channel_x_divisor;
  1131. const uint nrows_y = p.ncols_x;
  1132. const uint nrows_dst = p.nrows_x;
  1133. const uint row_dst = row_x;
  1134. const uint idst = channel*nrows_dst + row_dst;
  1135. tmp[tid] = 0.0f;
  1136. for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
  1137. const uint col_x = col_x0 + tid;
  1138. if (col_x >= p.ncols_x) {
  1139. break;
  1140. }
  1141. const uint row_y = col_x;
  1142. const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
  1143. const uint iy = channel*nrows_y + row_y;
  1144. const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
  1145. tmp[tid] += xi * FLOAT_TYPE(data_b[iy]);
  1146. }
  1147. // sum up partial sums and write back result
  1148. barrier();
  1149. [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
  1150. if (tid < s) {
  1151. tmp[tid] += tmp[tid + s];
  1152. }
  1153. barrier();
  1154. }
  1155. if (tid == 0) {
  1156. dst[idst] = tmp[0];
  1157. }
  1158. }
  1159. """
  1160. # F16 to F32
  1161. f32_to_f16_src = """#version 450
  1162. #extension GL_EXT_shader_16bit_storage : require
  1163. layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
  1164. layout (binding = 0) readonly buffer A {float data_a[];};
  1165. layout (binding = 1) writeonly buffer D {float16_t data_b[];};
  1166. layout (push_constant) uniform parameter
  1167. {
  1168. int M;
  1169. int K;
  1170. int stride_a;
  1171. int stride_b;
  1172. } p;
  1173. void main() {
  1174. const int row = int(gl_GlobalInvocationID.x % p.K);
  1175. const int col = int(gl_GlobalInvocationID.x / p.K);
  1176. if (row < p.K && col < p.M) {
  1177. data_b[col * p.stride_b + row] = float16_t(data_a[col * p.stride_a + row]);
  1178. }
  1179. }
  1180. """
  1181. generic_head = """
  1182. #version 450
  1183. #extension GL_EXT_shader_16bit_storage : require
  1184. layout (push_constant) uniform parameter
  1185. {
  1186. uint KX;
  1187. uint KY;
  1188. float param1;
  1189. float param2;
  1190. } p;
  1191. """
  1192. # MUL F32
  1193. mul_body = """layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
  1194. layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
  1195. layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
  1196. layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
  1197. void main() {
  1198. const uint idx = gl_GlobalInvocationID.x;
  1199. if (idx >= p.KX) {
  1200. return;
  1201. }
  1202. data_d[idx] = D_TYPE(FLOAT_TYPE(data_a[idx]) * FLOAT_TYPE(data_b[idx % p.KY]));
  1203. }
  1204. """
  1205. # ADD
  1206. add_body = """
  1207. layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
  1208. layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
  1209. layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
  1210. layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
  1211. void main() {
  1212. const uint idx = gl_GlobalInvocationID.x;
  1213. if (idx >= p.KX) {
  1214. return;
  1215. }
  1216. data_d[idx] = D_TYPE(FLOAT_TYPE(data_a[idx]) + FLOAT_TYPE(data_b[idx % p.KY]));
  1217. }
  1218. """
  1219. # SCALE
  1220. scale_body = """layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
  1221. layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
  1222. layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
  1223. void main() {
  1224. const uint idx = gl_GlobalInvocationID.x;
  1225. if (idx >= p.KX) {
  1226. return;
  1227. }
  1228. data_d[idx] = D_TYPE(FLOAT_TYPE(data_a[idx]) * FLOAT_TYPE(p.param1));
  1229. }
  1230. """
  1231. # SQR
  1232. sqr_body = """layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
  1233. layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
  1234. layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
  1235. void main() {
  1236. const uint idx = gl_GlobalInvocationID.x;
  1237. if (idx >= p.KX) {
  1238. return;
  1239. }
  1240. const FLOAT_TYPE val = FLOAT_TYPE(data_a[idx]);
  1241. data_d[idx] = D_TYPE(val * val);
  1242. }
  1243. """
  1244. # CLAMP
  1245. clamp_body = """layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
  1246. layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
  1247. layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
  1248. void main() {
  1249. const uint idx = gl_GlobalInvocationID.x;
  1250. if (idx >= p.KX) {
  1251. return;
  1252. }
  1253. const FLOAT_TYPE val = FLOAT_TYPE(data_a[idx]);
  1254. data_d[idx] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
  1255. }
  1256. """
  1257. # CPY
  1258. cpy_src = """#version 450
  1259. #extension GL_EXT_shader_16bit_storage : require
  1260. layout (push_constant) uniform parameter
  1261. {
  1262. uint ne;
  1263. uint ne00; uint ne01; uint nb00; uint nb01; uint nb02;
  1264. uint ne10; uint ne11; uint nb10; uint nb11; uint nb12;
  1265. uint d_offset;
  1266. } p;
  1267. layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
  1268. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  1269. layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
  1270. void main() {
  1271. if (gl_GlobalInvocationID.x >= p.ne) {
  1272. return;
  1273. }
  1274. const uint i02 = gl_GlobalInvocationID.x / (p.ne00*p.ne01);
  1275. const uint i01 = (gl_GlobalInvocationID.x - i02*p.ne01*p.ne00) / p.ne00;
  1276. const uint i00 = gl_GlobalInvocationID.x - i02*p.ne01*p.ne00 - i01*p.ne00;
  1277. const uint a_idx = i00*p.nb00 + i01*p.nb01 + i02*p.nb02;
  1278. const uint i12 = gl_GlobalInvocationID.x / (p.ne10*p.ne11);
  1279. const uint i11 = (gl_GlobalInvocationID.x - i12*p.ne11*p.ne10) / p.ne10;
  1280. const uint i10 = gl_GlobalInvocationID.x - i12*p.ne11*p.ne10 - i11*p.ne10;
  1281. const uint d_idx = i10*p.nb10 + i11*p.nb11 + i12*p.nb12;
  1282. """
  1283. cpy_end = """
  1284. data_d[p.d_offset + d_idx] = D_TYPE(data_a[a_idx]);
  1285. }
  1286. """
  1287. # Causes an optimization error otherwise
  1288. cpy_f16_f16_end = """
  1289. data_d[p.d_offset + d_idx] = data_a[a_idx];
  1290. }
  1291. """
  1292. # GET_ROWS
  1293. get_rows_body = """
  1294. #extension GL_EXT_control_flow_attributes : enable
  1295. #extension GL_EXT_shader_8bit_storage : require
  1296. layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
  1297. layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
  1298. layout (binding = 1) readonly buffer Y {int data_b[];};
  1299. layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
  1300. void main() {
  1301. const uint col = int(gl_GlobalInvocationID.x) * 2;
  1302. const uint row = int(gl_GlobalInvocationID.y);
  1303. if (col >= p.KY) {
  1304. return;
  1305. }
  1306. const uint r = uint(data_b[row]);
  1307. // copy data_a[r*p.KY + col] to dst[row*p.KX + col]
  1308. const uint xi = r*p.KY + col;
  1309. const uint di = row*p.KY + col;
  1310. const uint ib = xi/QUANT_K; // block index
  1311. const uint iqs = (xi%QUANT_K)/QUANT_R; // quant index
  1312. const uint iybs = di - di%QUANT_K; // y block start index
  1313. const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
  1314. DEQUANT_FUNC
  1315. dst[iybs + iqs + 0] = D_TYPE(v.x);
  1316. dst[iybs + iqs + y_offset] = D_TYPE(v.y);
  1317. }
  1318. """
  1319. # UNARY
  1320. gelu_body = """
  1321. #extension GL_EXT_control_flow_attributes : enable
  1322. layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
  1323. layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
  1324. layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
  1325. void main() {
  1326. const float GELU_COEF_A = 0.044715f;
  1327. const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
  1328. const uint i = gl_GlobalInvocationID.x;
  1329. if (i >= p.KX) {
  1330. return;
  1331. }
  1332. const float xi = float(data_a[i]);
  1333. const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi);
  1334. data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1)));
  1335. }
  1336. """
  1337. silu_body = """
  1338. #extension GL_EXT_control_flow_attributes : enable
  1339. layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
  1340. layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
  1341. layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
  1342. void main() {
  1343. const uint i = gl_GlobalInvocationID.x;
  1344. if (i >= p.KX) {
  1345. return;
  1346. }
  1347. const float xi = float(data_a[i]);
  1348. data_d[i] = D_TYPE(xi / (1.0f + exp(-xi)));
  1349. }
  1350. """
  1351. relu_body = """
  1352. #extension GL_EXT_control_flow_attributes : enable
  1353. layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
  1354. layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
  1355. layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
  1356. void main() {
  1357. const uint i = gl_GlobalInvocationID.x;
  1358. if (i >= p.KX) {
  1359. return;
  1360. }
  1361. data_d[i] = max(float(data_a[i]), 0);
  1362. }
  1363. """
  1364. # DIAG_MASK_INF
  1365. diag_mask_inf_head = """#version 450
  1366. #extension GL_EXT_shader_16bit_storage : require
  1367. layout (push_constant) uniform parameter
  1368. {
  1369. uint ncols;
  1370. uint rows_per_channel;
  1371. uint n_past;
  1372. } p;
  1373. """
  1374. diag_mask_inf_body = """
  1375. #extension GL_EXT_control_flow_attributes : enable
  1376. layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
  1377. layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
  1378. layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
  1379. void main() {
  1380. const uint col = gl_GlobalInvocationID.y;
  1381. const uint row = gl_GlobalInvocationID.x;
  1382. if (col >= p.ncols) {
  1383. return;
  1384. }
  1385. const uint i = row*p.ncols + col;
  1386. data_d[i] = D_TYPE(data_a[i] - float(uint(col > p.n_past + row % p.rows_per_channel) * 0xFFFFFFFF));
  1387. }
  1388. """
  1389. # NORMS
  1390. norm_body = """
  1391. #extension GL_EXT_control_flow_attributes : enable
  1392. #define BLOCK_SIZE 512
  1393. layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
  1394. layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
  1395. layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
  1396. shared vec2 sum[BLOCK_SIZE];
  1397. void main() {
  1398. const uint row = gl_WorkGroupID.x;
  1399. const uint tid = gl_LocalInvocationID.x;
  1400. const float eps = 1e-5f;
  1401. sum[tid] = vec2(0.0f, 0.0f);
  1402. [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
  1403. const float xi = float(data_a[row*p.KX + col]);
  1404. sum[tid].x += xi;
  1405. sum[tid].y += xi * xi;
  1406. }
  1407. // sum up partial sums and write back result
  1408. barrier();
  1409. [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
  1410. if (tid < s) {
  1411. sum[tid] += sum[tid + s];
  1412. }
  1413. barrier();
  1414. }
  1415. const float mean = sum[0].x / p.KX;
  1416. const float var = sum[0].y / p.KX - mean * mean;
  1417. const float inv_std = inversesqrt(var + 1e-5f);
  1418. [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
  1419. data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std);
  1420. }
  1421. }
  1422. """
  1423. rms_norm_body = """
  1424. #extension GL_EXT_control_flow_attributes : enable
  1425. #define BLOCK_SIZE 512
  1426. layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
  1427. layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
  1428. layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
  1429. shared FLOAT_TYPE sum[BLOCK_SIZE];
  1430. void main() {
  1431. const uint row = gl_WorkGroupID.x;
  1432. const uint tid = gl_LocalInvocationID.x;
  1433. sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
  1434. [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
  1435. const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
  1436. sum[tid] += xi * xi;
  1437. }
  1438. // sum up partial sums and write back result
  1439. barrier();
  1440. [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
  1441. if (tid < s) {
  1442. sum[tid] += sum[tid + s];
  1443. }
  1444. barrier();
  1445. }
  1446. const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(p.KX);
  1447. const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
  1448. [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
  1449. data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
  1450. }
  1451. }
  1452. """
  1453. # SOFT_MAX
  1454. soft_max_body = """
  1455. #extension GL_EXT_control_flow_attributes : enable
  1456. #define BLOCK_SIZE 512
  1457. layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
  1458. layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
  1459. layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
  1460. layout (binding = 2) buffer D {D_TYPE data_d[];};
  1461. shared FLOAT_TYPE vals[BLOCK_SIZE];
  1462. void main() {
  1463. const uint tid = gl_LocalInvocationID.x;
  1464. const uint rowx = gl_WorkGroupID.x;
  1465. const uint rowy = rowx % p.KY;
  1466. // Find max
  1467. vals[tid] = uintBitsToFloat(0xFF800000);
  1468. [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
  1469. vals[tid] = max(vals[tid], FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.param1 + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)));
  1470. }
  1471. barrier();
  1472. [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
  1473. if (tid < s) {
  1474. vals[tid] = max(vals[tid], vals[tid + s]);
  1475. }
  1476. barrier();
  1477. }
  1478. const FLOAT_TYPE max_val = vals[0];
  1479. barrier();
  1480. // Sum up values
  1481. vals[tid] = FLOAT_TYPE(0.0f);
  1482. [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
  1483. const uint i = rowx * p.KX + col;
  1484. const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.param1 + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
  1485. vals[tid] += val;
  1486. data_d[i] = D_TYPE(val);
  1487. }
  1488. barrier();
  1489. [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
  1490. if (tid < s) {
  1491. vals[tid] += vals[tid + s];
  1492. }
  1493. barrier();
  1494. }
  1495. const D_TYPE divisor = D_TYPE(vals[0]);
  1496. [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
  1497. data_d[rowx*p.KX + col] /= divisor;
  1498. }
  1499. }
  1500. """
  1501. # ROPE
  1502. rope_src = """
  1503. #version 450
  1504. #extension GL_EXT_shader_16bit_storage : require
  1505. layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
  1506. layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
  1507. layout (binding = 1) readonly buffer Y {int data_b[];};
  1508. layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
  1509. layout (push_constant) uniform parameter {
  1510. uint ncols;
  1511. float freq_scale;
  1512. uint p_delta_rows;
  1513. float freq_base;
  1514. float ext_factor;
  1515. float attn_factor;
  1516. float corr_dims[4];
  1517. } p;
  1518. float rope_yarn_ramp(const float low, const float high, const uint i0) {
  1519. const float y = (i0 / 2 - low) / max(0.001f, high - low);
  1520. return 1.0f - min(1.0f, max(0.0f, y));
  1521. }
  1522. void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) {
  1523. float mscale = p.attn_factor;
  1524. // Get n-d rotational scaling corrected for extrapolation
  1525. float theta_interp = p.freq_scale * theta_extrap;
  1526. float theta = theta_interp;
  1527. if (p.ext_factor != 0.0f) {
  1528. float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
  1529. theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
  1530. // Get n-d magnitude scaling corrected for interpolation
  1531. mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
  1532. }
  1533. cos_theta = cos(theta) * mscale;
  1534. sin_theta = sin(theta) * mscale;
  1535. }
  1536. void main() {
  1537. const uint col = gl_GlobalInvocationID.y * 2;
  1538. const uint row = gl_GlobalInvocationID.x;
  1539. if (col >= p.ncols) {
  1540. return;
  1541. }
  1542. const uint i = row*p.ncols + col;
  1543. const uint i2 = row/p.p_delta_rows;
  1544. const int pos = data_b[i2];
  1545. const float theta_base = pos * pow(p.freq_base, -float(col)/p.ncols);
  1546. float cos_theta, sin_theta;
  1547. rope_yarn(theta_base, col, cos_theta, sin_theta);
  1548. const float x0 = float(data_a[i + 0]);
  1549. const float x1 = float(data_a[i + 1]);
  1550. data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
  1551. data_d[i + 1] = D_TYPE(x0*sin_theta + x1*cos_theta);
  1552. }
  1553. """
  1554. rope_neox_src = """
  1555. #version 450
  1556. #extension GL_EXT_shader_16bit_storage : require
  1557. layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
  1558. layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
  1559. layout (binding = 1) readonly buffer Y {int data_b[];};
  1560. layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
  1561. layout (push_constant) uniform parameter {
  1562. uint ncols;
  1563. uint ndims;
  1564. float freq_scale;
  1565. uint p_delta_rows;
  1566. float freq_base;
  1567. float ext_factor;
  1568. float attn_factor;
  1569. float corr_dims[4];
  1570. float theta_scale;
  1571. float inv_ndims;
  1572. } p;
  1573. float rope_yarn_ramp(const float low, const float high, const uint i0) {
  1574. const float y = (i0 / 2 - low) / max(0.001f, high - low);
  1575. return 1.0f - min(1.0f, max(0.0f, y));
  1576. }
  1577. void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) {
  1578. float mscale = p.attn_factor;
  1579. // Get n-d rotational scaling corrected for extrapolation
  1580. float theta_interp = p.freq_scale * theta_extrap;
  1581. float theta = theta_interp;
  1582. if (p.ext_factor != 0.0f) {
  1583. float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor;
  1584. theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
  1585. // Get n-d magnitude scaling corrected for interpolation
  1586. mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale);
  1587. }
  1588. cos_theta = cos(theta) * mscale;
  1589. sin_theta = sin(theta) * mscale;
  1590. }
  1591. void main() {
  1592. const uint col = gl_GlobalInvocationID.y * 2;
  1593. const uint row = gl_GlobalInvocationID.x;
  1594. if (col >= p.ncols) {
  1595. return;
  1596. }
  1597. const uint ib = col / p.ndims;
  1598. const uint ic = col % p.ndims;
  1599. if (ib > 0) {
  1600. const uint i = row*p.ncols + ib*p.ndims + ic;
  1601. data_d[i + 0] = data_a[i + 0];
  1602. data_d[i + 1] = data_a[i + 1];
  1603. return;
  1604. }
  1605. const uint i = row*p.ncols + ib*p.ndims + ic/2;
  1606. const uint i2 = row/p.p_delta_rows;
  1607. const float cur_rot = p.inv_ndims * ic - ib;
  1608. const int pos = data_b[i2];
  1609. const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f);
  1610. float cos_theta, sin_theta;
  1611. rope_yarn(theta_base, uint(cur_rot), cos_theta, sin_theta);
  1612. const float x0 = float(data_a[i + 0]);
  1613. const float x1 = float(data_a[i + p.ndims/2]);
  1614. data_d[i + 0] = D_TYPE(x0*cos_theta - x1*sin_theta);
  1615. data_d[i + p.ndims/2] = D_TYPE(x0*sin_theta + x1*cos_theta);
  1616. }
  1617. """
  1618. GLSLC = "glslc"
  1619. VK_NUM_TYPES = 16
  1620. GGML_TYPE_F32 = 0
  1621. GGML_TYPE_F16 = 1
  1622. GGML_TYPE_Q4_0 = 2
  1623. GGML_TYPE_Q4_1 = 3
  1624. GGML_TYPE_Q5_0 = 6
  1625. GGML_TYPE_Q5_1 = 7
  1626. GGML_TYPE_Q8_0 = 8
  1627. GGML_TYPE_Q8_1 = 9
  1628. GGML_TYPE_Q2_K = 10
  1629. GGML_TYPE_Q3_K = 11
  1630. GGML_TYPE_Q4_K = 12
  1631. GGML_TYPE_Q5_K = 13
  1632. GGML_TYPE_Q6_K = 14
  1633. GGML_TYPE_Q8_K = 15
  1634. type_names = {
  1635. GGML_TYPE_F32: "f32",
  1636. GGML_TYPE_F16: "f16",
  1637. GGML_TYPE_Q4_0: "q4_0",
  1638. GGML_TYPE_Q4_1: "q4_1",
  1639. GGML_TYPE_Q5_0: "q5_0",
  1640. GGML_TYPE_Q5_1: "q5_1",
  1641. GGML_TYPE_Q8_0: "q8_0",
  1642. GGML_TYPE_Q8_1: "q8_1",
  1643. GGML_TYPE_Q2_K: "q2_K",
  1644. GGML_TYPE_Q3_K: "q3_K",
  1645. GGML_TYPE_Q4_K: "q4_K",
  1646. GGML_TYPE_Q5_K: "q5_K",
  1647. GGML_TYPE_Q6_K: "q6_K",
  1648. GGML_TYPE_Q8_K: "q8_K",
  1649. }
  1650. K_QUANTS_PER_ITERATION = 2
  1651. ASYNCIO_CONCURRENCY = 64
  1652. output_dir = gettempdir()
  1653. lock = asyncio.Lock()
  1654. shader_fnames = []
  1655. async def string_to_spv(name, code, defines, fp16=True):
  1656. f = NamedTemporaryFile(mode="w", delete=False)
  1657. f.write(code)
  1658. f.flush()
  1659. name = f"{name}{'_fp32' if not fp16 else ''}"
  1660. fname = os.path.join(output_dir, f"{name}.comp")
  1661. cmd = [GLSLC, "-fshader-stage=compute", "--target-env=vulkan1.2", "-O", f.name, "-o", fname]
  1662. cmd.extend([f"-D{key}={value}" for key, value in defines.items()])
  1663. proc = await asyncio.create_subprocess_exec(*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
  1664. stdout, stderr = await proc.communicate()
  1665. stdout = stdout.decode()
  1666. error = stderr.decode()
  1667. if proc.returncode:
  1668. # Generate preprocessed code
  1669. cmd = [GLSLC, "-E", f.name]
  1670. cmd.extend([f"-D{key}={value}" for key, value in defines.items()])
  1671. proc = await asyncio.create_subprocess_exec(*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE)
  1672. stdout, stderr = await proc.communicate()
  1673. print(" ".join(cmd))
  1674. if proc.returncode:
  1675. raise RuntimeError(f"{name=} {f.name=} {stdout=} {stderr=}")
  1676. preprocessed_code = stdout.decode()
  1677. cmd.extend([f"-D{key}={value}" for key, value in defines.items()])
  1678. code_with_lines = "\n".join([f"{i + 1}: {line}" for i, line in enumerate(preprocessed_code.splitlines())])
  1679. print(f"ERROR compiling {name}\n\n{code_with_lines}\n\n{error}")
  1680. f.close()
  1681. os.remove(f.name)
  1682. sys.exit(proc.returncode)
  1683. f.close()
  1684. os.remove(f.name)
  1685. async with lock:
  1686. shader_fnames.append((name, fname))
  1687. async def main():
  1688. print("ggml_vulkan: Generating and compiling shaders to SPIR-V")
  1689. tasks = []
  1690. for fp16 in (False, True):
  1691. # mulmat
  1692. if fp16:
  1693. shader_float_type = shader_f16
  1694. load_vec = "8"
  1695. vec_type_f16 = "f16mat2x4"
  1696. vec_type = "mat2x4"
  1697. else:
  1698. shader_float_type = shader_f32
  1699. load_vec = "4"
  1700. vec_type_f16 = "f16vec4"
  1701. vec_type = "vec4"
  1702. stream = []
  1703. stream.extend((mulmat_head, shader_float_type, mulmat_body))
  1704. tasks.append(string_to_spv("matmul_f32_l", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
  1705. tasks.append(string_to_spv("matmul_f32_m", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
  1706. tasks.append(string_to_spv("matmul_f32_s", "".join(stream), {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
  1707. tasks.append(string_to_spv("matmul_f32_aligned_l", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
  1708. tasks.append(string_to_spv("matmul_f32_aligned_m", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
  1709. tasks.append(string_to_spv("matmul_f32_aligned_s", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
  1710. tasks.append(string_to_spv("matmul_f16_l", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
  1711. tasks.append(string_to_spv("matmul_f16_m", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
  1712. tasks.append(string_to_spv("matmul_f16_s", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float16_t", "D_TYPE": "float"}, fp16))
  1713. tasks.append(string_to_spv("matmul_f16_aligned_l", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16))
  1714. tasks.append(string_to_spv("matmul_f16_aligned_m", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16))
  1715. tasks.append(string_to_spv("matmul_f16_aligned_s", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type_f16, "D_TYPE": "float"}, fp16))
  1716. tasks.append(string_to_spv("matmul_f16_f32_l", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
  1717. tasks.append(string_to_spv("matmul_f16_f32_m", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
  1718. tasks.append(string_to_spv("matmul_f16_f32_s", "".join(stream), {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}, fp16))
  1719. tasks.append(string_to_spv("matmul_f16_f32_aligned_l", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
  1720. tasks.append(string_to_spv("matmul_f16_f32_aligned_m", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
  1721. tasks.append(string_to_spv("matmul_f16_f32_aligned_s", "".join(stream), {"LOAD_VEC": load_vec, "A_TYPE": vec_type_f16, "B_TYPE": vec_type, "D_TYPE": "float"}, fp16))
  1722. # Shaders where precision is needed, so no fp16 version
  1723. # mul mat vec
  1724. for i in range(0, VK_NUM_TYPES):
  1725. stream.clear()
  1726. stream.extend((mul_mat_vec_head, shader_int8_ext, shader_f32))
  1727. if i == GGML_TYPE_F16:
  1728. stream.extend((shader_f16_defines, shader_f16_dequant_func, mul_mat_vec_body))
  1729. elif i == GGML_TYPE_Q4_0:
  1730. stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, mul_mat_vec_body))
  1731. elif i == GGML_TYPE_Q4_1:
  1732. stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, mul_mat_vec_body))
  1733. elif i == GGML_TYPE_Q5_0:
  1734. stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, mul_mat_vec_body))
  1735. elif i == GGML_TYPE_Q5_1:
  1736. stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, mul_mat_vec_body))
  1737. elif i == GGML_TYPE_Q8_0:
  1738. stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, mul_mat_vec_body))
  1739. elif i == GGML_TYPE_Q2_K:
  1740. stream.extend((shader_q2_K_defines, mul_mat_vec_q2_K_body))
  1741. elif i == GGML_TYPE_Q3_K:
  1742. stream.extend((shader_q3_K_defines, mul_mat_vec_q3_K_body))
  1743. elif i == GGML_TYPE_Q4_K:
  1744. stream.extend((shader_q4_K_defines, mul_mat_vec_q4_K_body))
  1745. elif i == GGML_TYPE_Q5_K:
  1746. stream.extend((shader_q5_K_defines, mul_mat_vec_q5_K_body))
  1747. elif i == GGML_TYPE_Q6_K:
  1748. stream.extend((shader_q6_K_defines, mul_mat_vec_q6_K_body))
  1749. else:
  1750. continue
  1751. tasks.append(string_to_spv(f"mul_mat_vec_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float", "K_QUANTS_PER_ITERATION": K_QUANTS_PER_ITERATION}))
  1752. # Dequant shaders
  1753. for i in range(0, VK_NUM_TYPES):
  1754. stream.clear()
  1755. stream.extend((dequant_head, shader_int8_ext, shader_f32))
  1756. if i == GGML_TYPE_F16:
  1757. stream.extend((shader_f16_defines, shader_f16_dequant_func, dequant_body))
  1758. elif i == GGML_TYPE_Q4_0:
  1759. stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, dequant_body))
  1760. elif i == GGML_TYPE_Q4_1:
  1761. stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, dequant_body))
  1762. elif i == GGML_TYPE_Q5_0:
  1763. stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, dequant_body))
  1764. elif i == GGML_TYPE_Q5_1:
  1765. stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, dequant_body))
  1766. elif i == GGML_TYPE_Q8_0:
  1767. stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, dequant_body))
  1768. elif i == GGML_TYPE_Q2_K:
  1769. stream.extend((shader_q2_K_defines, dequant_q2_K_body))
  1770. elif i == GGML_TYPE_Q3_K:
  1771. stream.extend((shader_q3_K_defines, dequant_q3_K_body))
  1772. elif i == GGML_TYPE_Q4_K:
  1773. stream.extend((shader_q4_K_defines, dequant_q4_K_body))
  1774. elif i == GGML_TYPE_Q5_K:
  1775. stream.extend((shader_q5_K_defines, dequant_q5_K_body))
  1776. elif i == GGML_TYPE_Q6_K:
  1777. stream.extend((shader_q6_K_defines, dequant_q6_K_body))
  1778. else:
  1779. continue
  1780. tasks.append(string_to_spv(f"dequant_{type_names[i]}", "".join(stream), {"D_TYPE": "float16_t"}))
  1781. tasks.append(string_to_spv("f32_to_f16", f32_to_f16_src, {}))
  1782. # get_rows
  1783. for i in range(0, VK_NUM_TYPES):
  1784. stream.clear()
  1785. stream.extend((generic_head, shader_int8_ext, shader_f32))
  1786. if i == GGML_TYPE_F16:
  1787. stream.extend((shader_f16_defines, shader_f16_dequant_func, get_rows_body))
  1788. elif i == GGML_TYPE_Q4_0:
  1789. stream.extend((shader_q4_0_defines, shader_q4_0_dequant_func, get_rows_body))
  1790. elif i == GGML_TYPE_Q4_1:
  1791. stream.extend((shader_q4_1_defines, shader_q4_1_dequant_func, get_rows_body))
  1792. elif i == GGML_TYPE_Q5_0:
  1793. stream.extend((shader_q5_0_defines, shader_q5_0_dequant_func, get_rows_body))
  1794. elif i == GGML_TYPE_Q5_1:
  1795. stream.extend((shader_q5_1_defines, shader_q5_1_dequant_func, get_rows_body))
  1796. elif i == GGML_TYPE_Q8_0:
  1797. stream.extend((shader_q8_0_defines, shader_q8_0_dequant_func, get_rows_body))
  1798. else:
  1799. continue
  1800. tasks.append(string_to_spv(f"get_rows_{type_names[i]}", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float16_t"}))
  1801. tasks.append(string_to_spv(f"get_rows_{type_names[i]}_f32", "".join(stream), {"B_TYPE": "float", "D_TYPE": "float"}))
  1802. tasks.append(string_to_spv("mul_mat_vec_p021_f16_f32", mul_mat_p021_src, {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}))
  1803. tasks.append(string_to_spv("mul_mat_vec_nc_f16_f32", mul_mat_nc_src, {"A_TYPE": "float16_t", "B_TYPE": "float", "D_TYPE": "float"}))
  1804. # Norms
  1805. tasks.append(string_to_spv("norm_f32", f"{generic_head}\n{shader_f32}\n{norm_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
  1806. tasks.append(string_to_spv("rms_norm_f32", f"{generic_head}\n{shader_f32}\n{rms_norm_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
  1807. tasks.append(string_to_spv("cpy_f32_f32", f"{cpy_src}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float"}))
  1808. tasks.append(string_to_spv("cpy_f32_f16", f"{cpy_src}\n{cpy_end}", {"A_TYPE": "float", "D_TYPE": "float16_t"}))
  1809. tasks.append(string_to_spv("cpy_f16_f16", f"{cpy_src}\n{cpy_f16_f16_end}", {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
  1810. tasks.append(string_to_spv("add_f32", f"{generic_head}\n{shader_f32}\n{add_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}))
  1811. tasks.append(string_to_spv("split_k_reduce", mulmat_split_k_reduce_src, {}))
  1812. tasks.append(string_to_spv("mul_f32", f"{generic_head}\n{shader_f32}\n{mul_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}))
  1813. tasks.append(string_to_spv("scale_f32", f"{generic_head}\n{shader_f32}\n{scale_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
  1814. tasks.append(string_to_spv("sqr_f32", f"{generic_head}\n{shader_f32}\n{sqr_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
  1815. tasks.append(string_to_spv("clamp_f32", f"{generic_head}\n{shader_f32}\n{clamp_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
  1816. tasks.append(string_to_spv("gelu_f32", f"{generic_head}\n{shader_f32}\n{gelu_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
  1817. tasks.append(string_to_spv("silu_f32", f"{generic_head}\n{shader_f32}\n{silu_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
  1818. tasks.append(string_to_spv("relu_f32", f"{generic_head}\n{shader_f32}\n{relu_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
  1819. tasks.append(string_to_spv("diag_mask_inf_f32", f"{diag_mask_inf_head}\n{shader_f32}\n{diag_mask_inf_body}", {"A_TYPE": "float", "D_TYPE": "float"}))
  1820. tasks.append(string_to_spv("soft_max_f32", f"{generic_head}\n{shader_f32}\n{soft_max_body}", {"A_TYPE": "float", "B_TYPE": "float", "D_TYPE": "float"}))
  1821. tasks.append(string_to_spv("rope_f32", rope_src, {"A_TYPE": "float", "D_TYPE": "float"}))
  1822. tasks.append(string_to_spv("rope_f16", rope_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
  1823. tasks.append(string_to_spv("rope_neox_f32", rope_neox_src, {"A_TYPE": "float", "D_TYPE": "float"}))
  1824. tasks.append(string_to_spv("rope_neox_f16", rope_neox_src, {"A_TYPE": "float16_t", "D_TYPE": "float16_t"}))
  1825. # Helper to decorate tasks with semaphore acquisition.
  1826. async def withSemaphore(sem, task):
  1827. async with sem:
  1828. return await task
  1829. # Run tasks concurrently guarded by a concurrency limit.
  1830. sem = asyncio.Semaphore(ASYNCIO_CONCURRENCY)
  1831. await asyncio.gather(*(withSemaphore(sem, task) for task in tasks))
  1832. with open("ggml-vulkan-shaders.hpp", "w") as f:
  1833. f.write("#include <cstdint>\n\n")
  1834. for name, path in sorted(shader_fnames):
  1835. with open(path, "rb") as spv:
  1836. counter = 0
  1837. newline_counter = 0
  1838. f.write(f"unsigned char {name}_data[] = {{\n")
  1839. for val in spv.read():
  1840. f.write(f"0x{val:02x},")
  1841. newline_counter += 1
  1842. counter += 1
  1843. if newline_counter >= 12:
  1844. newline_counter = 0
  1845. f.write("\n")
  1846. f.write("\n};\n")
  1847. f.write(f"const uint64_t {name}_len = {counter};\n\n")
  1848. os.remove(path)
  1849. if __name__ == "__main__":
  1850. parser = argparse.ArgumentParser(description="GGML Vulkan Shader Generator")
  1851. parser.add_argument("--glslc", help="Path to glslc")
  1852. args = parser.parse_args()
  1853. if args.glslc:
  1854. GLSLC = args.glslc
  1855. asyncio.run(main())