Browse Source

CUDA: FA support for Deepseek (Ampere or newer) (#13306)

* CUDA: FA support for Deepseek (Ampere or newer)

* do loop unrolling via C++ template
Johannes Gäßler 8 months ago
parent
commit
0cf6725e9f
33 changed files with 719 additions and 445 deletions
  1. 1 1
      ggml/src/ggml-cuda/CMakeLists.txt
  2. 19 0
      ggml/src/ggml-cuda/common.cuh
  3. 11 0
      ggml/src/ggml-cuda/cp-async.cuh
  4. 13 13
      ggml/src/ggml-cuda/fattn-common.cuh
  5. 457 268
      ggml/src/ggml-cuda/fattn-mma-f16.cuh
  6. 2 2
      ggml/src/ggml-cuda/fattn-tile-f16.cu
  7. 2 2
      ggml/src/ggml-cuda/fattn-tile-f32.cu
  8. 1 1
      ggml/src/ggml-cuda/fattn-vec-f16.cuh
  9. 1 1
      ggml/src/ggml-cuda/fattn-vec-f32.cuh
  10. 1 1
      ggml/src/ggml-cuda/fattn-wmma-f16.cu
  11. 71 45
      ggml/src/ggml-cuda/fattn.cu
  12. 6 6
      ggml/src/ggml-cuda/ggml-cuda.cu
  13. 5 0
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu
  14. 6 6
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu
  15. 6 6
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu
  16. 6 6
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu
  17. 6 6
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
  18. 5 0
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu
  19. 6 6
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
  20. 6 6
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu
  21. 6 6
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu
  22. 6 6
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu
  23. 5 0
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu
  24. 6 6
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu
  25. 6 6
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
  26. 6 6
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
  27. 6 6
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu
  28. 6 6
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu
  29. 6 6
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu
  30. 6 6
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
  31. 6 6
      ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
  32. 12 9
      ggml/src/ggml-cuda/template-instances/generate_cu_files.py
  33. 11 0
      src/llama-graph.cpp

+ 1 - 1
ggml/src/ggml-cuda/CMakeLists.txt

@@ -118,7 +118,7 @@ if (CUDAToolkit_FOUND)
 
     set(CUDA_CXX_FLAGS "")
 
-    set(CUDA_FLAGS -use_fast_math)
+    set(CUDA_FLAGS -use_fast_math -extended-lambda)
 
     if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
         # Options are:

+ 19 - 0
ggml/src/ggml-cuda/common.cuh

@@ -296,6 +296,25 @@ static __device__ void no_device_code(
 #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
 #endif // __CUDA_ARCH__
 
+// The compiler is always able to unroll loops if they contain continue expressions.
+// In such cases loop unrolling can still be achieved via recursion:
+template <int n>
+struct ggml_cuda_unroll {
+    template <typename Func, typename... Args>
+    __device__ void operator()(const Func & f, Args... args) const {
+        f(n - 1, args...);
+        ggml_cuda_unroll<n - 1>{}(f, args...);
+    }
+};
+
+template <>
+struct ggml_cuda_unroll<1> {
+    template <typename Func, typename... Args>
+    __device__ void operator()(const Func & f, Args... args) const {
+        f(0, args...);
+    }
+};
+
 template<int width = WARP_SIZE>
 static __device__ __forceinline__ int warp_reduce_sum(int x) {
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE

+ 11 - 0
ggml/src/ggml-cuda/cp-async.cuh

@@ -2,6 +2,17 @@
 
 #include "common.cuh"
 
+
+static __device__ __forceinline__ unsigned int ggml_cuda_cvta_generic_to_shared(void * generic_ptr) {
+#ifdef CP_ASYNC_AVAILABLE
+    return __cvta_generic_to_shared(generic_ptr);
+#else
+    GGML_UNUSED(generic_ptr);
+    NO_DEVICE_CODE;
+    return 0;
+#endif // CP_ASYNC_AVAILABLE
+}
+
 // Copies data from global to shared memory, cg == cache global.
 // Both the src and dst pointers must be aligned to 16 bit.
 // Shared memory uses 32 bit addressing, the pointer is passed as unsigned int.

+ 13 - 13
ggml/src/ggml-cuda/fattn-common.cuh

@@ -516,7 +516,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
         nullptr;
 }
 
-template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
+template<int D, int ncols1, int ncols2> // D == head size
 __launch_bounds__(D, 1)
 static __global__ void flash_attn_stream_k_fixup(
         float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
@@ -665,13 +665,13 @@ static void on_no_fattn_vec_case(const int D) {
         fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
         GGML_ABORT("fatal error");
     } else {
-        fprintf(stderr, "Unsupported KV type combination for head_size 256.\n");
+        fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
         fprintf(stderr, "Only f16 is supported.\n");
         GGML_ABORT("fatal error");
     }
 }
 
-template <int D, int ncols1, int ncols2, int KQ_stride>
+template <int DV, int ncols1, int ncols2>
 void launch_fattn(
     ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
     const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE
@@ -691,7 +691,7 @@ void launch_fattn(
 
     GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
     GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
-                                "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
+        "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
 
     GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
 
@@ -754,10 +754,13 @@ void launch_fattn(
     const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
 
     const dim3 block_dim(warp_size, nwarps, 1);
+    int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
+    CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
+
     dim3 blocks_num;
     if (stream_k) {
         // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
-        const int max_blocks = 2*nsm;
+        const int max_blocks = max_blocks_per_sm*nsm;
         const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
         const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
 
@@ -769,14 +772,11 @@ void launch_fattn(
         blocks_num.y = 1;
         blocks_num.z = 1;
 
-        dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
+        dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
     } else {
         GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
         const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
 
-        int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
-        CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
-
         // parallel_blocks should be at least large enough to achieve max. occupancy for a single wave:
         parallel_blocks = std::max((nsm * max_blocks_per_sm) / ntiles_total, 1);
 
@@ -853,19 +853,19 @@ void launch_fattn(
 
     if (stream_k) {
         if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
-            const dim3 block_dim_combine(D, 1, 1);
+            const dim3 block_dim_combine(DV, 1, 1);
             const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
 
-            flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
+            flash_attn_stream_k_fixup<DV, ncols1, ncols2>
                 <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
                 ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
         }
     } else if (parallel_blocks > 1) {
-        const dim3 block_dim_combine(D, 1, 1);
+        const dim3 block_dim_combine(DV, 1, 1);
         const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
         const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
 
-        flash_attn_combine_results<D>
+        flash_attn_combine_results<DV>
             <<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>>
             (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks);
     }

File diff suppressed because it is too large
+ 457 - 268
ggml/src/ggml-cuda/fattn-mma-f16.cuh


+ 2 - 2
ggml/src/ggml-cuda/fattn-tile-f16.cu

@@ -307,7 +307,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, 1, -1>
+            launch_fattn<D, cols_per_block, 1>
                 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
         } break;
         case 128: {
@@ -315,7 +315,7 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, 1, -1>
+            launch_fattn<D, cols_per_block, 1>
                 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
         } break;
         default: {

+ 2 - 2
ggml/src/ggml-cuda/fattn-tile-f32.cu

@@ -318,7 +318,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, 1, -1>
+            launch_fattn<D, cols_per_block, 1>
                 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
         } break;
         case 128: {
@@ -326,7 +326,7 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int    nwarps        = 8;
             constexpr size_t nbytes_shared = 0;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
-            launch_fattn<D, cols_per_block, 1, -1>
+            launch_fattn<D, cols_per_block, 1>
                 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
         } break;
         default: {

+ 1 - 1
ggml/src/ggml-cuda/fattn-vec-f16.cuh

@@ -315,7 +315,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
     constexpr size_t nbytes_shared = 0;
-    launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
+    launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
 }
 
 template <int D, ggml_type type_K, ggml_type type_V>

+ 1 - 1
ggml/src/ggml-cuda/fattn-vec-f32.cuh

@@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
     constexpr size_t nbytes_shared = 0;
-    launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
+    launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
 }
 
 template <int D, ggml_type type_K, ggml_type type_V>

+ 1 - 1
ggml/src/ggml-cuda/fattn-wmma-f16.cu

@@ -490,7 +490,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
         fattn_kernel = flash_attn_ext_f16<
             D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
     }
-    launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
+    launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
 }
 
 void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

+ 71 - 45
ggml/src/ggml-cuda/fattn.cu

@@ -8,58 +8,32 @@
 #include "fattn-wmma-f16.cuh"
 #include "fattn.cuh"
 
-template <int D, int ncols2>
+template <int DKQ, int DV, int ncols2>
 static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * Q = dst->src[0];
 
-    if (Q->ne[1] <= 8/ncols2) {
-        ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst);
-        return;
+    if constexpr (ncols2 <= 8) {
+        if (Q->ne[1] <= 8/ncols2) {
+            ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 8/ncols2, ncols2>(ctx, dst);
+            return;
+        }
     }
 
     if (Q->ne[1] <= 16/ncols2) {
-        ggml_cuda_flash_attn_ext_mma_f16_case<D, 16/ncols2, ncols2>(ctx, dst);
+        ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst);
         return;
     }
 
     if (Q->ne[1] <= 32/ncols2) {
-        ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst);
+        ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
         return;
     }
 
-    ggml_cuda_flash_attn_ext_mma_f16_case<D, 64/ncols2, ncols2>(ctx, dst);
-}
-
-template <int ncols2>
-static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * Q = dst->src[0];
-
-    switch (Q->ne[0]) {
-        case 64:
-            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
-            break;
-        case 80:
-            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
-            break;
-        case 96:
-            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
-            break;
-        case 112:
-            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
-            break;
-        case 128:
-            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
-            break;
-        case 256:
-            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
-            break;
-        default:
-            GGML_ABORT("fatal error");
-            break;
-    }
+    ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 64/ncols2, ncols2>(ctx, dst);
 }
 
