1
0

mul_mm_cm2.comp 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. #version 450
  2. #extension GL_EXT_control_flow_attributes : enable
  3. #extension GL_EXT_shader_16bit_storage : require
  4. #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
  5. #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
  6. #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
  7. #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
  8. #extension GL_KHR_memory_scope_semantics : enable
  9. #extension GL_KHR_cooperative_matrix : enable
  10. #extension GL_NV_cooperative_matrix2 : enable
  11. #extension GL_EXT_buffer_reference : enable
  12. #extension GL_KHR_shader_subgroup_ballot : enable
  13. #extension GL_KHR_shader_subgroup_vote : enable
  14. #ifdef DATA_A_BF16
  15. #extension GL_EXT_bfloat16 : enable
  16. #endif
  17. #include "types.comp"
  18. #include "utils.comp"
  19. layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
  20. #define IS_MUL_MM2 1
  21. layout (constant_id = 0) const uint BLOCK_SIZE = 256;
  22. layout (constant_id = 1) const uint BM = 64;
  23. layout (constant_id = 2) const uint BN = 64;
  24. layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
  25. layout (constant_id = 4) const bool enable_smaller_matrices = false;
  26. const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
  27. const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
  28. layout (push_constant) uniform parameter
  29. {
  30. uint M;
  31. uint N;
  32. uint K;
  33. uint stride_a;
  34. uint stride_b;
  35. uint stride_d;
  36. uint batch_stride_a;
  37. uint batch_stride_b;
  38. uint batch_stride_d;
  39. #ifdef MUL_MAT_ID
  40. uint nei0;
  41. uint nei1;
  42. uint nbi1;
  43. uint ne11;
  44. #else
  45. uint k_split;
  46. uint ne02;
  47. uint ne12;
  48. uint broadcast2;
  49. uint broadcast3;
  50. #endif
  51. // N dimension for the B matrix can be >= p.N
  52. uint padded_N;
  53. } p;
  54. layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
  55. layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
  56. layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
  57. #if QUANT_K > 1
  58. #define DECODEFUNCA , dequantFuncA
  59. #include "dequant_funcs_cm2.comp"
  60. #else
  61. #define DECODEFUNCA
  62. #endif
  63. #if !defined(fetch_scales)
  64. #define fetch_scales(a, b, c, d, e, f)
  65. #endif
  66. #if !defined(store_scales)
  67. #define store_scales(a)
  68. #endif
  69. #if defined(DATA_A_BF16)
  70. #define MAT_TYPE bfloat16_t
  71. #else
  72. #define MAT_TYPE FLOAT_TYPE
  73. #endif
  74. #ifdef MUL_MAT_ID
  75. layout (binding = 3) readonly buffer IDS {int data_ids[];};
  76. shared u16vec4 row_ids[4096];
  77. layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
  78. B_TYPE b[];
  79. };
  80. uint _ne1;
  81. layout (constant_id = 5) const uint subgroup_size = 32;
  82. shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
  83. B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
  84. {
  85. const uint row_i = blockCoords[0];
  86. if (row_i >= _ne1) {
  87. return B_TYPE(0.0);
  88. }
  89. const u16vec4 row_idx = row_ids[row_i];
  90. B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
  91. return ret;
  92. }
  93. D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic)
  94. {
  95. uint dr = ir * BM + r;
  96. uint dc = ic * BN + c;
  97. if (dr < p.M && dc < _ne1) {
  98. uint row_i = dc;
  99. const u16vec4 row_idx = row_ids[row_i];
  100. data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
  101. }
  102. return elem;
  103. }
  104. void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
  105. _ne1 = 0;
  106. uint num_elements = p.nei1 * p.nei0;
  107. uint nei0shift = findLSB(p.nei0);
  108. uint ids[16];
  109. uint iter = 0;
  110. for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
  111. // prefetch up to 16 elements
  112. if (iter == 0) {
  113. [[unroll]] for (uint k = 0; k < 16; ++k) {
  114. uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
  115. bool in_range = i < num_elements;
  116. uint ii1;
  117. if (nei0_is_pow2) {
  118. ii1 = i >> nei0shift;
  119. } else {
  120. ii1 = i / p.nei0;
  121. }
  122. uint ii0 = i - ii1 * p.nei0;
  123. ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
  124. }
  125. }
  126. uint i = j + gl_LocalInvocationIndex;
  127. bool in_range = i < num_elements;
  128. uint ii1;
  129. if (nei0_is_pow2) {
  130. ii1 = i >> nei0shift;
  131. } else {
  132. ii1 = i / p.nei0;
  133. }
  134. uint ii0 = i - ii1 * p.nei0;
  135. uint id = ids[iter++];
  136. uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
  137. ballots_sh[gl_SubgroupID] = ballot;
  138. barrier();
  139. uint subgroup_base = 0;
  140. uint total = 0;
  141. for (uint k = 0; k < gl_NumSubgroups; ++k) {
  142. if (k == gl_SubgroupID) {
  143. subgroup_base = total;
  144. }
  145. total += subgroupBallotBitCount(ballots_sh[k]);
  146. }
  147. barrier();
  148. uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
  149. if (in_range && id == expert_idx) {
  150. row_ids[_ne1 + idx] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
  151. }
  152. _ne1 += total;
  153. iter &= 15;
  154. }
  155. barrier();
  156. }
  157. #endif
  158. void main() {
  159. #ifdef NEEDS_INIT_IQ_SHMEM
  160. init_iq_shmem(gl_WorkGroupSize);
  161. #endif
  162. const uint tid = gl_LocalInvocationIndex;
  163. #ifdef MUL_MAT_ID
  164. const uint expert_idx = gl_GlobalInvocationID.z;
  165. #else
  166. const uint batch_idx = gl_GlobalInvocationID.z;
  167. const uint i13 = batch_idx / p.ne12;
  168. const uint i12 = batch_idx % p.ne12;
  169. const uint i03 = i13 / p.broadcast3;
  170. const uint i02 = i12 / p.broadcast2;
  171. const uint batch_idx_a = i03 * p.ne02 + i02;
  172. #endif
  173. const uint blocks_m = (p.M + BM - 1) / BM;
  174. const uint ir = gl_WorkGroupID.x % blocks_m;
  175. const uint ik = gl_WorkGroupID.x / blocks_m;
  176. const uint ic = gl_WorkGroupID.y;
  177. #ifdef MUL_MAT_ID
  178. if (bitCount(p.nei0) == 1) {
  179. load_row_ids(expert_idx, true);
  180. } else {
  181. load_row_ids(expert_idx, false);
  182. }
  183. // Workgroup has no work
  184. if (ic * BN >= _ne1) return;
  185. #endif
  186. #ifdef MUL_MAT_ID
  187. uint start_k = 0;
  188. const uint end_k = p.K;
  189. #else
  190. uint start_k = ik * p.k_split;
  191. const uint end_k = min(p.K, (ik + 1) * p.k_split);
  192. #endif
  193. #ifdef MUL_MAT_ID
  194. uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
  195. uint pos_b = 0;
  196. #else
  197. uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
  198. uint pos_b = batch_idx * p.batch_stride_b;
  199. uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
  200. #endif
  201. uint stride_a = p.stride_a / QUANT_K;
  202. uint stride_b = p.stride_b;
  203. // Hint to the compiler that values are aligned (want 16B alignment).
  204. // Quants are always block-aligned, no alignment needed.
  205. #if ALIGNED
  206. #if QUANT_K == 1
  207. stride_a &= ~7;
  208. #endif
  209. stride_b &= ~7;
  210. #endif
  211. // Create layouts for both clamped and unclamped accesses
  212. tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);
  213. tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
  214. tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
  215. tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
  216. tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
  217. tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
  218. #if QUANT_K > 1
  219. tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
  220. tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K);
  221. #endif
  222. // Use end_k rather than p.K as the dimension because that's what
  223. // we need to bound check against when using split_k.
  224. // Bounds check B against padded_N, but bounds check D against N.
  225. tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k);
  226. tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k);
  227. tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M);
  228. tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
  229. tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);
  230. tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
  231. #if !defined(MUL_MAT_ID)
  232. const uint START_ALIGN_K = 256;
  233. // For Qi_K (block size 256), unroll whole 256 element tiles.
  234. // For legacy quants (block size 32), unroll 8x.
  235. const uint UNROLL_K = (QUANT_K == 256) ? 256 : (BK * 8);
  236. const uint unroll_count = UNROLL_K / BK;
  237. // Detect a fast path where all loads are entirely in bounds and no clamping is required
  238. if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % START_ALIGN_K) == 0 && (end_k % BK) == 0 &&
  239. #if QUANT_K == 1
  240. (stride_a % 8) == 0 &&
  241. #endif
  242. (stride_b % 8) == 0) {
  243. // Hint to the compiler that values are aligned (want 16B alignment)
  244. start_k &= ~(START_ALIGN_K-1);
  245. stride_b &= ~7;
  246. #if QUANT_K == 1
  247. stride_a &= ~7;
  248. #endif
  249. tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
  250. tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
  251. uint k_iters = (end_k - start_k) / UNROLL_K;
  252. uint block_k = start_k;
  253. // fetch scale values for a tile of quants. These will be copied into shared memory.
  254. // The fetches and stores are pipelined to hide the latency.
  255. fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, true);
  256. if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {
  257. coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
  258. for (uint i = 0; i < k_iters; ++i) {
  259. store_scales(tid);
  260. if (block_k + UNROLL_K < end_k) {
  261. fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
  262. }
  263. // Manually partial unroll
  264. [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
  265. coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
  266. coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
  267. coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
  268. coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
  269. sum = coopMatMulAdd(mat_a, mat_b, sum);
  270. block_k += BK;
  271. }
  272. }
  273. // Do any remaining iterations that were not unrolled
  274. if (block_k < end_k) {
  275. store_scales(tid);
  276. }
  277. while (block_k < end_k) {
  278. coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
  279. coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
  280. coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
  281. coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
  282. sum = coopMatMulAdd(mat_a, mat_b, sum);
  283. block_k += BK;
  284. }
  285. coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
  286. coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
  287. return;
  288. } else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {
  289. coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
  290. for (uint i = 0; i < k_iters; ++i) {
  291. store_scales(tid);
  292. if (block_k + UNROLL_K < end_k) {
  293. fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
  294. }
  295. // Manually partial unroll
  296. [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
  297. coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
  298. coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
  299. coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
  300. coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
  301. sum = coopMatMulAdd(mat_a, mat_b, sum);
  302. block_k += BK;
  303. }
  304. }
  305. // Do any remaining iterations that were not unrolled
  306. if (block_k < end_k) {
  307. store_scales(tid);
  308. }
  309. while (block_k < end_k) {
  310. coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
  311. coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
  312. coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
  313. coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
  314. sum = coopMatMulAdd(mat_a, mat_b, sum);
  315. block_k += BK;
  316. }
  317. coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
  318. coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
  319. return;
  320. } else {
  321. coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
  322. for (uint i = 0; i < k_iters; ++i) {
  323. store_scales(tid);
  324. if (block_k + UNROLL_K < end_k) {
  325. fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
  326. }
  327. // Manually partial unroll
  328. [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
  329. coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
  330. coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
  331. coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
  332. coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
  333. sum = coopMatMulAdd(mat_a, mat_b, sum);
  334. block_k += BK;
  335. }
  336. }
  337. // Do any remaining iterations that were not unrolled
  338. if (block_k < end_k) {
  339. store_scales(tid);
  340. }
  341. while (block_k < end_k) {
  342. coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
  343. coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
  344. coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
  345. coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
  346. sum = coopMatMulAdd(mat_a, mat_b, sum);
  347. block_k += BK;
  348. }
  349. coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
  350. coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
  351. return;
  352. }
  353. } else
  354. #endif // !defined(MUL_MAT_ID)
  355. {
  356. tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
  357. tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1);
  358. tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
  359. tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
  360. coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
  361. sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
  362. uint k_iters = (end_k - start_k + BK - 1) / BK;
  363. fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
  364. [[dont_unroll]]
  365. for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
  366. store_scales(tid);
  367. if (block_k + BK < end_k) {
  368. fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
  369. }
  370. if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
  371. coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
  372. coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
  373. coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
  374. #ifdef MUL_MAT_ID
  375. coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
  376. #else
  377. coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
  378. #endif
  379. sum = coopMatMulAdd(mat_a, mat_b, sum);
  380. } else {
  381. coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
  382. coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
  383. coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
  384. #ifdef MUL_MAT_ID
  385. coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
  386. #else
  387. coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
  388. #endif
  389. sum = coopMatMulAdd(mat_a, mat_b, sum);
  390. }
  391. }
  392. // Convert from ACC_TYPE to D_TYPE
  393. coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
  394. mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
  395. #ifdef MUL_MAT_ID
  396. // Call callback to store each element, remapping row through shared memory
  397. coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
  398. #else
  399. coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
  400. #endif
  401. }
  402. }