|
|
@@ -107,11 +107,7 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i
|
|
|
{
|
|
|
const uint row_i = blockCoords[0];
|
|
|
|
|
|
- if (row_i >= _ne1) {
|
|
|
- return B_TYPE(0.0);
|
|
|
- }
|
|
|
-
|
|
|
- const u16vec4 row_idx = row_ids[row_i & (BN - 1)];
|
|
|
+ const u16vec4 row_idx = row_ids[row_i];
|
|
|
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
|
|
|
|
|
|
return ret;
|
|
|
@@ -194,12 +190,21 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
|
|
|
#endif
|
|
|
|
|
|
void main() {
|
|
|
+ const uint tid = gl_LocalInvocationIndex;
|
|
|
+#ifdef MUL_MAT_ID
|
|
|
+ // initialize to row 0 so we don't need to bounds check
|
|
|
+ if (tid < BN) {
|
|
|
+ row_ids[tid] = u16vec4(0);
|
|
|
+ }
|
|
|
+#if !defined(NEEDS_INIT_IQ_SHMEM)
|
|
|
+ barrier();
|
|
|
+#endif
|
|
|
+#endif
|
|
|
+
|
|
|
#ifdef NEEDS_INIT_IQ_SHMEM
|
|
|
init_iq_shmem(gl_WorkGroupSize);
|
|
|
#endif
|
|
|
|
|
|
- const uint tid = gl_LocalInvocationIndex;
|
|
|
-
|
|
|
#ifdef MUL_MAT_ID
|
|
|
const uint expert_idx = gl_GlobalInvocationID.z;
|
|
|
#else
|
|
|
@@ -482,7 +487,7 @@ void main() {
|
|
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
|
|
|
|
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
|
|
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
|
|
|
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
|
|
|
|
|
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
|
} else {
|
|
|
@@ -490,7 +495,7 @@ void main() {
|
|
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
|
|
|
|
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
|
|
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
|
|
|
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
|
|
|
|
|
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
|
}
|
|
|
@@ -526,7 +531,7 @@ void main() {
|
|
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
|
|
|
|
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
|
|
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
|
|
|
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
|
|
|
|
|
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
|
} else {
|
|
|
@@ -534,7 +539,7 @@ void main() {
|
|
|
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
|
|
|
|
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
|
|
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
|
|
|
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
|
|
|
|
|
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
|
}
|
|
|
@@ -571,7 +576,7 @@ void main() {
|
|
|
|
|
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
|
|
#ifdef MUL_MAT_ID
|
|
|
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
|
|
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
|
|
#else
|
|
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
|
|
#endif
|
|
|
@@ -583,7 +588,7 @@ void main() {
|
|
|
|
|
|
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
|
|
#ifdef MUL_MAT_ID
|
|
|
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
|
|
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
|
|
#else
|
|
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
|
|
#endif
|