-static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+template <int DKQ, int DV>
+static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * KQV  = dst;
     const ggml_tensor * Q    = dst->src[0];
     const ggml_tensor * K    = dst->src[1];
@@ -68,27 +42,79 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
     float max_bias = 0.0f;
     memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
 
-    const float use_gqa_opt = mask && max_bias == 0.0f;
+    const bool use_gqa_opt = mask && max_bias == 0.0f;
 
     GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
     const int gqa_ratio = Q->ne[2] / K->ne[2];
 
     if (use_gqa_opt && gqa_ratio % 8 == 0) {
-        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
+        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
         return;
     }
 
-    if (use_gqa_opt && gqa_ratio == 4) {
-        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
+    if (use_gqa_opt && gqa_ratio % 4 == 0) {
+        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst);
         return;
     }
 
-    if (use_gqa_opt && gqa_ratio == 2) {
-        ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
+    if (use_gqa_opt && gqa_ratio % 2 == 0) {
+        ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst);
         return;
     }
 
-    ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
+    ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst);
+}
+
+static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV  = dst;
+    const ggml_tensor * Q    = dst->src[0];
+    const ggml_tensor * K    = dst->src[1];
+    const ggml_tensor * V    = dst->src[2];
+    const ggml_tensor * mask = dst->src[3];
+
+    switch (Q->ne[0]) {
+        case 64:
+            GGML_ASSERT(V->ne[0] == 64);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 64,  64>(ctx, dst);
+            break;
+        case 80:
+            GGML_ASSERT(V->ne[0] == 80);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 80,  80>(ctx, dst);
+            break;
+        case 96:
+            GGML_ASSERT(V->ne[0] == 96);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2< 96,  96>(ctx, dst);
+            break;
+        case 112:
+            GGML_ASSERT(V->ne[0] == 112);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<112, 112>(ctx, dst);
+            break;
+        case 128:
+            GGML_ASSERT(V->ne[0] == 128);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst);
+            break;
+        case 256:
+            GGML_ASSERT(V->ne[0] == 256);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
+            break;
+        case 576: {
+            // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
+            GGML_ASSERT(V->ne[0] == 512);
+            float max_bias = 0.0f;
+            memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+            const bool use_gqa_opt = mask && max_bias == 0.0f;
+            GGML_ASSERT(use_gqa_opt);
+
+            GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+            const int gqa_ratio = Q->ne[2] / K->ne[2];
+            GGML_ASSERT(gqa_ratio % 16 == 0);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+        } break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
 }
 
 #define FATTN_VEC_F16_CASE(D, type_K, type_V)                               \
