Browse Source

CUDA: faster FA for GQA > 1 but not power of 2 (#19092)

Johannes Gäßler 5 days ago
parent
commit
0c21677e43

+ 12 - 10
ggml/src/ggml-cuda/fattn-common.cuh

@@ -643,9 +643,10 @@ static __global__ void flash_attn_stream_k_fixup(
 
     const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
     const int iter_j = (ne01 + (ncols1    - 1)) / ncols1;
+    const int iter_z = (ne02 + (ncols2    - 1)) / ncols2;
 
-    const int kbc0      = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
-    const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+    const int kbc0      = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
+    const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
 
     const bool did_not_have_any_data   = kbc0 == kbc0_stop;
     const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -654,15 +655,15 @@ static __global__ void flash_attn_stream_k_fixup(
         return;
     }
 
-    const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
-    const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
-    const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
+    const int sequence = kbc0 / (iter_k*iter_j*iter_z);
+    const int zt = (kbc0 - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j);
+    const int jt = (kbc0 - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
 
-    if (jt*ncols1 + j >= ne01) {
+    if (jt*ncols1 + j >= ne01 || zt*ncols2 + c >= ne02) {
         return;
     }
 
-    dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
+    dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt*(ncols2*D) + (j*ne02 + c)*D + tid;
 
     // Load the partial result that needs a fixup:
     float dst_val = 0.0f;
@@ -681,7 +682,7 @@ static __global__ void flash_attn_stream_k_fixup(
     int bidx = bidx0 - 1;
     int kbc_stop = kbc0;
     while(true) {
-        const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+        const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
         if (kbc == kbc_stop) { // Did not have any data.
             bidx--;
             kbc_stop = kbc;
@@ -883,7 +884,8 @@ void launch_fattn(
     }
 
     const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
-    const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
+    const int ntiles_z =  ((Q->ne[2] + ncols2 - 1) / ncols2);
+    const int ntiles_total = ntiles_x * ntiles_z * Q->ne[3];
 
     // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
     // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
@@ -958,7 +960,7 @@ void launch_fattn(
 
         blocks_num.x = ntiles_x;
         blocks_num.y = parallel_blocks;
-        blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
+        blocks_num.z = ntiles_z*Q->ne[3];
 
         if (parallel_blocks > 1) {
             dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));

+ 17 - 13
ggml/src/ggml-cuda/fattn-mma-f16.cuh

@@ -940,6 +940,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const int stride_V,
         const int stride_mask,
         const int jt,
+        const int zt,
         const int kb0_start,
         const int kb0_stop) {
 #if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
@@ -1022,7 +1023,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             const int j = jc / ncols2;
             const int c = jc % ncols2;
 
-            if (jt*ncols1 + j < int(ne01.z)) {
+            if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt*ncols2 + c < ne02)) {
 #pragma unroll
                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
                     const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
@@ -1408,7 +1409,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
                     const int j_dst = jc_dst / ncols2;
                     const int c_dst = jc_dst % ncols2;
 
-                    if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) {
+                    if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt*ncols2 + c_dst >= ne02))) {
                         continue;
                     }
 
@@ -1522,10 +1523,11 @@ static __global__ void flash_attn_ext_f16(
 
     const int iter_k = (ne11   + (nbatch_fa - 1)) / nbatch_fa;
     const int iter_j = (ne01.z + (ncols1    - 1)) / ncols1;
+    const int iter_z = (ne02   + (ncols2    - 1)) / ncols2;
 
     // kbc == k block continuous, current index in continuous ijk space.
-    int       kbc      = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
-    const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
+    int       kbc      = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
+    const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
 
     // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
     // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1536,9 +1538,9 @@ static __global__ void flash_attn_ext_f16(
     int kb0_stop  = min(iter_k, kb0_start + kbc_stop - kbc);
 
     while (kbc < kbc_stop && kb0_stop == iter_k) {
-        const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
-        const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
-        const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
+        const int sequence = kbc / (iter_k*iter_j*iter_z);
+        const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2
+        const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
 
         const int head0 = zt * ncols2;
 
@@ -1561,12 +1563,12 @@ static __global__ void flash_attn_ext_f16(
             constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
             flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
                 (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
-                 ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
+                 ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
         } else {
             constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
             flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
                 (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
-                 ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
+                 ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
         }
 
         kbc += iter_k;
@@ -1580,9 +1582,9 @@ static __global__ void flash_attn_ext_f16(
         return;
     }
 
-    const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
-    const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
-    const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
+    const int sequence = kbc / (iter_k*iter_j*iter_z);
+    const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2
+    const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
 
     const int head0 = zt * ncols2;
 
@@ -1605,7 +1607,7 @@ static __global__ void flash_attn_ext_f16(
     constexpr bool needs_fixup = false;
     flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
         (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
-         ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
+         ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
 #else
     GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
         max_bias, m0, m1, n_head_log2, logit_softcap,
@@ -1739,3 +1741,5 @@ extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
 extern DECL_FATTN_MMA_F16_CASE(576, 512,  4,  4);
 extern DECL_FATTN_MMA_F16_CASE(576, 512,  8,  4);
 extern DECL_FATTN_MMA_F16_CASE(576, 512, 16,  4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512,  1, 32);
+extern DECL_FATTN_MMA_F16_CASE(576, 512,  2, 32);

+ 57 - 10
ggml/src/ggml-cuda/fattn.cu

@@ -18,9 +18,11 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
         }
     }
 
-    if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
-        ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
-        return;
+    if constexpr (ncols2 <= 16) {
+        if ((turing_mma_available(cc) || amd_wmma_available(cc)) && Q->ne[1] <= 16/ncols2) {
+            ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
+            return;
+        }
     }
 
     if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || amd_wmma_available(cc) || Q->ne[1] <= 32/ncols2) {
@@ -33,6 +35,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
 
 template <int DKQ, int DV>
 static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const ggml_tensor * KQV  = dst;
     const ggml_tensor * Q    = dst->src[0];
     const ggml_tensor * K    = dst->src[1];
@@ -60,17 +63,38 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
     GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
     const int gqa_ratio = Q->ne[2] / K->ne[2];
 
-    if (use_gqa_opt && gqa_ratio % 8 == 0) {
+    // On Volta the GQA optimizations aren't as impactful vs. minimizing wasted compute:
+    if (cc == GGML_CUDA_CC_VOLTA) {
+        if (use_gqa_opt && gqa_ratio % 8 == 0) {
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
+            return;
+        }
+
+        if (use_gqa_opt && gqa_ratio % 4 == 0) {
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
+            return;
+        }
+
+        if (use_gqa_opt && gqa_ratio % 2 == 0) {
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
+            return;
+        }
+
+        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
+        return;
+    }
+
+    if (use_gqa_opt && gqa_ratio > 4) {
         ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
         return;
     }
 
-    if (use_gqa_opt && gqa_ratio % 4 == 0) {
+    if (use_gqa_opt && gqa_ratio > 2) {
         ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
         return;
     }
 
-    if (use_gqa_opt && gqa_ratio % 2 == 0) {
+    if (use_gqa_opt && gqa_ratio > 1) {
         ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
         return;
     }
@@ -79,6 +103,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
 }
 
 static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const ggml_tensor * KQV  = dst;
     const ggml_tensor * Q    = dst->src[0];
     const ggml_tensor * K    = dst->src[1];
@@ -121,8 +146,30 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
 
             GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
             const int gqa_ratio = Q->ne[2] / K->ne[2];
-            GGML_ASSERT(gqa_ratio % 4 == 0);
-            if (gqa_ratio % 16 == 0) {
+            if (gqa_ratio == 20) { // GLM 4.7 Flash
+                if (cc >= GGML_CUDA_CC_BLACKWELL) {
+                    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+                    break;
+                }
+                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+                    if (Q->ne[1] <= 4) {
+                        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+                        break;
+                    }
+                    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+                    break;
+                }
+                if (cc >= GGML_CUDA_CC_TURING) {
+                    if (Q->ne[1] <= 4) {
+                        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst);
+                        break;
+                    }
+                    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+                    break;
+                }
+                // Volta:
+                ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+            } else if (gqa_ratio % 16 == 0) {
                 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
             } else {
                 ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512,  4>(ctx, dst);
@@ -234,7 +281,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
 
     // The effective batch size for the kernel can be increased by gqa_ratio.
     // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
-    bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
+    bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
     for (const ggml_tensor * t : {Q, K, V, mask}) {
         if (t == nullptr || ggml_is_quantized(t->type)) {
             continue;
@@ -268,7 +315,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
             if (V->ne[0] != 512) {
                 return BEST_FATTN_KERNEL_NONE;
             }
-            if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
+            if (!gqa_opt_applies) {
                 return BEST_FATTN_KERNEL_NONE;
             }
             if (!V_is_K_view) {

+ 5 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu

@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);

+ 5 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu

@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);

+ 3 - 3
ggml/src/ggml-cuda/template-instances/generate_cu_files.py

@@ -71,7 +71,7 @@ for type_k in TYPES_KV:
             f.write(SOURCE_FATTN_VEC.format(type_k=type_k, type_v=type_v))
 
 for ncols in [8, 16, 32, 64]:
-    for ncols2 in [1, 2, 4, 8, 16]:
+    for ncols2 in [1, 2, 4, 8, 16, 32]:
         if ncols2 > ncols:
             continue
         ncols1 = ncols // ncols2
@@ -83,9 +83,9 @@ for ncols in [8, 16, 32, 64]:
                     continue
                 if head_size_kq == 72:
                     continue
-                if head_size_kq != 576 and ncols2 == 16:
+                if head_size_kq != 576 and ncols2 in (16, 32):
                     continue
-                if head_size_kq == 576 and ncols2 not in (4, 16):
+                if head_size_kq == 576 and ncols2 not in (4, 16, 32):
                     continue
                 head_size_v = head_size_kq if head_size_kq != 576 else 512
                 f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))