|
@@ -201,6 +201,11 @@ void main() {
|
|
|
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
|
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
|
|
uint32_t k_stride = p.nb11;
|
|
uint32_t k_stride = p.nb11;
|
|
|
uint32_t v_stride = p.nb21;
|
|
uint32_t v_stride = p.nb21;
|
|
|
|
|
+ // When using grouped query attention, all rows use the same mask (stride 0).
|
|
|
|
|
+ // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
|
|
|
|
+ // that prevents the compiler from folding the "&" through the select
|
|
|
|
|
+ // and breaking the alignment detection.
|
|
|
|
|
+ uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
|
|
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
|
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
|
|
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
|
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
|
|
{
|
|
{
|
|
@@ -209,6 +214,7 @@ void main() {
|
|
|
k_stride &= ~7;
|
|
k_stride &= ~7;
|
|
|
v_stride &= ~7;
|
|
v_stride &= ~7;
|
|
|
#endif
|
|
#endif
|
|
|
|
|
+ m_stride &= ~7;
|
|
|
}
|
|
}
|
|
|
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
|
|
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
|
|
|
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
|
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
|
@@ -261,10 +267,7 @@ void main() {
|
|
|
if (p.mask != 0) {
|
|
if (p.mask != 0) {
|
|
|
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
|
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
|
|
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
|
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
|
|
- // When using grouped query attention, all rows use the same mask.
|
|
|
|
|
- if (p.gqa_ratio > 1) {
|
|
|
|
|
- tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1);
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
|
|
|
|
|
|
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
|
|
|
|
|