Просмотр исходного кода

vulkan: skip all-negative-inf blocks in FA (#17186)

Jeff Bolz 2 месяцев назад
Родитель
Сommit
234ae7d7bd

+ 6 - 2
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -521,6 +521,7 @@ struct vk_device_struct {
     bool subgroup_shuffle;
     bool subgroup_ballot;
     bool subgroup_clustered;
+    bool subgroup_vote;
     bool multi_add;
     bool shader_int64;
     bool buffer_device_address;
@@ -4188,6 +4189,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
         device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
                                   (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot);
 
+        device->subgroup_vote = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
+                                (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eVote);
+
         const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
 
         device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -13572,8 +13576,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 default:
                     return false;
                 }
-                if (!coopmat2 && !device->subgroup_shuffle) {
-                    // scalar FA uses subgroupShuffle
+                if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) {
+                    // scalar/coopmat1 FA uses subgroupShuffle/subgroupAll
                     return false;
                 }
                 return true;

+ 33 - 15
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp

@@ -7,6 +7,7 @@
 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
 
 #extension GL_KHR_shader_subgroup_shuffle : enable
+#extension GL_KHR_shader_subgroup_vote : enable
 
 #include "types.glsl"
 #include "flash_attn_base.glsl"
@@ -108,6 +109,38 @@ void main() {
     [[dont_unroll]]
     for (uint32_t j = start_j; j < end_j; ++j) {
 
+        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+            bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
+
+            float max_mask = NEG_FLT_MAX_OVER_2;
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
+                uint32_t c = (idx + tid) % Bc;
+                uint32_t r = (idx + tid) / Bc;
+                if (idx + tid < Bc * Br) {
+                    if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
+                        float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
+                        masksh[c][r] = m;
+                        max_mask = max(max_mask, m);
+                    } else {
+                        masksh[c][r] = float(0);
+                    }
+                }
+            }
+            // skip the block if the mask is entirely -inf
+            bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
+            barrier();
+            if (gl_SubgroupInvocationID == 0) {
+                tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
+            }
+            barrier();
+            [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
+                max_mask = max(max_mask, tmpsh[s]);
+            }
+            if (max_mask <= NEG_FLT_MAX_OVER_2) {
+                continue;
+            }
+        }
+
         float Sf[Br][cols_per_thread];
         [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
             [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
@@ -153,21 +186,6 @@ void main() {
         }
 
         if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
-            bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
-
-            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
-                uint32_t c = (idx + tid) % Bc;
-                uint32_t r = (idx + tid) / Bc;
-                if (idx + tid < Bc * Br) {
-                    if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
-                        masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
-                    } else {
-                        masksh[c][r] = float(0);
-                    }
-                }
-            }
-            barrier();
-
             [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
                 [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
                     float mvf = masksh[c * cols_per_iter + col_tid][r];

+ 34 - 1
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp

@@ -7,6 +7,7 @@
 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
 
 #extension GL_KHR_shader_subgroup_basic : enable
+#extension GL_KHR_shader_subgroup_vote : enable
 #extension GL_KHR_memory_scope_semantics : enable
 #extension GL_KHR_cooperative_matrix : enable
 
@@ -148,6 +149,37 @@ void main() {
     [[dont_unroll]]
     for (uint32_t j = start_j; j < end_j; ++j) {
 
+        float mask_cache[Bc * Br / WorkGroupSize];
+        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+            bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
+
+            float max_mask = NEG_FLT_MAX_OVER_2;
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
+                uint32_t c = (idx + tid) % Bc;
+                uint32_t r = (idx + tid) / Bc;
+                if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
+                    if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
+                        float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
+                        mask_cache[idx / WorkGroupSize] = m;
+                        max_mask = max(max_mask, m);
+                    }
+                }
+            }
+            // skip the block if the mask is entirely -inf
+            bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
+            barrier();
+            if (gl_SubgroupInvocationID == 0) {
+                tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
+            }
+            barrier();
+            [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
+                max_mask = max(max_mask, tmpsh[s]);
+            }
+            if (max_mask <= NEG_FLT_MAX_OVER_2) {
+                continue;
+            }
+        }
+
         [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
             uint32_t d = (idx + tid) % (HSK / 4);
             uint32_t c = (idx + tid) / (HSK / 4);
@@ -208,7 +240,8 @@ void main() {
                 uint32_t r = (idx + tid) / Bc;
                 if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
                     if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
-                        sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
+                        float f = mask_cache[idx / WorkGroupSize];
+                        sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f);
                     }
                 }
             }

+ 37 - 19
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

@@ -29,6 +29,10 @@ ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
     return max(x, y);
 }
 
+float16_t maxReduceFp16(const in float16_t x, const in float16_t y) {
+    return max(x, y);
+}
+
 ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
     return x;
 }
@@ -142,21 +146,7 @@ void main() {
     [[dont_unroll]]
     for (uint32_t j = start_j; j < end_j; ++j) {
 
-        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
-
-        coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
-
-        uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
-        coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
-        S = coopMatMulAdd(Qf16, K_T, S);
-
-        if (p.logit_softcap != 0.0f) {
-            [[unroll]]
-            for (int k = 0; k < S.length(); ++k) {
-                S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
-            }
-        }
-
+        coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
         if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
             bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
 
@@ -164,12 +154,17 @@ void main() {
                 tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
                 tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
                 tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
+                tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
 
-                coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
+                coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv, mvmax;
 
                 coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
 
-                S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
+                // skip the block if the mask is entirely -inf
+                coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
+                if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
+                    continue;
+                }
             } else {
                 tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
                 // Don't clamp against nem1 when GQA is enabled
@@ -177,14 +172,37 @@ void main() {
                 tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
                 tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
 
-                coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
+                coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
 
                 coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
 
-                S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
+                // skip the block if the mask is entirely -inf
+                coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
+                if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
+                    continue;
+                }
             }
         }
 
+        coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
+
+        coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
+
+        uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
+        coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
+        S = coopMatMulAdd(Qf16, K_T, S);
+
+        if (p.logit_softcap != 0.0f) {
+            [[unroll]]
+            for (int k = 0; k < S.length(); ++k) {
+                S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
+            }
+        }
+
+        if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+            S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
+        }
+
         // Clear padding elements to -inf, so they don't contribute to rowmax
         if (Clamp != 0 &&
             ((j + 1) * Bc > KV ||