Browse Source

metal : enable FA for MLA heads (#18950)

Georgi Gerganov 1 week ago
parent
commit
271191906c

+ 2 - 6
ggml/src/ggml-metal/ggml-metal-device.m

@@ -1078,12 +1078,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                 op->src[0]->ne[0] != 112 &&
                 op->src[0]->ne[0] != 128 &&
                 op->src[0]->ne[0] != 192 &&
-                op->src[0]->ne[0] != 256) {
-                return false;
-            }
-            if (op->src[0]->ne[0] == 576) {
-                // DeepSeek sizes
-                // TODO: disabled for now, until optmized
+                op->src[0]->ne[0] != 256 &&
+                op->src[0]->ne[0] != 576) {
                 return false;
             }
             if (op->src[1]->type != op->src[2]->type) {

+ 1 - 1
ggml/src/ggml-metal/ggml-metal-ops.cpp

@@ -2520,7 +2520,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
 
         // simdgroups per threadgroup (a.k.a. warps)
         //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
-        int32_t nsg = 4;
+        int32_t nsg = ne00 >= 512 ? 8 : 4;
 
         const size_t smem = FATTN_SMEM(nsg);
 

+ 8 - 5
ggml/src/ggml-metal/ggml-metal.metal

@@ -5552,9 +5552,7 @@ void kernel_flash_attn_ext_impl(
 
                 constexpr short NC = (C/8)/NSG;
 
-                // note: do not unroll for large heads
-                #pragma unroll (DK <= 64 ? NC : 1)
-                for (short cc = 0; cc < NC; ++cc) {
+                FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
                     qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
 
                     if (DK % 16 != 0) {
@@ -5575,7 +5573,9 @@ void kernel_flash_attn_ext_impl(
                         k8x8_t mk[2];
                         q8x8_t mq[2];
 
-                        FOR_UNROLL (short i = 0; i < DK8/2; ++i) {
+                        // note: too much unroll can tank the performance for large heads
+                        #pragma unroll (MIN(DK8/2, 4*NSG))
+                        for (short i = 0; i < DK8/2; ++i) {
                             simdgroup_barrier(mem_flags::mem_none);
 
                             simdgroup_load(mq[0], pq + 0*8 + 16*i, DK);
@@ -5749,7 +5749,9 @@ void kernel_flash_attn_ext_impl(
                                 pv  += 8*NS20;
                             }
                         } else {
-                            FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) {
+                            constexpr short NC = (C/8)/2;
+
+                            FOR_UNROLL (short cc = 0; cc < NC; ++cc) {
                                 s8x8_t vs[2];
 
                                 simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false);
@@ -5952,6 +5954,7 @@ kernel void kernel_flash_attn_ext(
       //case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break;
       //case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break;
         case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break;
+        case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break;
     }
 #undef FWD_TMPL
 #undef FWD_ARGS