Browse Source

mla : make the V tensor a view of K (#18986)

* mla : pass V as a view of K to the FA op

* cuda : adjust mla logic to new layout

* kv-cache : fix rope shift

* tests : remove comment

* cuda : fix reusable_cutoff

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Georgi Gerganov 6 days ago
parent
commit
a5eaa1d6a3

+ 5 - 2
ggml/src/ggml-cuda/fattn-common.cuh

@@ -778,12 +778,15 @@ void launch_fattn(
 ) {
 ) {
     constexpr int ncols = ncols1 * ncols2;
     constexpr int ncols = ncols1 * ncols2;
 
 
-    const bool is_mla = DV == 512; // TODO better parameterization
-
     const ggml_tensor * Q = dst->src[0];
     const ggml_tensor * Q = dst->src[0];
     const ggml_tensor * K = dst->src[1];
     const ggml_tensor * K = dst->src[1];
     const ggml_tensor * V = dst->src[2];
     const ggml_tensor * V = dst->src[2];
 
 
+    // TODO: make this more generic by removing the notion of "MLA".
+    //       for example "is V a view of K?" so we can skip loading it.
+    //       V strides should be driven by V itself and avoid assumption of the data layout
+    const bool is_mla = V->op == GGML_OP_VIEW && V->src[0] == K;
+
     GGML_ASSERT(V || is_mla);
     GGML_ASSERT(V || is_mla);
 
 
     const ggml_tensor * mask  = dst->src[3];
     const ggml_tensor * mask  = dst->src[3];

+ 4 - 3
ggml/src/ggml-cuda/fattn-mma-f16.cuh

@@ -794,7 +794,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     // For MLA K and V have the same data.
     // For MLA K and V have the same data.
     // Therefore, iterate over V in reverse and re-use the data if possible.
     // Therefore, iterate over V in reverse and re-use the data if possible.
     static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
     static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
-    constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
+    // constexpr int reusable_cutoff = mla ? (DV - 1) - (DV - 1) % (2*nbatch_K2) : DV;
+    constexpr int reusable_cutoff = DV; // TODO implement properly
 #if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
 #if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
     T_A_VKQ A_identity;
     T_A_VKQ A_identity;
     make_identity_mat(A_identity);
     make_identity_mat(A_identity);
@@ -1552,7 +1553,7 @@ static __global__ void flash_attn_ext_f16(
             (const half *) (mask + nb33*(sequence % ne33));
             (const half *) (mask + nb33*(sequence % ne33));
         float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
         float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
 
 
-        const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
+        const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
         const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
         const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
 
 
         const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
         const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
@@ -1596,7 +1597,7 @@ static __global__ void flash_attn_ext_f16(
         (const half *) (mask + nb33*(sequence % ne33));
         (const half *) (mask + nb33*(sequence % ne33));
     float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
     float2       * dstk   = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
 
 
-    const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
+    const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
     const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
     const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
 
 
     const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
     const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;

+ 5 - 0
src/llama-graph.cpp

@@ -1565,6 +1565,11 @@ ggml_tensor * llm_graph_context::build_attn_mha(
             v = ggml_transpose(ctx0, v);
             v = ggml_transpose(ctx0, v);
         }
         }
 
 
+        // TODO: update llama_kv_cache to not store V cache in the MLA case and automatically return a view of K
+        if (v_mla) {
+            v = ggml_view_4d(ctx0, k, v->ne[0], v->ne[1], v->ne[2], v->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
+        }
+
         // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
         // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
         if (k->type == GGML_TYPE_F32) {
         if (k->type == GGML_TYPE_F32) {
             k = ggml_cast(ctx0, k, GGML_TYPE_F16);
             k = ggml_cast(ctx0, k, GGML_TYPE_F16);

+ 6 - 2
src/llama-kv-cache.cpp

@@ -1594,6 +1594,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
     const auto & n_embd_head_k = hparams.n_embd_head_k;
     const auto & n_embd_head_k = hparams.n_embd_head_k;
   //const auto & n_embd_head_v = hparams.n_embd_head_v;
   //const auto & n_embd_head_v = hparams.n_embd_head_v;
 
 
+    const auto & n_rot = hparams.n_rot;
+
+    const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;
+
     auto inp = std::make_unique<llm_graph_input_k_shift>(this);
     auto inp = std::make_unique<llm_graph_input_k_shift>(this);
 
 
     inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
     inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
@@ -1614,10 +1618,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
 
 
         ggml_tensor * k =
         ggml_tensor * k =
             ggml_view_3d(ctx, layer.k,
             ggml_view_3d(ctx, layer.k,
-                n_embd_head_k, n_head_kv, get_size()*n_stream,
+                n_rot, n_head_kv, get_size()*n_stream,
                 ggml_row_size(layer.k->type, n_embd_head_k),
                 ggml_row_size(layer.k->type, n_embd_head_k),
                 ggml_row_size(layer.k->type, n_embd_k_gqa),
                 ggml_row_size(layer.k->type, n_embd_k_gqa),
-                0);
+                ggml_row_size(layer.k->type, n_embd_nope));
 
 
         ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
         ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
 
 

+ 4 - 5
src/models/deepseek2.cpp

@@ -124,14 +124,14 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
 
 
                 // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
                 // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
                 // note: rope must go first for in-place context shifting in build_rope_shift()
                 // note: rope must go first for in-place context shifting in build_rope_shift()
-                ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
+                ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
                 cb(Qcur, "Qcur", il);
                 cb(Qcur, "Qcur", il);
 
 
                 kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
                 kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
                 cb(kv_cmpr, "kv_cmpr_reshape", il);
                 cb(kv_cmpr, "kv_cmpr_reshape", il);
 
 
                 // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
                 // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
-                ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0);
+                ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0);
                 cb(Kcur, "Kcur", il);
                 cb(Kcur, "Kcur", il);
 
 
                 // {kv_lora_rank, 1, n_tokens}
                 // {kv_lora_rank, 1, n_tokens}
@@ -169,11 +169,10 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
                 Vcur = ggml_cont(ctx0, Vcur);
                 Vcur = ggml_cont(ctx0, Vcur);
                 cb(Vcur, "Vcur_cont", il);
                 cb(Vcur, "Vcur_cont", il);
 
 
-                // note: rope must go first for in-place context shifting in build_rope_shift()
-                ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0);
+                ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0);
                 cb(Qcur, "Qcur", il);
                 cb(Qcur, "Qcur", il);
 
 
-                ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0);
+                ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
                 cb(Kcur, "Kcur", il);
                 cb(Kcur, "Kcur", il);
 
 
                 if (inp_attn_scale) {
                 if (inp_attn_scale) {

+ 1 - 0
src/models/minicpm3.cpp

@@ -9,6 +9,7 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap
 
 
     const uint32_t n_embd_head_qk_rope = hparams.n_rot;
     const uint32_t n_embd_head_qk_rope = hparams.n_rot;
     const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
     const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+
     const uint32_t kv_lora_rank = hparams.n_lora_kv;
     const uint32_t kv_lora_rank = hparams.n_lora_kv;
 
 
     ggml_tensor * cur;
     ggml_tensor * cur;

+ 1 - 0
src/models/plm.cpp

@@ -5,6 +5,7 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params &
 
 
     const uint32_t n_embd_head_qk_rope = hparams.n_rot;
     const uint32_t n_embd_head_qk_rope = hparams.n_rot;
     const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
     const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
+
     const uint32_t kv_lora_rank = hparams.n_lora_kv;
     const uint32_t kv_lora_rank = hparams.n_lora_kv;
 
 
     ggml_tensor * cur;
     ggml_tensor * cur;

+ 13 - 1
tests/test-backend-ops.cpp

@@ -6122,7 +6122,19 @@ struct test_flash_attn_ext : public test_case {
         ggml_tensor * k = create_permuted(type_KV,       hsk_padded, kv, nh,         nr23[1], true); // the K tensor is usually a view of the K cache
         ggml_tensor * k = create_permuted(type_KV,       hsk_padded, kv, nh,         nr23[1], true); // the K tensor is usually a view of the K cache
         ggml_set_name(k, "k");
         ggml_set_name(k, "k");
 
 
-        ggml_tensor * v = create_permuted(type_KV,       hsv_padded, kv, nh,         nr23[1], true); // the V tensor is usually a view of the V cache
+        ggml_tensor * v = nullptr;
+        if (hsk_padded == 576 && hsv_padded == 512) {
+            // TODO: this branch should become a separate test case parameter instead of hardcoding this for these head shapes
+
+            // in this branch, the V cache is sub-view of the K cache. this is used by some MLA-based models
+            // for more info:
+            //   - https://github.com/ggml-org/llama.cpp/pull/13435
+            //   - https://github.com/ggml-org/llama.cpp/pull/18953#issuecomment-3774948392
+            //   - https://github.com/ggml-org/llama.cpp/pull/18986
+            v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0);
+        } else {
+            v = create_permuted(type_KV,       hsv_padded, kv, nh,         nr23[1], true); // the V tensor is usually a view of the V cache
+        }
         ggml_set_name(v, "v");
         ggml_set_name(v, "v");
 
 
         ggml_tensor * m = nullptr;
         ggml_tensor * m = nullptr;