Procházet zdrojové kódy

CUDA: optimize FA for GQA + large batches (#12014)

Johannes Gäßler před 10 měsíci
rodič
revize
5fa07c2f93
32 změnil soubory, kde provedl 862 přidání a 356 odebrání
  1. 1 1
      ggml/src/ggml-cuda/cp-async.cuh
  2. 53 69
      ggml/src/ggml-cuda/fattn-common.cuh
  3. 475 198
      ggml/src/ggml-cuda/fattn-mma-f16.cuh
  4. 2 2
      ggml/src/ggml-cuda/fattn-tile-f16.cu
  5. 2 2
      ggml/src/ggml-cuda/fattn-tile-f32.cu
  6. 1 1
      ggml/src/ggml-cuda/fattn-vec-f16.cuh
  7. 1 1
      ggml/src/ggml-cuda/fattn-vec-f32.cuh
  8. 3 3
      ggml/src/ggml-cuda/fattn-wmma-f16.cu
  9. 56 17
      ggml/src/ggml-cuda/fattn.cu
  10. 75 0
      ggml/src/ggml-cuda/mma.cuh
  11. 0 10
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu
  12. 0 10
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu
  13. 0 10
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu
  14. 0 10
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu
  15. 10 0
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu
  16. 10 0
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu
  17. 10 0
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu
  18. 10 0
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
  19. 10 0
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
  20. 10 0
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu
  21. 10 0
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu
  22. 10 0
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu
  23. 10 0
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu
  24. 10 0
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
  25. 10 0
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
  26. 10 0
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu
  27. 10 0
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu
  28. 10 0
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu
  29. 10 0
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
  30. 10 0
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
  31. 13 7
      ggml/src/ggml-cuda/template-instances/generate_cu_files.py
  32. 20 15
      tests/test-backend-ops.cpp

+ 1 - 1
ggml/src/ggml-cuda/cp-async.cuh

@@ -24,7 +24,7 @@ static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, co
     } else
 #endif // CUDART_VERSION >= 11040
     {
-        asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;"
+        asm volatile("cp.async.cg.shared.global [%0], [%1], 16;"
             : : "r"(dst), "l"(src));
     }
 #else

+ 53 - 69
ggml/src/ggml-cuda/fattn-common.cuh

@@ -516,27 +516,25 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
         nullptr;
 }
 
-// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
-#ifdef __clang__
-#pragma clang diagnostic push
-#pragma clang diagnostic ignored "-Wpass-failed"
-#endif // __clang__
-
-template<int D, int ncols, int KQ_stride> // D == head size
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
 __launch_bounds__(D, 1)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 static __global__ void flash_attn_stream_k_fixup(
         float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
-    const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
-
-    const int iter_k = ne11 / KQ_stride;
-    const int iter_j = (ne01 + (ncols - 1)) / ncols;
+    constexpr int ncols = ncols1*ncols2;
 
     const int bidx0 = blockIdx.x;
+    const int j     = blockIdx.y;
+    const int c     = blockIdx.z;
+    const int jc    = j*ncols2 + c;
+    const int tid   = threadIdx.x;
+
+    const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
+
+    const int iter_k = ne11 / FATTN_KQ_STRIDE;
+    const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
 
-    const int kbc0      = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
-    const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
+    const int kbc0      = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
+    const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
 
     const bool did_not_have_any_data   = kbc0 == kbc0_stop;
     const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -548,22 +546,22 @@ static __global__ void flash_attn_stream_k_fixup(
     const int channel = kbc0 / (iter_k*iter_j);
     const int jt      = (kbc0 - channel*iter_k*iter_j) / iter_k;
 
-    dst += jt*ncols*ne02*D + channel*D;
+    if (jt*ncols1 + j >= ne01) {
+        return;
+    }
 
-    // Load the partial result that needs a fixup:
-    float dst_val[ncols] = {0.0f};
-    float max_val[ncols] = {0.0f};
-    float rowsum[ncols]  = {0.0f};
-#pragma unroll
-    for (int j = 0; j < ncols; ++j) {
-        if (jt*ncols + j >= ne01) {
-            break;
-        }
-        dst_val[j] = dst[j*ne02*D + threadIdx.x];
+    dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
 
-        const float2 tmp = dst_fixup[bidx0*ncols + j];
-        max_val[j] = tmp.x;
-        rowsum[j]  = tmp.y;
+    // Load the partial result that needs a fixup:
+    float dst_val = 0.0f;
+    float max_val = 0.0f;
+    float rowsum  = 0.0f;
+    {
+        dst_val = *dst;
+
+        const float2 tmp = dst_fixup[bidx0*ncols + jc];
+        max_val = tmp.x;
+        rowsum  = tmp.y;
     }
 
     // Iterate over previous blocks and compute the combined results.
@@ -571,36 +569,30 @@ static __global__ void flash_attn_stream_k_fixup(
     int bidx = bidx0 - 1;
     int kbc_stop = kbc0;
     while(true) {
-        const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
+        const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
         if (kbc == kbc_stop) { // Did not have any data.
             bidx--;
             kbc_stop = kbc;
             continue;
         }
 
-#pragma unroll
-        for (int j = 0; j < ncols; ++j) {
-            if (jt*ncols + j >= ne01) {
-                break;
-            }
-            const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
+        const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
 
-            const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
+        const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];
 
-            // Scale the current and new value accumulators depending on the max. values.
-            const float max_val_new = fmaxf(max_val[j], tmp.x);
+        // Scale the current and new value accumulators depending on the max. values.
+        const float max_val_new = fmaxf(max_val, tmp.x);
 
-            const float diff_val = max_val[j] - max_val_new;
-            const float diff_add = tmp.x      - max_val_new;
+        const float diff_val = max_val - max_val_new;
+        const float diff_add = tmp.x   - max_val_new;
 
-            const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
-            const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
+        const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
+        const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
 
-            dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
-            rowsum[j]  = scale_val*rowsum[j]  + scale_add*tmp.y;
+        dst_val = scale_val*dst_val + scale_add*dst_add;
+        rowsum  = scale_val*rowsum  + scale_add*tmp.y;
 
-            max_val[j] = max_val_new;
-        }
+        max_val = max_val_new;
 
         // If this block started in a previous tile we are done and don't need to combine additional partial results.
         if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
@@ -611,19 +603,9 @@ static __global__ void flash_attn_stream_k_fixup(
     }
 
     // Write back final result:
-#pragma unroll
-    for (int j = 0; j < ncols; ++j) {
-        if (jt*ncols + j >= ne01) {
-            return;
-        }
-        dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
-    }
+    *dst = dst_val / rowsum;
 }
 
-#ifdef __clang__
-#pragma clang diagnostic pop
-#endif // __clang__
-
 template<int D, int parallel_blocks> // D == head size
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
@@ -690,11 +672,13 @@ static void on_no_fattn_vec_case(const int D) {
 }
 
 // parallel_blocks == 0 is stream-k decomposition
-template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
+template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
 void launch_fattn(
     ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
     const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
 ) {
+    constexpr int ncols = ncols1 * ncols2;
+
     const ggml_tensor * Q = dst->src[0];
     const ggml_tensor * K = dst->src[1];
     const ggml_tensor * V = dst->src[2];
@@ -763,25 +747,26 @@ void launch_fattn(
         nb23 = nb23*bs*sizeof(half)/ts;
     }
 
-    const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
-    const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
+    const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
+    const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
 
     const dim3 block_dim(WARP_SIZE, nwarps, 1);
     dim3 blocks_num;
     if (parallel_blocks == 0) {
         // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
-        const int tiles_nwaves  = (ntiles_total + 2*nsm - 1) / (2*nsm);
-        const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves);
+        const int max_blocks = 2*nsm;
+        const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
+        const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
 
-        const int nblocks_stream_k = 2*nsm;
+        const int nblocks_stream_k = max_blocks;
 
-        const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE;
+        const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
 
         blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
         blocks_num.y = 1;
         blocks_num.z = 1;
 
-        dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
+        dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
     } else {
         blocks_num.x = parallel_blocks*ntiles_x;
         blocks_num.y = Q->ne[2];
@@ -793,7 +778,6 @@ void launch_fattn(
         }
     }
 
-
     float scale         = 1.0f;
     float max_bias      = 0.0f;
     float logit_softcap = 0.0f;
@@ -832,9 +816,9 @@ void launch_fattn(
     if constexpr (parallel_blocks == 0) {
         if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
             const dim3 block_dim_combine(D, 1, 1);
-            const dim3 blocks_num_combine = blocks_num;
+            const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
 
-            flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
+            flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
                 <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
                 ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
         }

Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 475 - 198
ggml/src/ggml-cuda/fattn-mma-f16.cuh


+ 2 - 2
ggml/src/ggml-cuda/fattn-tile-f16.cu

@@ -302,14 +302,14 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
+            launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         case 128: {
             constexpr int    D             = 128;
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
+            launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         default: {
             GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");

+ 2 - 2
ggml/src/ggml-cuda/fattn-tile-f32.cu

@@ -296,14 +296,14 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
+            launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         case 128: {
             constexpr int    D             = 128;
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
+            launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
         } break;
         default: {
             GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");

+ 1 - 1
ggml/src/ggml-cuda/fattn-vec-f16.cuh

@@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
     constexpr size_t nbytes_shared = 0;
-    launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
+    launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
 }
 
 template <int D, ggml_type type_K, ggml_type type_V>

+ 1 - 1
ggml/src/ggml-cuda/fattn-vec-f32.cuh

@@ -290,7 +290,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
     constexpr size_t nbytes_shared = 0;
-    launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
+    launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
 }
 
 template <int D, ggml_type type_K, ggml_type type_V>

+ 3 - 3
ggml/src/ggml-cuda/fattn-wmma-f16.cu

@@ -478,7 +478,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
             fattn_kernel = flash_attn_ext_f16<
                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
         }
-        launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+        launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
         return;
     }
     if (2*blocks_num_pb1 < 2*nsm) {
@@ -493,7 +493,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
             fattn_kernel = flash_attn_ext_f16<
                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
         }
-        launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+        launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
         return;
     }
     constexpr int parallel_blocks = 1;
@@ -507,7 +507,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
         fattn_kernel = flash_attn_ext_f16<
             D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
     }
-    launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+    launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
 }
 
 void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

