| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469 |
- #version 450
- #extension GL_EXT_control_flow_attributes : enable
- #extension GL_EXT_shader_16bit_storage : require
- #ifdef FLOAT16
- #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
- #endif
- #if defined(DATA_A_IQ1_M)
- #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
- #endif
- #if defined(DATA_A_BF16) && defined(COOPMAT)
- #extension GL_EXT_bfloat16 : enable
- #endif
- #ifdef COOPMAT
- #extension GL_KHR_cooperative_matrix : enable
- #extension GL_KHR_memory_scope_semantics : enable
- #endif
- #if defined(COOPMAT) || defined(MUL_MAT_ID_USE_SUBGROUPS)
- #extension GL_KHR_shader_subgroup_basic : enable
- #extension GL_KHR_shader_subgroup_ballot : enable
- #endif
- #ifdef MUL_MAT_ID
- #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
- #endif
- #include "types.comp"
- #ifndef LOAD_VEC_A
- #define LOAD_VEC_A 1
- #endif
- #ifndef LOAD_VEC_B
- #define LOAD_VEC_B 1
- #endif
- #if !defined(TO_FLOAT_TYPE)
- #define TO_FLOAT_TYPE FLOAT_TYPE
- #endif
- layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
- layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
- #if defined(A_TYPE_PACKED16)
- layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
- #endif
- #if defined(A_TYPE_PACKED32)
- layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
- #endif
- layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
- layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
- #ifdef MUL_MAT_ID
- layout (binding = 3) readonly buffer IDS {int data_ids[];};
- #endif
- layout (push_constant) uniform parameter
- {
- uint M;
- uint N;
- uint K;
- uint stride_a;
- uint stride_b;
- uint stride_d;
- uint batch_stride_a;
- uint batch_stride_b;
- uint batch_stride_d;
- #ifdef MUL_MAT_ID
- uint nei0;
- uint nei1;
- uint nbi1;
- uint ne11;
- #else
- uint k_split;
- uint ne02;
- uint ne12;
- uint broadcast2;
- uint broadcast3;
- #endif
- } p;
- layout (constant_id = 0) const uint BLOCK_SIZE = 64;
- layout (constant_id = 1) const uint BM = 64;
- layout (constant_id = 2) const uint BN = 64;
- layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
- layout (constant_id = 4) const uint WM = 32;
- layout (constant_id = 5) const uint WN = 32;
- layout (constant_id = 6) const uint WMITER = 2;
- layout (constant_id = 7) const uint TM = 4;
- layout (constant_id = 8) const uint TN = 2;
- layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat
- layout (constant_id = 10) const uint WARP = 32;
- #ifdef COOPMAT
- #define SHMEM_STRIDE (BK + 8)
- #else
- #define SHMEM_STRIDE (BK + 1)
- #endif
- shared FLOAT_TYPE buf_a[BM * SHMEM_STRIDE];
- shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
- #define NUM_WARPS (BLOCK_SIZE / WARP)
- #ifdef MUL_MAT_ID
- shared u16vec2 row_ids[BN];
- uint _ne1;
- #ifdef MUL_MAT_ID_USE_SUBGROUPS
- shared uvec4 ballots_sh[NUM_WARPS];
- void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
- _ne1 = 0;
- uint num_elements = p.nei1 * p.nei0;
- uint nei0shift = findLSB(p.nei0);
- uint ids[16];
- uint iter = 0;
- for (uint j = 0; j < num_elements; j += BLOCK_SIZE) {
- // prefetch up to 16 elements
- if (iter == 0) {
- [[unroll]] for (uint k = 0; k < 16; ++k) {
- uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE;
- bool in_range = i < num_elements;
- uint ii1;
- if (nei0_is_pow2) {
- ii1 = i >> nei0shift;
- } else {
- ii1 = i / p.nei0;
- }
- uint ii0 = i - ii1 * p.nei0;
- ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0;
- }
- }
- uint i = j + gl_LocalInvocationIndex;
- bool in_range = i < num_elements;
- uint ii1;
- if (nei0_is_pow2) {
- ii1 = i >> nei0shift;
- } else {
- ii1 = i / p.nei0;
- }
- uint ii0 = i - ii1 * p.nei0;
- uint id = ids[iter++];
- uvec4 ballot = subgroupBallot(in_range && id == expert_idx);
- ballots_sh[gl_SubgroupID] = ballot;
- barrier();
- uint subgroup_base = 0;
- uint total = 0;
- for (uint k = 0; k < gl_NumSubgroups; ++k) {
- if (k == gl_SubgroupID) {
- subgroup_base = total;
- }
- total += subgroupBallotBitCount(ballots_sh[k]);
- }
- barrier();
- uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
- if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
- row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
- }
- _ne1 += total;
- iter &= 15;
- if (_ne1 >= (ic + 1) * BN) {
- break;
- }
- }
- barrier();
- }
- #endif // MUL_MAT_ID_USE_SUBGROUPS
- #endif // MUL_MAT_ID
- #ifdef COOPMAT
- shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
- #endif
- #include "mul_mm_funcs.comp"
- void main() {
- #ifdef NEEDS_INIT_IQ_SHMEM
- init_iq_shmem(gl_WorkGroupSize);
- #endif
- #ifdef MUL_MAT_ID
- const uint expert_idx = gl_GlobalInvocationID.z;
- #else
- const uint batch_idx = gl_GlobalInvocationID.z;
- const uint i13 = batch_idx / p.ne12;
- const uint i12 = batch_idx % p.ne12;
- const uint i03 = i13 / p.broadcast3;
- const uint i02 = i12 / p.broadcast2;
- const uint batch_idx_a = i03 * p.ne02 + i02;
- #endif
- const uint blocks_m = (p.M + BM - 1) / BM;
- const uint ir = gl_WorkGroupID.x % blocks_m;
- const uint ik = gl_WorkGroupID.x / blocks_m;
- const uint ic = gl_WorkGroupID.y;
- const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER);
- const uint WSUBM = WM / WMITER;
- const uint WSUBN = WN / WNITER;
- #ifdef COOPMAT
- const uint warp_i = gl_SubgroupID;
- const uint tiw = gl_SubgroupInvocationID;
- const uint cms_per_row = WM / TM;
- const uint cms_per_col = WN / TN;
- const uint storestride = WARP / TM;
- const uint store_r = tiw % TM;
- const uint store_c = tiw / TM;
- #else
- const uint warp_i = gl_LocalInvocationID.x / WARP;
- const uint tiw = gl_LocalInvocationID.x % WARP;
- const uint tiwr = tiw % (WSUBM / TM);
- const uint tiwc = tiw / (WSUBM / TM);
- #endif
- const uint warp_r = warp_i % (BM / WM);
- const uint warp_c = warp_i / (BM / WM);
- const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A);
- const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A);
- const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B);
- const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B);
- const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A / BK;
- const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B / BK;
- #ifdef MUL_MAT_ID
- #ifdef MUL_MAT_ID_USE_SUBGROUPS
- if (bitCount(p.nei0) == 1) {
- load_row_ids(expert_idx, true, ic);
- } else {
- load_row_ids(expert_idx, false, ic);
- }
- #else
- _ne1 = 0;
- for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
- for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
- if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
- if (_ne1 >= ic * BN) {
- row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
- }
- _ne1++;
- }
- }
- }
- barrier();
- #endif
- // Workgroup has no work
- if (ic * BN >= _ne1) return;
- #endif
- #ifdef MUL_MAT_ID
- const uint start_k = 0;
- const uint end_k = p.K;
- #else
- const uint start_k = ik * p.k_split;
- const uint end_k = min(p.K, (ik + 1) * p.k_split);
- #endif
- uint pos_a = (
- #ifdef MUL_MAT_ID
- expert_idx * p.batch_stride_a +
- #else
- batch_idx_a * p.batch_stride_a +
- #endif
- ir * BM * p.stride_a + start_k) / LOAD_VEC_A;
- #ifdef MUL_MAT_ID
- uint pos_b = 0;
- #else
- uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
- #endif
- #ifdef COOPMAT
- coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a;
- coopmat<FLOAT_TYPE, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b;
- coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col];
- [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) {
- sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
- }
- #else
- ACC_TYPE sums[WMITER * TM * WNITER * TN];
- FLOAT_TYPE cache_a[WMITER * TM];
- FLOAT_TYPE cache_b[TN];
- [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
- sums[i] = ACC_TYPE(0.0f);
- }
- #endif
- for (uint block = start_k; block < end_k; block += BK) {
- [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) {
- load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block + loadr_a, end_k);
- }
- [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
- #if !defined(MUL_MAT_ID)
- load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block + loadr_b, end_k);
- #else
- load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block + loadr_b, end_k);
- #endif
- }
- barrier();
- pos_a += BK / LOAD_VEC_A;
- pos_b += BK / LOAD_VEC_B;
- #ifdef COOPMAT
- [[unroll]] for (uint i = 0; i < BK; i += TK) {
- [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
- // Load from shared into cache
- coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor);
- [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
- coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor);
- sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]);
- }
- }
- }
- #else
- [[unroll]] for (uint i = 0; i < BK; i++) {
- // Load from shared into cache
- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
- [[unroll]] for (uint j = 0; j < TM; j++) {
- cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
- }
- }
- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
- [[unroll]] for (uint j = 0; j < TN; j++) {
- cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
- }
- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
- [[unroll]] for (uint cc = 0; cc < TN; cc++) {
- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
- const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
- sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr]), ACC_TYPE(cache_b[cc]), sums[sums_idx]);
- }
- }
- }
- }
- }
- #endif
- barrier();
- }
- #if defined(ACC_TYPE_MAX)
- #ifdef COOPMAT
- [[unroll]] for (uint j = 0; j < cms_per_row * cms_per_col; j++) {
- [[unroll]] for (uint i = 0; i < sums[j].length(); ++i) {
- sums[j][i] = clamp(sums[j][i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
- }
- }
- #else
- [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
- sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
- }
- #endif
- #endif
- const uint dr = ir * BM + warp_r * WM;
- const uint dc = ic * BN + warp_c * WN;
- #ifndef MUL_MAT_ID
- const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
- #endif
- #ifdef COOPMAT
- #ifdef MUL_MAT_ID
- [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
- [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
- coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
- [[unroll]] for (uint col = 0; col < TN; col += storestride) {
- const uint row_i = dc + cm_col * TN + col + store_c;
- if (row_i >= _ne1) break;
- const u16vec2 row_idx = row_ids[row_i - ic * BN];
- if (dr + cm_row * TM + store_r < p.M) {
- data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
- }
- }
- }
- }
- #else
- const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float
- [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) {
- [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) {
- const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N;
- if (is_aligned && is_in_bounds) {
- // Full coopMat is within bounds and stride_d is aligned with 16B
- coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]);
- coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
- } else if (is_in_bounds) {
- // Full coopMat is within bounds, but stride_d is not aligned
- coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
- [[unroll]] for (uint col = 0; col < TN; col += storestride) {
- data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
- }
- } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) {
- // Partial coopMat is within bounds
- coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor);
- [[unroll]] for (uint col = 0; col < TN; col += storestride) {
- if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) {
- data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
- }
- }
- }
- }
- }
- #endif // MUL_MAT_ID
- #else
- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
- const uint dr_warp = dr + wsir * WSUBM + tiwr * TM;
- const uint dc_warp = dc + wsic * WSUBN + tiwc * TN;
- [[unroll]] for (uint cc = 0; cc < TN; cc++) {
- #ifdef MUL_MAT_ID
- const uint row_i = dc_warp + cc;
- if (row_i >= _ne1) break;
- const u16vec2 row_idx = row_ids[row_i - ic * BN];
- #endif // MUL_MAT_ID
- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
- #ifdef MUL_MAT_ID
- if (dr_warp + cr < p.M) {
- data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
- }
- #else
- if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
- data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
- }
- #endif // MUL_MAT_ID
- }
- }
- }
- }
- #endif // COOPMAT
- }
|