|
@@ -154,15 +154,31 @@ void main() {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
|
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
|
|
|
- tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
|
|
|
|
- tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
|
|
|
|
- tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
|
|
|
|
|
|
+ bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
|
|
|
|
|
|
|
|
- coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
|
|
|
|
|
|
+ if (nem1_bounds_check) {
|
|
|
|
|
+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
|
|
|
|
+ tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
|
|
|
|
+ tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
|
|
|
|
|
|
|
- coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
|
|
|
|
|
|
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
|
|
|
|
|
|
|
- S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
|
|
|
|
|
|
+ coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
|
|
|
|
+
|
|
|
|
|
+ S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
|
|
|
|
+ // Don't clamp against nem1 when GQA is enabled
|
|
|
|
|
+ uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
|
|
|
|
|
+ tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
|
|
|
|
|
+ tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
|
|
|
|
+
|
|
|
|
|
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
|
|
|
|
+
|
|
|
|
|
+ coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
|
|
|
|
+
|
|
|
|
|
+ S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Clear padding elements to -inf, so they don't contribute to rowmax
|
|
// Clear padding elements to -inf, so they don't contribute to rowmax
|