@@ -299,7 +325,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
     const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
     const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
-    const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0;
+    const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
     if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
         if (prec == GGML_PREC_DEFAULT) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);

+ 6 - 6
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -3215,16 +3215,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             return false;
 #endif // FLASH_ATTN_AVAILABLE
             if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
-                // different head sizes of K and V are not supported yet
-                return false;
+                const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
+                if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) {
+                    return false;
+                }
+                const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
+                return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
             }
             if (op->src[0]->ne[0] == 192) {
                 return false;
             }
-            if (op->src[0]->ne[0] == 576) {
-                // DeepSeek MLA
-                return false;
-            }
             if (op->src[0]->ne[3] != 1) {
                 return false;
             }

+ 5 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu

@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);

+ 6 - 6
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu

@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 1, 8);
-DECL_FATTN_MMA_F16_CASE(80, 1, 8);
-DECL_FATTN_MMA_F16_CASE(96, 1, 8);
-DECL_FATTN_MMA_F16_CASE(112, 1, 8);
-DECL_FATTN_MMA_F16_CASE(128, 1, 8);
-DECL_FATTN_MMA_F16_CASE(256, 1, 8);
+DECL_FATTN_MMA_F16_CASE(64, 64, 1, 8);
+DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8);
+DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8);
+DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8);
+DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8);
+DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8);

