mul_mm.comp 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. #version 450
  2. #extension GL_EXT_control_flow_attributes : enable
  3. #extension GL_EXT_shader_16bit_storage : require
  4. #ifdef FLOAT16
  5. #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
  6. #endif
  7. #ifdef MUL_MAT_ID
  8. #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
  9. #endif
  10. #include "types.comp"
  11. #ifndef LOAD_VEC_A
  12. #define LOAD_VEC_A 1
  13. #endif
  14. #ifndef LOAD_VEC_B
  15. #define LOAD_VEC_B 1
  16. #endif
  17. layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
  18. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  19. layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
  20. layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
  21. #ifdef MUL_MAT_ID
  22. layout (binding = 3) readonly buffer IDS {int data_ids[];};
  23. #endif
  24. layout (push_constant) uniform parameter
  25. {
  26. uint M;
  27. uint N;
  28. uint K;
  29. uint stride_a;
  30. uint stride_b;
  31. uint stride_d;
  32. uint batch_stride_a;
  33. uint batch_stride_b;
  34. uint batch_stride_d;
  35. #ifdef MUL_MAT_ID
  36. uint nei0;
  37. uint nei1;
  38. uint nbi1;
  39. uint ne11;
  40. #else
  41. uint k_split;
  42. uint ne02;
  43. uint ne12;
  44. uint broadcast2;
  45. uint broadcast3;
  46. #endif
  47. } p;
  48. layout (constant_id = 1) const uint BM = 64;
  49. layout (constant_id = 2) const uint BN = 64;
  50. layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
  51. layout (constant_id = 4) const uint WM = 32;
  52. layout (constant_id = 5) const uint WN = 32;
  53. layout (constant_id = 6) const uint WMITER = 2;
  54. layout (constant_id = 7) const uint TM = 4;
  55. layout (constant_id = 8) const uint TN = 2;
  56. layout (constant_id = 9) const uint WARP = 32;
  57. shared FLOAT_TYPE buf_a[BM * (BK+1)];
  58. shared FLOAT_TYPE buf_b[BN * (BK+1)];
  59. #ifdef MUL_MAT_ID
  60. shared u16vec2 row_ids[2048];
  61. #endif
  62. void main() {
  63. #ifdef MUL_MAT_ID
  64. const uint expert_idx = gl_GlobalInvocationID.z;
  65. #else
  66. const uint batch_idx = gl_GlobalInvocationID.z;
  67. const uint i13 = batch_idx / p.ne12;
  68. const uint i12 = batch_idx % p.ne12;
  69. const uint i03 = i13 / p.broadcast3;
  70. const uint i02 = i12 / p.broadcast2;
  71. const uint batch_idx_a = i03 * p.ne02 + i02;
  72. #endif
  73. const uint blocks_m = (p.M + BM - 1) / BM;
  74. const uint ir = gl_WorkGroupID.x % blocks_m;
  75. const uint ik = gl_WorkGroupID.x / blocks_m;
  76. const uint ic = gl_WorkGroupID.y;
  77. const uint warp_i = gl_LocalInvocationID.x / WARP;
  78. const uint warp_r = warp_i % (BM / WM);
  79. const uint warp_c = warp_i / (BM / WM);
  80. const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
  81. const uint WSUBM = WM / WMITER;
  82. const uint WSUBN = WN / WNITER;
  83. const uint tiw = gl_LocalInvocationID.x % WARP;
  84. const uint tiwr = tiw % (WSUBM / TM);
  85. const uint tiwc = tiw / (WSUBM / TM);
  86. const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
  87. const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
  88. const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
  89. const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
  90. const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK;
  91. const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
  92. #ifdef MUL_MAT_ID
  93. uint _ne1 = 0;
  94. for (uint ii1 = 0; ii1 < p.nei1; ii1++) {
  95. for (uint ii0 = 0; ii0 < p.nei0; ii0++) {
  96. if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
  97. row_ids[_ne1] = u16vec2(ii0, ii1);
  98. _ne1++;
  99. }
  100. }
  101. }
  102. barrier();
  103. // Workgroup has no work
  104. if (ic * BN >= _ne1) return;
  105. #endif
  106. #ifdef MUL_MAT_ID
  107. const uint start_k = 0;
  108. const uint end_k = p.K;
  109. #else
  110. const uint start_k = ik * p.k_split;
  111. const uint end_k = min(p.K, (ik + 1) * p.k_split);
  112. #endif
  113. uint pos_a = (
  114. #ifdef MUL_MAT_ID
  115. expert_idx * p.batch_stride_a +
  116. #else
  117. batch_idx_a * p.batch_stride_a +
  118. #endif
  119. ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
  120. #ifdef MUL_MAT_ID
  121. uint pos_b = 0;
  122. #else
  123. uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
  124. #endif
  125. float sums[WMITER * TM * WNITER * TN];
  126. FLOAT_TYPE cache_a[WMITER * TM];
  127. FLOAT_TYPE cache_b[WNITER * TN];
  128. [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
  129. sums[i] = 0.0f;
  130. }
  131. [[unroll]] for (uint block = start_k; block < end_k; block += BK) {
  132. [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
  133. #if defined(DATA_A_F32) || defined(DATA_A_F16)
  134. #if LOAD_VEC_A == 8
  135. const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
  136. const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
  137. buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x);
  138. buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y);
  139. buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z);
  140. buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w);
  141. buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x);
  142. buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y);
  143. buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z);
  144. buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w);
  145. #elif LOAD_VEC_A == 4
  146. const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
  147. const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
  148. buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x);
  149. buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y);
  150. buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z);
  151. buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w);
  152. #else
  153. if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
  154. buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
  155. } else {
  156. buf_a[(loadc_a + l) * (BK+1) + loadr_a] = FLOAT_TYPE(0.0f);
  157. }
  158. #endif
  159. #elif defined(DATA_A_Q4_0)
  160. const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
  161. const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
  162. const uint ib = idx / 16;
  163. const uint iqs = idx & 0xF;
  164. const float d = float(data_a[ib].d);
  165. const uint vui = uint(data_a[ib].qs[iqs]);
  166. const vec2 v = (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
  167. buf_a[buf_idx ] = FLOAT_TYPE(v.x);
  168. buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
  169. #elif defined(DATA_A_Q4_1)
  170. const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
  171. const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
  172. const uint ib = idx / 16;
  173. const uint iqs = idx & 0xF;
  174. const float d = float(data_a[ib].d);
  175. const float m = float(data_a[ib].m);
  176. const uint vui = uint(data_a[ib].qs[iqs]);
  177. const vec2 v = vec2(vui & 0xF, vui >> 4) * d + m;
  178. buf_a[buf_idx ] = FLOAT_TYPE(v.x);
  179. buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
  180. #elif defined(DATA_A_Q5_0)
  181. const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
  182. const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
  183. const uint ib = idx / 16;
  184. const uint iqs = idx & 0xF;
  185. const float d = float(data_a[ib].d);
  186. const uint uint_qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0];
  187. const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
  188. const uint vui = uint(data_a[ib].qs[iqs]);
  189. const vec2 v = (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
  190. buf_a[buf_idx ] = FLOAT_TYPE(v.x);
  191. buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
  192. #elif defined(DATA_A_Q5_1)
  193. const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
  194. const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a;
  195. const uint ib = idx / 16;
  196. const uint iqs = idx & 0xF;
  197. const float d = float(data_a[ib].d);
  198. const float m = float(data_a[ib].m);
  199. const uint uint_qh = data_a[ib].qh;
  200. const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
  201. const uint vui = uint(data_a[ib].qs[iqs]);
  202. const vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
  203. buf_a[buf_idx ] = FLOAT_TYPE(v.x);
  204. buf_a[buf_idx + 16] = FLOAT_TYPE(v.y);
  205. #elif defined(DATA_A_Q8_0)
  206. const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
  207. const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
  208. const uint ib = idx / 16;
  209. const uint iqs = (idx & 0xF) * 2;
  210. const float d = float(data_a[ib].d);
  211. const vec2 v = vec2(int(data_a[ib].qs[iqs]), int(data_a[ib].qs[iqs + 1])) * d;
  212. buf_a[buf_idx ] = FLOAT_TYPE(v.x);
  213. buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
  214. #elif defined(DATA_A_Q2_K)
  215. const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
  216. const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
  217. const uint ib = idx / 128; // 2 values per idx
  218. const uint iqs = idx % 128; // 0..127
  219. const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30
  220. const uint scalesi = iqs / 8; // 0..15
  221. const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
  222. const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]);
  223. const uint scales = data_a[ib].scales[scalesi];
  224. const vec2 d = vec2(data_a[ib].d);
  225. const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4);
  226. buf_a[buf_idx ] = FLOAT_TYPE(v.x);
  227. buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
  228. #elif defined(DATA_A_Q3_K)
  229. const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
  230. const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
  231. const uint ib = idx / 128; // 2 values per idx
  232. const uint iqs = idx % 128; // 0..127
  233. const uint n = iqs / 64; // 0,1
  234. const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
  235. const uint hmi = (iqs % 16) * 2; // 0,2,4..30
  236. const uint j = (iqs % 64) / 4; // 0..3
  237. const uint is = iqs / 8; // 0..15
  238. const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3
  239. const uint qsshift = halfsplit * 2; // 0,2,4,6
  240. const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
  241. const int8_t us = int8_t(is < 4 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+8] >> 0) & 3) << 4) :
  242. is < 8 ? (data_a[ib].scales[is-0] & 0xF) | (((data_a[ib].scales[is+4] >> 2) & 3) << 4) :
  243. is < 12 ? (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is+0] >> 4) & 3) << 4) :
  244. (data_a[ib].scales[is-8] >> 4) | (((data_a[ib].scales[is-4] >> 6) & 3) << 4));
  245. const float dl = float(data_a[ib].d) * float(us - 32);
  246. buf_a[buf_idx ] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)));
  247. buf_a[buf_idx + 1] = FLOAT_TYPE(dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4)));
  248. #elif defined(DATA_A_Q4_K)
  249. const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
  250. const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
  251. const uint ib = idx / 128; // 2 values per idx
  252. const uint iqs = idx % 128; // 0..127
  253. const uint n = iqs / 32; // 0,1,2,3
  254. const uint b = (iqs % 32) / 16; // 0,1
  255. const uint is = 2 * n + b; // 0..7
  256. const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
  257. const vec2 loadd = vec2(data_a[ib].d);
  258. uint8_t sc;
  259. uint8_t mbyte;
  260. if (is < 4) {
  261. sc = uint8_t(data_a[ib].scales[is ] & 63);
  262. mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
  263. } else {
  264. sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
  265. mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
  266. }
  267. const float d = loadd.x * sc;
  268. const float m = loadd.y * mbyte;
  269. buf_a[buf_idx ] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) - m);
  270. buf_a[buf_idx + 1] = FLOAT_TYPE(d * float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) - m);
  271. #elif defined(DATA_A_Q5_K)
  272. const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
  273. const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
  274. const uint ib = idx / 128; // 2 values per idx
  275. const uint iqs = idx % 128; // 0..127
  276. const uint n = iqs / 32; // 0,1,2,3
  277. const uint b = (iqs % 32) / 16; // 0,1
  278. const uint is = 2 * n + b; // 0..7
  279. const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126
  280. const uint qhi = (iqs % 16) * 2; // 0,2,4..30
  281. const uint8_t hm = uint8_t(1 << (iqs / 16));
  282. const vec2 loadd = vec2(data_a[ib].d);
  283. uint8_t sc;
  284. uint8_t mbyte;
  285. if (is < 4) {
  286. sc = uint8_t(data_a[ib].scales[is ] & 63);
  287. mbyte = uint8_t(data_a[ib].scales[is + 4] & 63);
  288. } else {
  289. sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4));
  290. mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4));
  291. }
  292. const float d = loadd.x * sc;
  293. const float m = loadd.y * mbyte;
  294. buf_a[buf_idx ] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0)) - m);
  295. buf_a[buf_idx + 1] = FLOAT_TYPE(d * (float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0)) - m);
  296. #elif defined(DATA_A_Q6_K)
  297. const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
  298. const uint buf_idx = (loadc_a + l) * (BK+1) + loadr_a * LOAD_VEC_A;
  299. const uint ib = idx / 128; // 2 values per idx
  300. const uint iqs = idx % 128; // 0..127
  301. const uint n = iqs / 64; // 0,1
  302. const uint b = (iqs % 64) / 32; // 0,1
  303. const uint is_b = (iqs % 16) / 8; // 0,1
  304. const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
  305. const uint is = 8 * n + qhshift + is_b; // 0..15
  306. const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
  307. const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
  308. const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]);
  309. buf_a[buf_idx ] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32));
  310. buf_a[buf_idx + 1] = FLOAT_TYPE(dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32));
  311. #endif
  312. }
  313. [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
  314. #if LOAD_VEC_B == 8
  315. #ifdef MUL_MAT_ID
  316. const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
  317. const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
  318. #else
  319. const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
  320. #endif
  321. const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
  322. buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x);
  323. buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y);
  324. buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z);
  325. buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w);
  326. buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x);
  327. buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y);
  328. buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z);
  329. buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
  330. #elif LOAD_VEC_B == 4
  331. #ifdef MUL_MAT_ID
  332. const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l];
  333. const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
  334. #else
  335. const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
  336. #endif
  337. const uint buf_idx = (loadc_b + l) * (BK+1) + loadr_b * LOAD_VEC_B;
  338. buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx].x);
  339. buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx].y);
  340. buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx].z);
  341. buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx].w);
  342. #elif !MUL_MAT_ID
  343. if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
  344. buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);
  345. } else {
  346. buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
  347. }
  348. #else
  349. const uint row_i = ic * BN + loadc_b + l;
  350. if (row_i < _ne1) {
  351. const u16vec2 row_idx = row_ids[row_i];
  352. buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
  353. } else {
  354. buf_b[(loadc_b + l) * (BK+1) + loadr_b] = FLOAT_TYPE(0.0f);
  355. }
  356. #endif
  357. }
  358. barrier();
  359. pos_a += BK / LOAD_VEC_A;
  360. pos_b += BK / LOAD_VEC_B;
  361. for (uint i = 0; i < BK; i++) {
  362. // Load from shared into cache
  363. [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
  364. [[unroll]] for (uint j = 0; j < TM; j++) {
  365. cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
  366. }
  367. }
  368. [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
  369. [[unroll]] for (uint j = 0; j < TN; j++) {
  370. cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
  371. }
  372. }
  373. [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
  374. [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
  375. [[unroll]] for (uint cc = 0; cc < TN; cc++) {
  376. [[unroll]] for (uint cr = 0; cr < TM; cr++) {
  377. sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]);
  378. }
  379. }
  380. }
  381. }
  382. }
  383. barrier();
  384. }
  385. const uint dr = ir * BM + warp_r * WM;
  386. const uint dc = ic * BN + warp_c * WN;
  387. #ifndef MUL_MAT_ID
  388. const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
  389. #endif
  390. [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
  391. [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
  392. const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
  393. const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
  394. [[unroll]] for (uint cc = 0; cc < TN; cc++) {
  395. #ifdef MUL_MAT_ID
  396. const uint row_i = dc_warp + cc;
  397. if (row_i >= _ne1) break;
  398. const u16vec2 row_idx = row_ids[row_i];
  399. #endif
  400. [[unroll]] for (uint cr = 0; cr < TM; cr++) {
  401. #ifdef MUL_MAT_ID
  402. data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
  403. #else
  404. if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
  405. data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
  406. }
  407. #endif
  408. }
  409. }
  410. }
  411. }
  412. }