+ 56 - 17
ggml/src/ggml-cuda/fattn.cu

@@ -8,28 +8,50 @@
 #include "fattn-wmma-f16.cuh"
 #include "fattn.cuh"
 
-template <int cols_per_block>
+template <int D, int ncols2>
+static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
+
+    if (Q->ne[1] <= 8/ncols2) {
+        ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst);
+        return;
+    }
+
+    if (Q->ne[1] <= 16/ncols2) {
+        ggml_cuda_flash_attn_ext_mma_f16_case<D, 16/ncols2, ncols2>(ctx, dst);
+        return;
+    }
+
+    if (Q->ne[1] <= 32/ncols2) {
+        ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst);
+        return;
+    }
+
+    ggml_cuda_flash_attn_ext_mma_f16_case<D, 64/ncols2, ncols2>(ctx, dst);
+}
+
+template <int ncols2>
 static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * Q = dst->src[0];
 
     switch (Q->ne[0]) {
         case 64:
-            ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
             break;
         case 80:
-            ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
             break;
         case 96:
-            ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
             break;
         case 112:
-            ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
             break;
         case 128:
-            ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
             break;
         case 256:
-            ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
             break;
         default:
             GGML_ABORT("fatal error");
@@ -38,24 +60,35 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context
 }
 
 static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * Q = dst->src[0];
+    const ggml_tensor * KQV  = dst;
+    const ggml_tensor * Q    = dst->src[0];
+    const ggml_tensor * K    = dst->src[1];
+    const ggml_tensor * mask = dst->src[3];
+
+    float max_bias = 0.0f;
+    memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+    const float use_gqa_opt = mask && max_bias == 0.0f;
+
+    GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+    const int gqa_ratio = Q->ne[2] / K->ne[2];
 