+ 6 - 6
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu

@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 16, 1);
-DECL_FATTN_MMA_F16_CASE(80, 16, 1);
-DECL_FATTN_MMA_F16_CASE(96, 16, 1);
-DECL_FATTN_MMA_F16_CASE(112, 16, 1);
-DECL_FATTN_MMA_F16_CASE(128, 16, 1);
-DECL_FATTN_MMA_F16_CASE(256, 16, 1);
+DECL_FATTN_MMA_F16_CASE(64, 64, 16, 1);
+DECL_FATTN_MMA_F16_CASE(80, 80, 16, 1);
+DECL_FATTN_MMA_F16_CASE(96, 96, 16, 1);
+DECL_FATTN_MMA_F16_CASE(112, 112, 16, 1);
+DECL_FATTN_MMA_F16_CASE(128, 128, 16, 1);
+DECL_FATTN_MMA_F16_CASE(256, 256, 16, 1);

+ 6 - 6
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu

@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 16, 2);
-DECL_FATTN_MMA_F16_CASE(80, 16, 2);
-DECL_FATTN_MMA_F16_CASE(96, 16, 2);
-DECL_FATTN_MMA_F16_CASE(112, 16, 2);
-DECL_FATTN_MMA_F16_CASE(128, 16, 2);
-DECL_FATTN_MMA_F16_CASE(256, 16, 2);
+DECL_FATTN_MMA_F16_CASE(64, 64, 16, 2);
+DECL_FATTN_MMA_F16_CASE(80, 80, 16, 2);
+DECL_FATTN_MMA_F16_CASE(96, 96, 16, 2);
+DECL_FATTN_MMA_F16_CASE(112, 112, 16, 2);
+DECL_FATTN_MMA_F16_CASE(128, 128, 16, 2);
+DECL_FATTN_MMA_F16_CASE(256, 256, 16, 2);

+ 6 - 6
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu

