|
|
@@ -29,6 +29,10 @@ ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
|
|
return max(x, y);
|
|
|
}
|
|
|
|
|
|
+float16_t maxReduceFp16(const in float16_t x, const in float16_t y) {
|
|
|
+ return max(x, y);
|
|
|
+}
|
|
|
+
|
|
|
ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
|
|
return x;
|
|
|
}
|
|
|
@@ -142,21 +146,7 @@ void main() {
|
|
|
[[dont_unroll]]
|
|
|
for (uint32_t j = start_j; j < end_j; ++j) {
|
|
|
|
|
|
- coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
|
|
-
|
|
|
- coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
|
|
|
-
|
|
|
- uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
|
|
- coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
|
|
|
- S = coopMatMulAdd(Qf16, K_T, S);
|
|
|
-
|
|
|
- if (p.logit_softcap != 0.0f) {
|
|
|
- [[unroll]]
|
|
|
- for (int k = 0; k < S.length(); ++k) {
|
|
|
- S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
|
|
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
|
|
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
|
|
|
|
|
@@ -164,12 +154,17 @@ void main() {
|
|
|
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
|
|
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
|
|
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
|
|
+ tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
|
|
|
|
|
|
- coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
|
|
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv, mvmax;
|
|
|
|
|
|
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
|
|
|
|
|
- S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
|
|
+ // skip the block if the mask is entirely -inf
|
|
|
+ coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
|
|
|
+ if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
} else {
|
|
|
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
|
|
// Don't clamp against nem1 when GQA is enabled
|
|
|
@@ -177,14 +172,37 @@ void main() {
|
|
|
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
|
|
|
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
|
|
|
|
|
- coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
|
|
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
|
|
|
|
|
|
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
|
|
|
|
|
- S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
|
|
+ // skip the block if the mask is entirely -inf
|
|
|
+ coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
|
|
|
+ if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
|
|
+
|
|
|
+ coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
|
|
|
+
|
|
|
+ uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
|
|
+ coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
|
|
|
+ S = coopMatMulAdd(Qf16, K_T, S);
|
|
|
+
|
|
|
+ if (p.logit_softcap != 0.0f) {
|
|
|
+ [[unroll]]
|
|
|
+ for (int k = 0; k < S.length(); ++k) {
|
|
|
+ S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
|
|
+ S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
|
|
+ }
|
|
|
+
|
|
|
// Clear padding elements to -inf, so they don't contribute to rowmax
|
|
|
if (Clamp != 0 &&
|
|
|
((j + 1) * Bc > KV ||
|