-    if (Q->ne[1] <= 8) {
+    if (use_gqa_opt && gqa_ratio % 8 == 0) {
         ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
         return;
     }
 
-    if (Q->ne[1] <= 16) {
-        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst);
+    if (use_gqa_opt && gqa_ratio == 4) {
+        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
         return;
     }
 
-    if (Q->ne[1] <= 32) {
-        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst);
+    if (use_gqa_opt && gqa_ratio == 2) {
+        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
         return;
     }
 
-    ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst);
+    ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
 }
 
 #define FATTN_VEC_F16_CASE(D, type_K, type_V)                               \
@@ -209,8 +242,11 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
 }
 
 void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * KQV = dst;
-    const ggml_tensor * Q   = dst->src[0];
+    const ggml_tensor * KQV  = dst;
+    const ggml_tensor * Q    = dst->src[0];
+    const ggml_tensor * K    = dst->src[1];
+    const ggml_tensor * V    = dst->src[2];
+    const ggml_tensor * mask = dst->src[3];
 
     ggml_cuda_set_device(ctx.device);
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
@@ -252,7 +288,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
         return;
     }
 
-    if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
+    const int gqa_ratio = Q->ne[2] / K->ne[2];
+    const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
+        K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
+    if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) {
         if (prec == GGML_PREC_DEFAULT) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
             return;

+ 75 - 0
ggml/src/ggml-cuda/mma.cuh

@@ -73,6 +73,8 @@ namespace ggml_cuda_mma {
                 return threadIdx.x / 4;
             } else if constexpr (I == 16 && J == 8) {
                 return (l / 2) * 8 + threadIdx.x / 4;
+            } else if constexpr (I == 16 && J == 16) {
+                return ((l / 2) % 2) * 8 + threadIdx.x / 4;
             } else {
                 static_assert(I == -1 && J == -1, "template specialization not implemented");
             }
@@ -85,6 +87,8 @@ namespace ggml_cuda_mma {
                 return 4 * l + threadIdx.x % 4;
             } else if constexpr (I == 16 && J == 8) {
                 return 2 * (threadIdx.x % 4) + l % 2;
+            } else if constexpr (I == 16 && J == 16) {
+                return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
             } else {
                 static_assert(I == -1 && J == -1, "template specialization not implemented");
             }
@@ -289,6 +293,42 @@ namespace ggml_cuda_mma {
 #endif // NEW_MMA_AVAILABLE
     }
 
+    static __device__ __forceinline__ void mma(
+            tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
+#ifdef NEW_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+        asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
+        asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
+            : "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
+#else
+        // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
+            : "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+    }
+
     static __device__ __forceinline__ void mma(
             tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
 #ifdef NEW_MMA_AVAILABLE
@@ -316,4 +356,39 @@ namespace ggml_cuda_mma {
 #endif // NEW_MMA_AVAILABLE
     }
 
+    static __device__ __forceinline__ void mma(
+            tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
+#ifdef NEW_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        int       * Dxi = (int       *) D.x;
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+        asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
+        asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
+#else
+        // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
+        asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
+            : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
+            : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
+#else
+        GGML_UNUSED(D);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        NO_DEVICE_CODE;
+#endif // NEW_MMA_AVAILABLE
+    }
 }

+ 0 - 10
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb16.cu

@@ -1,10 +0,0 @@
-// This file has been autogenerated by generate_cu_files.py, do not edit manually.
-
-#include "../fattn-mma-f16.cuh"
-
-DECL_FATTN_MMA_F16_CASE(64, 16);
-DECL_FATTN_MMA_F16_CASE(80, 16);
-DECL_FATTN_MMA_F16_CASE(96, 16);
-DECL_FATTN_MMA_F16_CASE(112, 16);
-DECL_FATTN_MMA_F16_CASE(128, 16);
-DECL_FATTN_MMA_F16_CASE(256, 16);

+ 0 - 10
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb32.cu

@@ -1,10 +0,0 @@
-// This file has been autogenerated by generate_cu_files.py, do not edit manually.
-
-#include "../fattn-mma-f16.cuh"
-
-DECL_FATTN_MMA_F16_CASE(64, 32);
-DECL_FATTN_MMA_F16_CASE(80, 32);
-DECL_FATTN_MMA_F16_CASE(96, 32);
-DECL_FATTN_MMA_F16_CASE(112, 32);
-DECL_FATTN_MMA_F16_CASE(128, 32);
-DECL_FATTN_MMA_F16_CASE(256, 32);

+ 0 - 10
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb64.cu

@@ -1,10 +0,0 @@
-// This file has been autogenerated by generate_cu_files.py, do not edit manually.
-
-#include "../fattn-mma-f16.cuh"
-
-DECL_FATTN_MMA_F16_CASE(64, 64);
-DECL_FATTN_MMA_F16_CASE(80, 64);
-DECL_FATTN_MMA_F16_CASE(96, 64);
-DECL_FATTN_MMA_F16_CASE(112, 64);
-DECL_FATTN_MMA_F16_CASE(128, 64);
-DECL_FATTN_MMA_F16_CASE(256, 64);

+ 0 - 10
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-cpb8.cu

@@ -1,10 +0,0 @@
-// This file has been autogenerated by generate_cu_files.py, do not edit manually.
-
-#include "../fattn-mma-f16.cuh"
-
-DECL_FATTN_MMA_F16_CASE(64, 8);
-DECL_FATTN_MMA_F16_CASE(80, 8);
-DECL_FATTN_MMA_F16_CASE(96, 8);
-DECL_FATTN_MMA_F16_CASE(112, 8);
-DECL_FATTN_MMA_F16_CASE(128, 8);
-DECL_FATTN_MMA_F16_CASE(256, 8);

+ 10 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu

@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 1, 8);
+DECL_FATTN_MMA_F16_CASE(80, 1, 8);
+DECL_FATTN_MMA_F16_CASE(96, 1, 8);
+DECL_FATTN_MMA_F16_CASE(112, 1, 8);
+DECL_FATTN_MMA_F16_CASE(128, 1, 8);
+DECL_FATTN_MMA_F16_CASE(256, 1, 8);

+ 10 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu

@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 16, 1);
+DECL_FATTN_MMA_F16_CASE(80, 16, 1);
+DECL_FATTN_MMA_F16_CASE(96, 16, 1);
+DECL_FATTN_MMA_F16_CASE(112, 16, 1);
+DECL_FATTN_MMA_F16_CASE(128, 16, 1);
+DECL_FATTN_MMA_F16_CASE(256, 16, 1);

