浏览代码

vulkan: use aligned loads for flash attention mask (#12853)

Rewrite the stride logic for the mask tensor in the FA shader to force the
stride to be aligned, to allow using more efficient loads.
Jeff Bolz 10 月之前
父节点
当前提交
a4837577aa
共有 1 个文件被更改,包括 7 次插入4 次删除
  1. 7 4
      ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

+ 7 - 4
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

@@ -201,6 +201,11 @@ void main() {
     uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
     uint32_t k_stride = p.nb11;
     uint32_t v_stride = p.nb21;
+    // When using grouped query attention, all rows use the same mask (stride 0).
+    // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
+    // that prevents the compiler from folding the "&" through the select
+    // and breaking the alignment detection.
+    uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
     // hint to the compiler that strides are aligned for the aligned variant of the shader
     if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
     {
@@ -209,6 +214,7 @@ void main() {
         k_stride &= ~7;
         v_stride &= ~7;
 #endif
+        m_stride &= ~7;
     }
     tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
     tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
@@ -261,10 +267,7 @@ void main() {
         if (p.mask != 0) {
             tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
             tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
-            // When using grouped query attention, all rows use the same mask.
-            if (p.gqa_ratio > 1) {
-                tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, 0, 1);
-            }
+            tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
 
             coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;