|
|
@@ -19,6 +19,9 @@
|
|
|
|
|
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
|
|
+#define IS_MUL_MM2 1
|
|
|
+
|
|
|
+layout (constant_id = 0) const uint BLOCK_SIZE = 256;
|
|
|
layout (constant_id = 1) const uint BM = 64;
|
|
|
layout (constant_id = 2) const uint BN = 64;
|
|
|
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant
|
|
|
@@ -70,6 +73,13 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
|
|
#define DECODEFUNCA
|
|
|
#endif
|
|
|
|
|
|
+#if !defined(fetch_scales)
|
|
|
+#define fetch_scales(a, b, c, d, e, f)
|
|
|
+#endif
|
|
|
+#if !defined(store_scales)
|
|
|
+#define store_scales(a)
|
|
|
+#endif
|
|
|
+
|
|
|
#ifdef MUL_MAT_ID
|
|
|
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
|
|
|
|
|
@@ -116,6 +126,8 @@ void main() {
|
|
|
init_iq_shmem(gl_WorkGroupSize);
|
|
|
#endif
|
|
|
|
|
|
+ const uint tid = gl_LocalInvocationIndex;
|
|
|
+
|
|
|
#ifdef MUL_MAT_ID
|
|
|
const uint expert_idx = gl_GlobalInvocationID.z;
|
|
|
#else
|
|
|
@@ -218,14 +230,21 @@ void main() {
|
|
|
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
|
|
|
|
|
|
#if !defined(MUL_MAT_ID)
|
|
|
+
|
|
|
+ const uint START_ALIGN_K = 256;
|
|
|
+ // For Qi_K (block size 256), unroll whole 256 element tiles.
|
|
|
+ // For legacy quants (block size 32), unroll 8x.
|
|
|
+ const uint UNROLL_K = (QUANT_K == 256) ? 256 : (BK * 8);
|
|
|
+ const uint unroll_count = UNROLL_K / BK;
|
|
|
+
|
|
|
// Detect a fast path where all loads are entirely in bounds and no clamping is required
|
|
|
- if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % BK) == 0 && (end_k % BK) == 0 &&
|
|
|
+ if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % START_ALIGN_K) == 0 && (end_k % BK) == 0 &&
|
|
|
#if QUANT_K == 1
|
|
|
(stride_a % 8) == 0 &&
|
|
|
#endif
|
|
|
- (stride_b % 8) == 0 && (start_k % 8) == 0) {
|
|
|
+ (stride_b % 8) == 0) {
|
|
|
// Hint to the compiler that values are aligned (want 16B alignment)
|
|
|
- start_k &= ~7;
|
|
|
+ start_k &= ~(START_ALIGN_K-1);
|
|
|
stride_b &= ~7;
|
|
|
#if QUANT_K == 1
|
|
|
stride_a &= ~7;
|
|
|
@@ -234,11 +253,39 @@ void main() {
|
|
|
tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1);
|
|
|
tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);
|
|
|
|
|
|
- uint k_iters = (end_k - start_k + BK - 1) / BK;
|
|
|
+ uint k_iters = (end_k - start_k) / UNROLL_K;
|
|
|
+ uint block_k = start_k;
|
|
|
+
|
|
|
+ // fetch scale values for a tile of quants. These will be copied into shared memory.
|
|
|
+ // The fetches and stores are pipelined to hide the latency.
|
|
|
+ fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, true);
|
|
|
+
|
|
|
if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {
|
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
|
|
|
- for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
|
|
+ for (uint i = 0; i < k_iters; ++i) {
|
|
|
+
|
|
|
+ store_scales(tid);
|
|
|
+ if (block_k + UNROLL_K < end_k) {
|
|
|
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
|
|
|
+ }
|
|
|
|
|
|
+ // Manually partial unroll
|
|
|
+ [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
|
|
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
|
|
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
|
|
+
|
|
|
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
|
|
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
|
|
|
+
|
|
|
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
|
+ block_k += BK;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // Do any remaining iterations that were not unrolled
|
|
|
+ if (block_k < end_k) {
|
|
|
+ store_scales(tid);
|
|
|
+ }
|
|
|
+ while (block_k < end_k) {
|
|
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
|
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
|
|
|
|
|
|
@@ -246,6 +293,7 @@ void main() {
|
|
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);
|
|
|
|
|
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
|
+ block_k += BK;
|
|
|
}
|
|
|
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
|
|
|
|
|
|
@@ -253,8 +301,30 @@ void main() {
|
|
|
return;
|
|
|
} else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {
|
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
|
|
|
- for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
|
|
+ for (uint i = 0; i < k_iters; ++i) {
|
|
|
+
|
|
|
+ store_scales(tid);
|
|
|
+ if (block_k + UNROLL_K < end_k) {
|
|
|
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Manually partial unroll
|
|
|
+ [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
|
|
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
|
|
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
|
|
|
|
|
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
|
|
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
|
|
|
+
|
|
|
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
|
+ block_k += BK;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // Do any remaining iterations that were not unrolled
|
|
|
+ if (block_k < end_k) {
|
|
|
+ store_scales(tid);
|
|
|
+ }
|
|
|
+ while (block_k < end_k) {
|
|
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
|
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
|
|
|
|
|
|
@@ -262,6 +332,7 @@ void main() {
|
|
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);
|
|
|
|
|
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
|
+ block_k += BK;
|
|
|
}
|
|
|
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
|
|
|
|
|
|
@@ -269,8 +340,31 @@ void main() {
|
|
|
return;
|
|
|
} else {
|
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
|
|
|
- for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
|
|
|
|
|
+ for (uint i = 0; i < k_iters; ++i) {
|
|
|
+
|
|
|
+ store_scales(tid);
|
|
|
+ if (block_k + UNROLL_K < end_k) {
|
|
|
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Manually partial unroll
|
|
|
+ [[unroll]] for (uint j = 0; j < unroll_count; ++j) {
|
|
|
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
|
|
+ coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
|
|
+
|
|
|
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
|
|
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
|
|
+
|
|
|
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
|
+ block_k += BK;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ // Do any remaining iterations that were not unrolled
|
|
|
+ if (block_k < end_k) {
|
|
|
+ store_scales(tid);
|
|
|
+ }
|
|
|
+ while (block_k < end_k) {
|
|
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
|
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
|
|
|
|
|
@@ -278,6 +372,7 @@ void main() {
|
|
|
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
|
|
|
|
|
sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
|
+ block_k += BK;
|
|
|
}
|
|
|
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
|
|
|
|
|
|
@@ -298,47 +393,29 @@ void main() {
|
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
|
|
|
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
|
|
|
|
|
|
+ uint k_iters = (end_k - start_k + BK - 1) / BK;
|
|
|
+
|
|
|
+ fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
|
|
|
+
|
|
|
[[dont_unroll]]
|
|
|
- for (uint block_k = start_k; block_k < end_k; block_k += BK) {
|
|
|
+ for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
|
|
|
+
|
|
|
+ store_scales(tid);
|
|
|
+ if (block_k + BK < end_k) {
|
|
|
+ fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
|
|
|
+ }
|
|
|
|
|
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
|
|
|
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
|
|
|
|
|
|
- // Clamping is expensive, so detect different code paths for each combination
|
|
|
- // of A and B needing clamping.
|
|
|
- bool unclampedA = (ir + 1) * BM <= p.M && block_k + BK <= end_k && (block_k % 8) == 0;
|
|
|
+ coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
|
|
#ifdef MUL_MAT_ID
|
|
|
- bool unclampedB = true;
|
|
|
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
|
|
#else
|
|
|
- bool unclampedB = (ic + 1) * BN <= p.padded_N && block_k + BK <= end_k && (block_k % 8) == 0;
|
|
|
+ coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
|
|
#endif
|
|
|
- if (unclampedA && unclampedB) {
|
|
|
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
|
|
|
-#ifdef MUL_MAT_ID
|
|
|
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
|
|
-#else
|
|
|
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
|
|
|
-#endif
|
|
|
- sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
|
- } else if (unclampedA && !unclampedB) {
|
|
|
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
|
|
|
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
|
|
-
|
|
|
- sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
|
- } else if (!unclampedA && unclampedB) {
|
|
|
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
|
|
-#ifdef MUL_MAT_ID
|
|
|
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB);
|
|
|
-#else
|
|
|
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
|
|
|
-#endif
|
|
|
- sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
|
- } else if (!unclampedA && !unclampedB) {
|
|
|
- coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
|
|
|
- coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
|
|
|
|
|
|
- sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
|
- }
|
|
|
+ sum = coopMatMulAdd(mat_a, mat_b, sum);
|
|
|
}
|
|
|
|
|
|
// Convert from ACC_TYPE to D_TYPE
|