+ 10 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu

@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 16, 2);
+DECL_FATTN_MMA_F16_CASE(80, 16, 2);
+DECL_FATTN_MMA_F16_CASE(96, 16, 2);
+DECL_FATTN_MMA_F16_CASE(112, 16, 2);
+DECL_FATTN_MMA_F16_CASE(128, 16, 2);
+DECL_FATTN_MMA_F16_CASE(256, 16, 2);

+ 10 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu

@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 16, 4);
+DECL_FATTN_MMA_F16_CASE(80, 16, 4);
+DECL_FATTN_MMA_F16_CASE(96, 16, 4);
+DECL_FATTN_MMA_F16_CASE(112, 16, 4);
+DECL_FATTN_MMA_F16_CASE(128, 16, 4);
+DECL_FATTN_MMA_F16_CASE(256, 16, 4);

+ 10 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu

@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 2, 4);
+DECL_FATTN_MMA_F16_CASE(80, 2, 4);
+DECL_FATTN_MMA_F16_CASE(96, 2, 4);
+DECL_FATTN_MMA_F16_CASE(112, 2, 4);
+DECL_FATTN_MMA_F16_CASE(128, 2, 4);
+DECL_FATTN_MMA_F16_CASE(256, 2, 4);

+ 10 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu

@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 2, 8);
+DECL_FATTN_MMA_F16_CASE(80, 2, 8);
+DECL_FATTN_MMA_F16_CASE(96, 2, 8);
+DECL_FATTN_MMA_F16_CASE(112, 2, 8);
+DECL_FATTN_MMA_F16_CASE(128, 2, 8);
+DECL_FATTN_MMA_F16_CASE(256, 2, 8);

+ 10 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu

@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 32, 1);
+DECL_FATTN_MMA_F16_CASE(80, 32, 1);
+DECL_FATTN_MMA_F16_CASE(96, 32, 1);
+DECL_FATTN_MMA_F16_CASE(112, 32, 1);
+DECL_FATTN_MMA_F16_CASE(128, 32, 1);
+DECL_FATTN_MMA_F16_CASE(256, 32, 1);

