Browse Source

CUDA: re-use MLA K data for V in MMA FA (#19057)

Johannes Gäßler 6 days ago
parent
commit
8f91ca54ec
3 changed files with 73 additions and 71 deletions
  1. 39 37
      ggml/src/ggml-cuda/fattn-common.cuh
  2. 29 34
      ggml/src/ggml-cuda/fattn-mma-f16.cuh
  3. 5 0
      ggml/src/ggml-cuda/fattn.cu

+ 39 - 37
ggml/src/ggml-cuda/fattn-common.cuh

@@ -782,12 +782,7 @@ void launch_fattn(
     const ggml_tensor * K = dst->src[1];
     const ggml_tensor * V = dst->src[2];
 
-    // TODO: make this more generic by removing the notion of "MLA".
-    //       for example "is V a view of K?" so we can skip loading it.
-    //       V strides should be driven by V itself and avoid assumption of the data layout
-    const bool is_mla = V->op == GGML_OP_VIEW && V->src[0] == K;
-
-    GGML_ASSERT(V || is_mla);
+    const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data;
 
     const ggml_tensor * mask  = dst->src[3];
     const ggml_tensor * sinks = dst->src[4];
@@ -797,9 +792,9 @@ void launch_fattn(
     GGML_ASSERT(Q->type == GGML_TYPE_F32);
     GGML_ASSERT(KQV->type == GGML_TYPE_F32);
 
-    GGML_ASSERT(      Q->nb[0] == ggml_element_size(Q));
-    GGML_ASSERT(      K->nb[0] == ggml_element_size(K));
-    GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
+    GGML_ASSERT(Q->nb[0] == ggml_element_size(Q));
+    GGML_ASSERT(K->nb[0] == ggml_element_size(K));
+    GGML_ASSERT(V->nb[0] == ggml_element_size(V));
 
     GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
 
@@ -820,10 +815,10 @@ void launch_fattn(
     size_t nb12 = K->nb[2];
     size_t nb13 = K->nb[3];
 
-    const char * V_data = V ? (const char *) V->data : nullptr;
-    size_t nb21 = V ? V->nb[1] : nb11;
-    size_t nb22 = V ? V->nb[2] : nb12;
-    size_t nb23 = V ? V->nb[3] : nb13;
+    const char * V_data = (const char *) V->data;
+    size_t nb21 = V->nb[1];
+    size_t nb22 = V->nb[2];
+    size_t nb23 = V->nb[3];
 
     if (need_f16_K && K->type != GGML_TYPE_F16) {
         const size_t bs = ggml_blck_size(K->type);
@@ -852,32 +847,39 @@ void launch_fattn(
         K_data = (char *) K_f16.ptr;
     }
 
-    if (V && need_f16_V && V->type != GGML_TYPE_F16) {
-        const size_t bs = ggml_blck_size(V->type);
-        const size_t ts = ggml_type_size(V->type);
-
-        V_f16.alloc(ggml_nelements(V));
-        if (ggml_is_contiguously_allocated(V)) {
-            to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
-            to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
-            V_data = (char *) V_f16.ptr;
-
-            nb21 = nb21*bs*sizeof(half)/ts;
-            nb22 = nb22*bs*sizeof(half)/ts;
-            nb23 = nb23*bs*sizeof(half)/ts;
+    if (need_f16_V && V->type != GGML_TYPE_F16) {
+        if (V_is_K_view) {
+            V_data = K_data;
+            nb21   = nb11;
+            nb22   = nb12;
+            nb23   = nb13;
         } else {
-            GGML_ASSERT(V->nb[0] == ts);
-            to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
-            const int64_t s01 = nb21 / ts;
-            const int64_t s02 = nb22 / ts;
-            const int64_t s03 = nb23 / ts;
-            to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
-
-            nb21 = V->ne[0] * sizeof(half);
-            nb22 = V->ne[1] * nb21;
-            nb23 = V->ne[2] * nb22;
+            const size_t bs = ggml_blck_size(V->type);
+            const size_t ts = ggml_type_size(V->type);
+
+            V_f16.alloc(ggml_nelements(V));
+            if (ggml_is_contiguously_allocated(V)) {
+                to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
+                to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
+                V_data = (char *) V_f16.ptr;
+
+                nb21 = nb21*bs*sizeof(half)/ts;
+                nb22 = nb22*bs*sizeof(half)/ts;
+                nb23 = nb23*bs*sizeof(half)/ts;
+            } else {
+                GGML_ASSERT(V->nb[0] == ts);
+                to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type);
+                const int64_t s01 = nb21 / ts;
+                const int64_t s02 = nb22 / ts;
+                const int64_t s03 = nb23 / ts;
+                to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
+
+                nb21 = V->ne[0] * sizeof(half);
+                nb22 = V->ne[1] * nb21;
+                nb23 = V->ne[2] * nb22;
+            }
+            V_data = (char *) V_f16.ptr;
         }
-        V_data = (char *) V_f16.ptr;
     }
 
     const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);

+ 29 - 34
ggml/src/ggml-cuda/fattn-mma-f16.cuh

@@ -400,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
 }
 
 template<int DKQ, int DV, int ncols1, int ncols2, int nwarps,
-    bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
+    bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
     typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ>
 static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const float2 * const __restrict__ Q_f2,
@@ -442,8 +442,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     constexpr int stride_tile_Q = DKQ/2     + 4;
     constexpr int stride_tile_K = nbatch_K2 + 4;
 
-    static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
-    constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
+    constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
 
     const int k_VKQ_0 = kb0 * nbatch_fa;
 #if defined(TURING_MMA_AVAILABLE)
@@ -456,7 +455,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 
     if constexpr (nstages > 1) {
         static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline");
-        static_assert(!mla, "multi-stage loading not implemented for MLA");
+        static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
         static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
         constexpr bool use_cp_async = true;
         cp_async_wait_all();
@@ -471,8 +470,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         }
     }
 
+    // For MLA K and V have the same data.
+    // Therefore, iterate over K in reverse and later re-use the data if possible.
 #pragma unroll
-    for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
+    for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) {
         const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
         const int k0_diff = k0_stop - k0_start;
 
@@ -776,6 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     }
 
     if constexpr (nstages > 1) {
+        static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
         // Preload K tile for next iteration:
         constexpr bool use_cp_async = true;
         cp_async_wait_all();
@@ -791,11 +793,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     }
 
 
-    // For MLA K and V have the same data.
-    // Therefore, iterate over V in reverse and re-use the data if possible.
-    static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
-    // constexpr int reusable_cutoff = mla ? (DV - 1) - (DV - 1) % (2*nbatch_K2) : DV;
-    constexpr int reusable_cutoff = DV; // TODO implement properly
 #if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
     T_A_VKQ A_identity;
     make_identity_mat(A_identity);
@@ -803,12 +800,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 
     // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
 #pragma unroll
-    for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
-        const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
-        const int i0_diff  = i0_stop - i0_start;
+    for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {
+        static_assert(DV % (2*nbatch_V2) == 0, "bad loop size");
+        const int i0_stop = i0_start + 2*nbatch_V2;
+        const int i0_diff = i0_stop - i0_start;
 
         if constexpr (nstages <= 1) {
-            if (i0_start < reusable_cutoff) {
+            if (!V_is_K_view || i0_stop > 2*nbatch_K2) {
                 constexpr bool use_cp_async = nstages == 1;
                 flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
                     (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);
@@ -818,7 +816,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                 __syncthreads();
             }
         }
-        const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
+        const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
 
 #if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
@@ -921,7 +919,7 @@ template<int ncols> struct mma_tile_sizes {
 };
 #endif // defined(TURING_MMA_AVAILABLE)
 
-template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup>
 static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const float2 * const __restrict__ Q_f2,
         const half2  * const __restrict__ K_h2,
@@ -975,8 +973,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     constexpr int stride_tile_Q = DKQ/2     + 4;
     constexpr int stride_tile_K = nbatch_K2 + 4;
 
-    static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
-    constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
+    constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
     constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
 
     extern __shared__ half2 tile_Q[];
@@ -1080,7 +1077,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             constexpr bool last_iter = false;
             constexpr int  k_VKQ_sup = nbatch_fa;
             flash_attn_ext_f16_iter
-                <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
+                <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
                  T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
                 (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
                  ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -1089,7 +1086,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         constexpr bool last_iter = true;
         const     int  k_VKQ_sup = ne11 - kb0*nbatch_fa;
         flash_attn_ext_f16_iter
-            <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
+            <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
               T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
             (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
              ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -1100,7 +1097,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
             constexpr bool last_iter = false;
             constexpr int  k_VKQ_sup = nbatch_fa;
             flash_attn_ext_f16_iter
-                <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
+                <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
                  T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
                 (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
                  ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -1109,7 +1106,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         constexpr bool last_iter = true;
         constexpr int  k_VKQ_sup = nbatch_fa;
         flash_attn_ext_f16_iter
-            <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
+            <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
              T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
             (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
              ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -1457,7 +1454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 #endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
 }
 
-template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool mla>
+template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>
 __launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))
 static __global__ void flash_attn_ext_f16(
         const char * __restrict__ Q,
@@ -1509,8 +1506,6 @@ static __global__ void flash_attn_ext_f16(
     }
 #endif // defined(AMD_WMMA_AVAILABLE)
 
-    static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
-
     constexpr int ncols     = ncols1 * ncols2;
     constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
     constexpr int nthreads  = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
@@ -1523,7 +1518,7 @@ static __global__ void flash_attn_ext_f16(
     const int stride_K    = nb11 / sizeof(half2);
     const int stride_mask = nb31 / sizeof(half);
 
-    const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
+    const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
 
     const int iter_k = (ne11   + (nbatch_fa - 1)) / nbatch_fa;
     const int iter_j = (ne01.z + (ncols1    - 1)) / ncols1;
@@ -1553,7 +1548,7 @@ static __global__ void flash_attn_ext_f16(
             (const half *) (mask + nb33*(sequence % ne33));
         float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
 
-        const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
+        const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
         const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
 
         const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
@@ -1564,12 +1559,12 @@ static __global__ void flash_attn_ext_f16(
         constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
         if (kb0_start == 0) {
             constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
-            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
+            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
                 (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
                  ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
         } else {
             constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
-            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
+            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
                 (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
                  ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
         }
@@ -1597,7 +1592,7 @@ static __global__ void flash_attn_ext_f16(
         (const half *) (mask + nb33*(sequence % ne33));
     float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
 
-    const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
+    const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
     const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
 
     const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
@@ -1608,7 +1603,7 @@ static __global__ void flash_attn_ext_f16(
 
     constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
     constexpr bool needs_fixup = false;
-    flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
+    flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
         (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
          ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
 #else
@@ -1644,7 +1639,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
     const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
     const int nwarps        = nthreads / WARP_SIZE;
 
-    constexpr bool mla = DKQ == 576;
+    constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
 
     const size_t nbytes_shared_KV_1stage = nbatch_fa            * std::max(nbatch_K2 + 4,  nbatch_V2 + 4) * sizeof(half2);
     const size_t nbytes_shared_KV_2stage = nbatch_fa            *         (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
@@ -1669,7 +1664,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
     fattn_kernel_t fattn_kernel;
     if (logit_softcap == 0.0f) {
         constexpr bool use_logit_softcap = false;
-        fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
+        fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
 
 #if !defined(GGML_USE_MUSA)
         static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
@@ -1680,7 +1675,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
 #endif // !defined(GGML_USE_MUSA)
     } else {
         constexpr bool use_logit_softcap = true;
-        fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
+        fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
 
 #if !defined(GGML_USE_MUSA)
         static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};

+ 5 - 0
ggml/src/ggml-cuda/fattn.cu

@@ -247,6 +247,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
         }
     }
 
+    const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data;
+
     const int cc = ggml_cuda_info().devices[device].cc;
 
     switch (K->ne[0]) {
@@ -269,6 +271,9 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
             if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
                 return BEST_FATTN_KERNEL_NONE;
             }
+            if (!V_is_K_view) {
+                return BEST_FATTN_KERNEL_NONE;
+            }
             break;
         default:
             return BEST_FATTN_KERNEL_NONE;