| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497 |
- #version 450
- #extension GL_EXT_control_flow_attributes : enable
- #extension GL_EXT_shader_16bit_storage : require
- #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
- #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
- #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
- #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
- #extension GL_KHR_memory_scope_semantics : enable
- #extension GL_KHR_cooperative_matrix : enable
- #extension GL_NV_cooperative_matrix2 : enable
- #extension GL_EXT_buffer_reference : enable
- #extension GL_KHR_shader_subgroup_ballot : enable
- #extension GL_KHR_shader_subgroup_vote : enable
- #ifdef DATA_A_BF16
- #extension GL_EXT_bfloat16 : enable
- #endif
- #include "types.comp"
- #include "utils.comp"
- layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
- #define IS_MUL_MM2 1
- layout (constant_id = 0) const uint BLOCK_SIZE = 256;
- layout (constant_id = 1) const uint BM = 64;
- layout (constant_id = 2) const uint BN = 64;
- layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
- layout (constant_id = 4) const bool enable_smaller_matrices = false;
- const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
- const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;
- layout (push_constant) uniform parameter
- {
- uint M;
- uint N;
- uint K;
- uint stride_a;
- uint stride_b;
- uint stride_d;
- uint batch_stride_a;
- uint batch_stride_b;
- uint batch_stride_d;
- #ifdef MUL_MAT_ID
- uint nei0;
- uint nei1;
- uint nbi1;
- uint ne11;
- #else
- uint k_split;
- uint ne02;
- uint ne12;
- uint broadcast2;
- uint broadcast3;
- #endif
- // N dimension for the B matrix can be >= p.N
- uint padded_N;
- } p;
- layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
- layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
- layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
- #if QUANT_K > 1
- #define DECODEFUNCA , dequantFuncA
- #include "dequant_funcs_cm2.comp"
- #else
- #define DECODEFUNCA
- #endif
- #if !defined(fetch_scales)
- #define fetch_scales(a, b, c, d, e, f)
- #endif
- #if !defined(store_scales)
- #define store_scales(a)
- #endif
- #if defined(DATA_A_BF16)
- #define MAT_TYPE bfloat16_t
- #else
- #define MAT_TYPE FLOAT_TYPE
- #endif
- #ifdef MUL_MAT_ID
- layout (binding = 3) readonly buffer IDS {int data_ids[];};
- shared u16vec4 row_ids[4096];
- layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
- B_TYPE b[];
- };
- uint _ne1;
- layout (constant_id = 5) const uint subgroup_size = 32;
- shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size];
- B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2])
- {
- const uint row_i = blockCoords[0];
- if (row_i >= _ne1) {
- return B_TYPE(0.0);
- }
- const u16vec4 row_idx = row_ids[row_i];
- B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
- return ret;
- }
- D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic)
- {
- uint dr = ir * BM + r;
- uint dc = ic * BN + c;
- if (dr < p.M && dc < _ne1) {
- uint row_i = dc;
- const u16vec4 row_idx = row_ids[row_i];
- data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
- }
- return elem;
- }
- void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
- _ne1 = 0;
- uint num_elements = p.nei1 * p.nei0;
- uint nei0shift = findLSB(p.nei0);
- uint ids[16];
- uint iter = 0;
- for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
- // prefetch up to 16 elements
- if (iter == 0) {
- [[unroll]] for (uint k = 0; k < 16; ++k) {
- uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
- bool in_range = i < num_elements;
- uint ii1;
- if (nei0_is_pow2) {
- ii1 = i >> nei0shift;
- } else {
- ii1 = i / p.nei0;
- }
- uint ii0 = i - ii1 * p.nei0;
- ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
- }
- }
- uint i = j + gl_LocalInvocationIndex;
- bool in_range = i < num_elements;
- uint ii1;
- if (nei0_is_pow2) {
- ii1 = i >> nei0shift;
- } else {
- ii1 = i / p.nei0;
- }
- uint ii0 = i - ii1 * p.nei0;
- uint id = ids[iter++];
- uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
- ballots_sh[gl_SubgroupID] = ballot;
- barrier();
- uint subgroup_base = 0;
- uint total = 0;
- for (uint k = 0; k < gl_NumSubgroups; ++k) {
- if (k == gl_SubgroupID) {
- subgroup_base = total;
- }
- total += subgroupBallotBitCount(ballots_sh[k]);
- }
- barrier();
- uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
- if (in_range && id == expert_idx) {
- row_ids[_ne1 + idx] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
- }
- _ne1 += total;
- iter &= 15;
- }
- barrier();
- }
- #endif
- void main() {
- #ifdef NEEDS_INIT_IQ_SHMEM
- init_iq_shmem(gl_WorkGroupSize);
- #endif
- const uint tid = gl_LocalInvocationIndex;
- #ifdef MUL_MAT_ID
- const uint expert_idx = gl_GlobalInvocationID.z;
- #else
- const uint batch_idx = gl_GlobalInvocationID.z;
- const uint i13 = batch_idx / p.ne12;
- const uint i12 = batch_idx % p.ne12;
- const uint i03 = i13 / p.broadcast3;
- const uint i02 = i12 / p.broadcast2;
- const uint batch_idx_a = i03 * p.ne02 + i02;
- #endif
- const uint blocks_m = (p.M + BM - 1) / BM;
- const uint ir = gl_WorkGroupID.x % blocks_m;
- const uint ik = gl_WorkGroupID.x / blocks_m;
- const uint ic = gl_WorkGroupID.y;
- #ifdef MUL_MAT_ID
- if (bitCount(p.nei0) == 1) {
- load_row_ids(expert_idx, true);
- } else {
- load_row_ids(expert_idx, false);
- }
- // Workgroup has no work
- if (ic * BN >= _ne1) return;
- #endif
- #ifdef MUL_MAT_ID
- uint start_k = 0;
- const uint end_k = p.K;
- #else
- uint start_k = ik * p.k_split;
- const uint end_k = min(p.K, (ik + 1) * p.k_split);
- #endif
- #ifdef MUL_MAT_ID
- uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
- uint pos_b = 0;
- #else
- uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
- uint pos_b = batch_idx * p.batch_stride_b;
- uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
- #endif
- uint stride_a = p.stride_a / QUANT_K;
- uint stride_b = p.stride_b;
- // Hint to the compiler that values are aligned (want 16B alignment).
- // Quants are always block-aligned, no alignment needed.
- #if ALIGNED
- #if QUANT_K == 1
- stride_a &= ~7;
- #endif
- stride_b &= ~7;
- #endif
- // Create layouts for both clamped and unclamped accesses
- tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2);
- tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
- tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
- tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
- tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
- tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);
- #if QUANT_K > 1
- tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
- tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K);
- #endif
- // Use end_k rather than p.K as the dimension because that's what
- // we need to bound check against when using split_k.
- // Bounds check B against padded_N, but bounds check D against N.
- tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k);
- tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k);
- tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M);
- tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k);
- tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k);
- tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
- #if !defined(MUL_MAT_ID)
- const uint START_ALIGN_K = 256;
- // For Qi_K (block size 256), unroll whole 256 element tiles.
- // For legacy quants (block size 32), unroll 8x.
- const uint UNROLL_K = (QUANT_K == 256) ? 256 : (BK * 8);
- const uint unroll_count = UNROLL_K / BK;
- // Detect a fast path where all loads are entirely in bounds and no clamping is required
- if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % START_ALIGN_K) == 0 && (end_k % BK) == 0 &&
- #if QUANT_K == 1
- (stride_a % 8) == 0 &&
- #endif
- (stride_b % 8) == 0) {
- // Hint to the compiler that values are aligned (want 16B alignment)
- start_k &= ~(START_ALIGN_K-1);
- stride_b &= ~7;
- #if QUANT_K == 1
- stride_a &= ~7;
- #endif
- tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
- tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
- uint k_iters = (end_k - start_k) / UNROLL_K;
- uint block_k = start_k;
- // fetch scale values for a tile of quants. These will be copied into shared memory.
- // The fetches and stores are pipelined to hide the latency.
- fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, true);
- if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {
- coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
- for (uint i = 0; i < k_iters; ++i) {
- store_scales(tid);
- if (block_k + UNROLL_K < end_k) {
- fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
- }
- // Manually partial unroll
- [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
- sum = coopMatMulAdd(mat_a, mat_b, sum);
- block_k += BK;
- }
- }
- // Do any remaining iterations that were not unrolled
- if (block_k < end_k) {
- store_scales(tid);
- }
- while (block_k < end_k) {
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
- sum = coopMatMulAdd(mat_a, mat_b, sum);
- block_k += BK;
- }
- coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
- coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
- return;
- } else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {
- coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
- for (uint i = 0; i < k_iters; ++i) {
- store_scales(tid);
- if (block_k + UNROLL_K < end_k) {
- fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
- }
- // Manually partial unroll
- [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
- sum = coopMatMulAdd(mat_a, mat_b, sum);
- block_k += BK;
- }
- }
- // Do any remaining iterations that were not unrolled
- if (block_k < end_k) {
- store_scales(tid);
- }
- while (block_k < end_k) {
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
- sum = coopMatMulAdd(mat_a, mat_b, sum);
- block_k += BK;
- }
- coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
- coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
- return;
- } else {
- coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
- for (uint i = 0; i < k_iters; ++i) {
- store_scales(tid);
- if (block_k + UNROLL_K < end_k) {
- fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
- }
- // Manually partial unroll
- [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
- sum = coopMatMulAdd(mat_a, mat_b, sum);
- block_k += BK;
- }
- }
- // Do any remaining iterations that were not unrolled
- if (block_k < end_k) {
- store_scales(tid);
- }
- while (block_k < end_k) {
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
- sum = coopMatMulAdd(mat_a, mat_b, sum);
- block_k += BK;
- }
- coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
- coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
- return;
- }
- } else
- #endif // !defined(MUL_MAT_ID)
- {
- tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
- tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1);
- tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
- tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
- coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
- sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
- uint k_iters = (end_k - start_k + BK - 1) / BK;
- fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
- [[dont_unroll]]
- for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
- store_scales(tid);
- if (block_k + BK < end_k) {
- fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
- }
- if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
- #ifdef MUL_MAT_ID
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
- #else
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
- #endif
- sum = coopMatMulAdd(mat_a, mat_b, sum);
- } else {
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
- coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
- #ifdef MUL_MAT_ID
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
- #else
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
- #endif
- sum = coopMatMulAdd(mat_a, mat_b, sum);
- }
- }
- // Convert from ACC_TYPE to D_TYPE
- coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
- mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
- #ifdef MUL_MAT_ID
- // Call callback to store each element, remapping row through shared memory
- coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
- #else
- coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
- #endif
- }
- }
|