+ 10 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu

@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 32, 2);
+DECL_FATTN_MMA_F16_CASE(80, 32, 2);
+DECL_FATTN_MMA_F16_CASE(96, 32, 2);
+DECL_FATTN_MMA_F16_CASE(112, 32, 2);
+DECL_FATTN_MMA_F16_CASE(128, 32, 2);
+DECL_FATTN_MMA_F16_CASE(256, 32, 2);

+ 10 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu

@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 4, 2);
+DECL_FATTN_MMA_F16_CASE(80, 4, 2);
+DECL_FATTN_MMA_F16_CASE(96, 4, 2);
+DECL_FATTN_MMA_F16_CASE(112, 4, 2);
+DECL_FATTN_MMA_F16_CASE(128, 4, 2);
+DECL_FATTN_MMA_F16_CASE(256, 4, 2);

+ 10 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu

@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 4, 4);
+DECL_FATTN_MMA_F16_CASE(80, 4, 4);
+DECL_FATTN_MMA_F16_CASE(96, 4, 4);
+DECL_FATTN_MMA_F16_CASE(112, 4, 4);
+DECL_FATTN_MMA_F16_CASE(128, 4, 4);
+DECL_FATTN_MMA_F16_CASE(256, 4, 4);

+ 10 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu

@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 4, 8);
+DECL_FATTN_MMA_F16_CASE(80, 4, 8);
+DECL_FATTN_MMA_F16_CASE(96, 4, 8);
+DECL_FATTN_MMA_F16_CASE(112, 4, 8);
+DECL_FATTN_MMA_F16_CASE(128, 4, 8);
+DECL_FATTN_MMA_F16_CASE(256, 4, 8);

+ 10 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu

@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 64, 1);
+DECL_FATTN_MMA_F16_CASE(80, 64, 1);
+DECL_FATTN_MMA_F16_CASE(96, 64, 1);
+DECL_FATTN_MMA_F16_CASE(112, 64, 1);
+DECL_FATTN_MMA_F16_CASE(128, 64, 1);
+DECL_FATTN_MMA_F16_CASE(256, 64, 1);

+ 10 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu

@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 8, 1);
+DECL_FATTN_MMA_F16_CASE(80, 8, 1);
+DECL_FATTN_MMA_F16_CASE(96, 8, 1);
+DECL_FATTN_MMA_F16_CASE(112, 8, 1);
+DECL_FATTN_MMA_F16_CASE(128, 8, 1);
+DECL_FATTN_MMA_F16_CASE(256, 8, 1);

+ 10 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu

@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 8, 2);
+DECL_FATTN_MMA_F16_CASE(80, 8, 2);
+DECL_FATTN_MMA_F16_CASE(96, 8, 2);
+DECL_FATTN_MMA_F16_CASE(112, 8, 2);
+DECL_FATTN_MMA_F16_CASE(128, 8, 2);
+DECL_FATTN_MMA_F16_CASE(256, 8, 2);

+ 10 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu

@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 8, 4);
+DECL_FATTN_MMA_F16_CASE(80, 8, 4);
+DECL_FATTN_MMA_F16_CASE(96, 8, 4);
+DECL_FATTN_MMA_F16_CASE(112, 8, 4);
+DECL_FATTN_MMA_F16_CASE(128, 8, 4);
+DECL_FATTN_MMA_F16_CASE(256, 8, 4);

+ 10 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu

@@ -0,0 +1,10 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(64, 8, 8);
+DECL_FATTN_MMA_F16_CASE(80, 8, 8);
+DECL_FATTN_MMA_F16_CASE(96, 8, 8);
+DECL_FATTN_MMA_F16_CASE(112, 8, 8);
+DECL_FATTN_MMA_F16_CASE(128, 8, 8);
+DECL_FATTN_MMA_F16_CASE(256, 8, 8);

+ 13 - 7
ggml/src/ggml-cuda/template-instances/generate_cu_files.py