@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 16, 4);
-DECL_FATTN_MMA_F16_CASE(80, 16, 4);
-DECL_FATTN_MMA_F16_CASE(96, 16, 4);
-DECL_FATTN_MMA_F16_CASE(112, 16, 4);
-DECL_FATTN_MMA_F16_CASE(128, 16, 4);
-DECL_FATTN_MMA_F16_CASE(256, 16, 4);
+DECL_FATTN_MMA_F16_CASE(64, 64, 16, 4);
+DECL_FATTN_MMA_F16_CASE(80, 80, 16, 4);
+DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
+DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
+DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
+DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);

+ 5 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu

@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);

+ 6 - 6
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu

@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 2, 4);
-DECL_FATTN_MMA_F16_CASE(80, 2, 4);
-DECL_FATTN_MMA_F16_CASE(96, 2, 4);
-DECL_FATTN_MMA_F16_CASE(112, 2, 4);
-DECL_FATTN_MMA_F16_CASE(128, 2, 4);
-DECL_FATTN_MMA_F16_CASE(256, 2, 4);
+DECL_FATTN_MMA_F16_CASE(64, 64, 2, 4);
+DECL_FATTN_MMA_F16_CASE(80, 80, 2, 4);
+DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
+DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
+DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
+DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);

+ 6 - 6
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu

@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 2, 8);
-DECL_FATTN_MMA_F16_CASE(80, 2, 8);
-DECL_FATTN_MMA_F16_CASE(96, 2, 8);
-DECL_FATTN_MMA_F16_CASE(112, 2, 8);
-DECL_FATTN_MMA_F16_CASE(128, 2, 8);
-DECL_FATTN_MMA_F16_CASE(256, 2, 8);
+DECL_FATTN_MMA_F16_CASE(64, 64, 2, 8);
+DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8);
+DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8);
+DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8);
+DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8);
+DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8);

+ 6 - 6
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu

@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 32, 1);
-DECL_FATTN_MMA_F16_CASE(80, 32, 1);
-DECL_FATTN_MMA_F16_CASE(96, 32, 1);
-DECL_FATTN_MMA_F16_CASE(112, 32, 1);
-DECL_FATTN_MMA_F16_CASE(128, 32, 1);
-DECL_FATTN_MMA_F16_CASE(256, 32, 1);
+DECL_FATTN_MMA_F16_CASE(64, 64, 32, 1);
+DECL_FATTN_MMA_F16_CASE(80, 80, 32, 1);
+DECL_FATTN_MMA_F16_CASE(96, 96, 32, 1);
+DECL_FATTN_MMA_F16_CASE(112, 112, 32, 1);
+DECL_FATTN_MMA_F16_CASE(128, 128, 32, 1);
+DECL_FATTN_MMA_F16_CASE(256, 256, 32, 1);

+ 6 - 6
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu

@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 32, 2);
-DECL_FATTN_MMA_F16_CASE(80, 32, 2);
-DECL_FATTN_MMA_F16_CASE(96, 32, 2);
-DECL_FATTN_MMA_F16_CASE(112, 32, 2);
-DECL_FATTN_MMA_F16_CASE(128, 32, 2);
-DECL_FATTN_MMA_F16_CASE(256, 32, 2);
+DECL_FATTN_MMA_F16_CASE(64, 64, 32, 2);
+DECL_FATTN_MMA_F16_CASE(80, 80, 32, 2);
+DECL_FATTN_MMA_F16_CASE(96, 96, 32, 2);
+DECL_FATTN_MMA_F16_CASE(112, 112, 32, 2);
+DECL_FATTN_MMA_F16_CASE(128, 128, 32, 2);
+DECL_FATTN_MMA_F16_CASE(256, 256, 32, 2);

+ 5 - 0
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu

@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-mma-f16.cuh"
+
+DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);

+ 6 - 6
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_2.cu