@@ -18,7 +18,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f
 
 """
 
-SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {cols_per_block});\n"
+SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n"
 
 TYPES_MMQ = [
     "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
@@ -57,12 +57,18 @@ for vkq_size in [16, 32]:
                 with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
                     f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
 
-for cols_per_block in [8, 16, 32, 64]:
-    with open(f"fattn-mma-f16-instance-cpb{cols_per_block}.cu", "w") as f:
-        f.write(SOURCE_FATTN_MMA_START)
-
-        for head_size in [64, 80, 96, 112, 128, 256]:
-            f.write(SOURCE_FATTN_MMA_CASE.format(cols_per_block=cols_per_block, head_size=head_size))
+for ncols in [8, 16, 32, 64, 128]:
+    for ncols2 in [1, 2, 4, 8]:
+        ncols1 = ncols // ncols2
+        if ncols == 128:
+            continue  # Too much register pressure.
+        with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
+            f.write(SOURCE_FATTN_MMA_START)
+
+            for head_size in [64, 80, 96, 112, 128, 256]:
+                if ncols == 128 and head_size == 256:
+                    continue  # Needs too much shared memory.
+                f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size))
 
 for type in TYPES_MMQ:
     with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:

+ 20 - 15
tests/test-backend-ops.cpp

@@ -3119,6 +3119,7 @@ struct test_leaky_relu : public test_case {
 struct test_flash_attn_ext : public test_case {
     const int64_t hs; // head size
     const int64_t nh; // num heads
+    const int64_t nr; // repeat in Q, tests for grouped-query attention
     const int64_t kv; // kv size
     const int64_t nb; // batch size
 
@@ -3131,7 +3132,7 @@ struct test_flash_attn_ext : public test_case {
     std::array<int32_t, 4> permute;
 
     std::string vars() override {
-        return VARS_TO_STR9(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
+        return VARS_TO_STR10(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, permute);
     }
 
     double max_nmse_err() override {
@@ -3142,13 +3143,13 @@ struct test_flash_attn_ext : public test_case {
         GGML_UNUSED(t);
         // Just counting matmul costs:
         // Q*K^T is nb x hs x kv, P*V is nb x kv x hs, per head
-        return 2 * 2 * nh * nb * hs * kv;
+        return 2 * 2 * nh*nr * nb * hs * kv;
     }
 
-    test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8,
+    test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8,
                         bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16,
                         std::array<int32_t, 4> permute = {0, 1, 2, 3})
-        : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
+        : hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
@@ -3166,13 +3167,13 @@ struct test_flash_attn_ext : public test_case {
             return t;
         };
 
-        ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh, 1);
+        ggml_tensor * q = create_permuted(GGML_TYPE_F32, hs_padded, nb, nh*nr, 1);
         ggml_set_name(q, "q");
 
-        ggml_tensor * k = create_permuted(type_KV,       hs_padded, kv, nh, 1);
+        ggml_tensor * k = create_permuted(type_KV,       hs_padded, kv, nh,    1);
         ggml_set_name(k, "k");
 
-        ggml_tensor * v = create_permuted(type_KV,       hs_padded, kv, nh, 1);
+        ggml_tensor * v = create_permuted(type_KV,       hs_padded, kv, nh,    1);
         ggml_set_name(v, "v");
 
         ggml_tensor * m = nullptr;
@@ -4278,14 +4279,18 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                 if (!mask && max_bias > 0.0f) continue;
                 for (float logit_softcap : {0.0f, 10.0f}) {
                     if (hs != 128 && logit_softcap != 0.0f) continue;
-                    for (int nh : { 32, }) {
-                        for (int kv : { 512, 1024, }) {
-                            for (int nb : { 1, 3, 32, 35, }) {
-                                for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
-                                    test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
-                                    // run fewer test cases permuted
-                                    if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
-                                        test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
+                    for (int nh : { 4, }) {
+                        for (int nr : { 1, 4, 16 }) {
+                            if (nr == 16 && hs != 128) continue;
+                            for (int kv : { 512, 1024, }) {
+                                if (nr != 1 && kv != 512) continue;
+                                for (int nb : { 1, 3, 32, 35, }) {
+                                    for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
+                                        test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV));
+                                        // run fewer test cases permuted
+                                        if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+                                            test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3}));
+                                        }
                                     }
                                 }
                             }

Některé soubory nejsou zobrazeny, neboť je v těchto rozdílových datech změněno mnoho souborů