@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 4, 2);
-DECL_FATTN_MMA_F16_CASE(80, 4, 2);
-DECL_FATTN_MMA_F16_CASE(96, 4, 2);
-DECL_FATTN_MMA_F16_CASE(112, 4, 2);
-DECL_FATTN_MMA_F16_CASE(128, 4, 2);
-DECL_FATTN_MMA_F16_CASE(256, 4, 2);
+DECL_FATTN_MMA_F16_CASE(64, 64, 4, 2);
+DECL_FATTN_MMA_F16_CASE(80, 80, 4, 2);
+DECL_FATTN_MMA_F16_CASE(96, 96, 4, 2);
+DECL_FATTN_MMA_F16_CASE(112, 112, 4, 2);
+DECL_FATTN_MMA_F16_CASE(128, 128, 4, 2);
+DECL_FATTN_MMA_F16_CASE(256, 256, 4, 2);

+ 6 - 6
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu

@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 4, 4);
-DECL_FATTN_MMA_F16_CASE(80, 4, 4);
-DECL_FATTN_MMA_F16_CASE(96, 4, 4);
-DECL_FATTN_MMA_F16_CASE(112, 4, 4);
-DECL_FATTN_MMA_F16_CASE(128, 4, 4);
-DECL_FATTN_MMA_F16_CASE(256, 4, 4);
+DECL_FATTN_MMA_F16_CASE(64, 64, 4, 4);
+DECL_FATTN_MMA_F16_CASE(80, 80, 4, 4);
+DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
+DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
+DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
+DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);

+ 6 - 6
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu

@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 4, 8);
-DECL_FATTN_MMA_F16_CASE(80, 4, 8);
-DECL_FATTN_MMA_F16_CASE(96, 4, 8);
-DECL_FATTN_MMA_F16_CASE(112, 4, 8);
-DECL_FATTN_MMA_F16_CASE(128, 4, 8);
-DECL_FATTN_MMA_F16_CASE(256, 4, 8);
+DECL_FATTN_MMA_F16_CASE(64, 64, 4, 8);
+DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8);
+DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8);
+DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8);
+DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8);
+DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8);

+ 6 - 6
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu

@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 64, 1);
-DECL_FATTN_MMA_F16_CASE(80, 64, 1);
-DECL_FATTN_MMA_F16_CASE(96, 64, 1);
-DECL_FATTN_MMA_F16_CASE(112, 64, 1);
-DECL_FATTN_MMA_F16_CASE(128, 64, 1);
-DECL_FATTN_MMA_F16_CASE(256, 64, 1);
+DECL_FATTN_MMA_F16_CASE(64, 64, 64, 1);
+DECL_FATTN_MMA_F16_CASE(80, 80, 64, 1);
+DECL_FATTN_MMA_F16_CASE(96, 96, 64, 1);
+DECL_FATTN_MMA_F16_CASE(112, 112, 64, 1);
+DECL_FATTN_MMA_F16_CASE(128, 128, 64, 1);
+DECL_FATTN_MMA_F16_CASE(256, 256, 64, 1);

+ 6 - 6
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu

@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 8, 1);
-DECL_FATTN_MMA_F16_CASE(80, 8, 1);
-DECL_FATTN_MMA_F16_CASE(96, 8, 1);
-DECL_FATTN_MMA_F16_CASE(112, 8, 1);
-DECL_FATTN_MMA_F16_CASE(128, 8, 1);
-DECL_FATTN_MMA_F16_CASE(256, 8, 1);
+DECL_FATTN_MMA_F16_CASE(64, 64, 8, 1);
+DECL_FATTN_MMA_F16_CASE(80, 80, 8, 1);
+DECL_FATTN_MMA_F16_CASE(96, 96, 8, 1);
+DECL_FATTN_MMA_F16_CASE(112, 112, 8, 1);
+DECL_FATTN_MMA_F16_CASE(128, 128, 8, 1);
+DECL_FATTN_MMA_F16_CASE(256, 256, 8, 1);

+ 6 - 6
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu

@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 8, 2);
-DECL_FATTN_MMA_F16_CASE(80, 8, 2);
-DECL_FATTN_MMA_F16_CASE(96, 8, 2);
-DECL_FATTN_MMA_F16_CASE(112, 8, 2);
-DECL_FATTN_MMA_F16_CASE(128, 8, 2);
-DECL_FATTN_MMA_F16_CASE(256, 8, 2);
+DECL_FATTN_MMA_F16_CASE(64, 64, 8, 2);
+DECL_FATTN_MMA_F16_CASE(80, 80, 8, 2);
+DECL_FATTN_MMA_F16_CASE(96, 96, 8, 2);
+DECL_FATTN_MMA_F16_CASE(112, 112, 8, 2);
+DECL_FATTN_MMA_F16_CASE(128, 128, 8, 2);
+DECL_FATTN_MMA_F16_CASE(256, 256, 8, 2);

+ 6 - 6
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu

@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 8, 4);
-DECL_FATTN_MMA_F16_CASE(80, 8, 4);
-DECL_FATTN_MMA_F16_CASE(96, 8, 4);
-DECL_FATTN_MMA_F16_CASE(112, 8, 4);
-DECL_FATTN_MMA_F16_CASE(128, 8, 4);
-DECL_FATTN_MMA_F16_CASE(256, 8, 4);
+DECL_FATTN_MMA_F16_CASE(64, 64, 8, 4);
+DECL_FATTN_MMA_F16_CASE(80, 80, 8, 4);
+DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
+DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
+DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
+DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);

+ 6 - 6
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu

@@ -2,9 +2,9 @@
 
 #include "../fattn-mma-f16.cuh"
 
-DECL_FATTN_MMA_F16_CASE(64, 8, 8);
-DECL_FATTN_MMA_F16_CASE(80, 8, 8);
-DECL_FATTN_MMA_F16_CASE(96, 8, 8);
-DECL_FATTN_MMA_F16_CASE(112, 8, 8);
-DECL_FATTN_MMA_F16_CASE(128, 8, 8);
-DECL_FATTN_MMA_F16_CASE(256, 8, 8);
+DECL_FATTN_MMA_F16_CASE(64, 64, 8, 8);
+DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8);
+DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8);
+DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8);
+DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8);
+DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8);

+ 12 - 9
ggml/src/ggml-cuda/template-instances/generate_cu_files.py

@@ -18,7 +18,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f
 
 """
 
-SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n"
+SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n"
 
 TYPES_MMQ = [
     "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
@@ -57,18 +57,21 @@ for vkq_size in [16, 32]:
                 with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
                     f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
 
-for ncols in [8, 16, 32, 64, 128]:
-    for ncols2 in [1, 2, 4, 8]:
+for ncols in [8, 16, 32, 64]:
+    for ncols2 in [1, 2, 4, 8, 16]:
+        if ncols2 > ncols:
+            continue
         ncols1 = ncols // ncols2
-        if ncols == 128:
-            continue  # Too much register pressure.
         with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
             f.write(SOURCE_FATTN_MMA_START)
 
-            for head_size in [64, 80, 96, 112, 128, 256]:
-                if ncols == 128 and head_size == 256:
-                    continue  # Needs too much shared memory.
-                f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size))
+            for head_size_kq in [64, 80, 96, 112, 128, 256, 576]:
+                if head_size_kq != 576 and ncols2 == 16:
+                    continue
+                if head_size_kq == 576 and ncols2 != 16:
+                    continue
+                head_size_v = head_size_kq if head_size_kq != 576 else 512
+                f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))
 
 for type in TYPES_MMQ:
     with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:

+ 11 - 0
src/llama-graph.cpp

@@ -1227,8 +1227,19 @@ ggml_tensor * llm_graph_context::build_attn_mha(
         ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
 
         if (v_mla) {
+#if 0
+            // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
+            // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
             cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
             cur = ggml_mul_mat(ctx0, v_mla, cur);
+#else
+            // It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
+            // The permutations are noops and only change how the tensor data is interpreted.
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            cur = ggml_mul_mat(ctx0, v_mla, cur);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
+#endif
         }
 
         cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);

Some files were not shown because too many files changed in this diff