Browse Source

kv-cache : support V-less cache (#19067)

* kv-cache : support V-less cache

* cuda : better check for V_is_K_view

* cuda : improve V_is_K_view check

* graph : add comments

* hparams : refactor
Georgi Gerganov 3 days ago
parent
commit
d9c6ce46f7

+ 1 - 1
ggml/src/ggml-cuda/fattn-common.cuh

@@ -782,7 +782,7 @@ void launch_fattn(
     const ggml_tensor * K = dst->src[1];
     const ggml_tensor * V = dst->src[2];
 
-    const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data;
+    const bool V_is_K_view = V->view_src && V->view_offs == 0 && (V->view_src == K || V->view_src == K->view_src);
 
     const ggml_tensor * mask  = dst->src[3];
     const ggml_tensor * sinks = dst->src[4];

+ 1 - 1
ggml/src/ggml-cuda/fattn.cu

@@ -247,7 +247,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
         }
     }
 
-    const bool V_is_K_view = V->op == GGML_OP_VIEW && V->src[0] == K && V->data == K->data;
+    const bool V_is_K_view = V->view_src && V->view_offs == 0 && (V->view_src == K || V->view_src == K->view_src);
 
     const int cc = ggml_cuda_info().devices[device].cc;
 

+ 4 - 4
src/llama-context.cpp

@@ -793,7 +793,7 @@ float * llama_context::get_embeddings_ith(int32_t i) {
             throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs));
         }
 
-        const uint32_t n_embd_out = model.hparams.get_n_embd_out();
+        const uint32_t n_embd_out = model.hparams.n_embd_out();
         return embd + j*n_embd_out;
     } catch (const std::exception & err) {
         LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
@@ -1279,7 +1279,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
                 {
                     // extract token embeddings
                     GGML_ASSERT(embd != nullptr);
-                    const uint32_t n_embd_out = hparams.get_n_embd_out();
+                    const uint32_t n_embd_out = hparams.n_embd_out();
 
                     GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size);
                     ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float));
@@ -1688,7 +1688,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
                     {
                         // extract token embeddings
                         GGML_ASSERT(embd != nullptr);
-                        const uint32_t n_embd_out = hparams.get_n_embd_out();
+                        const uint32_t n_embd_out = hparams.n_embd_out();
                         float * embd_out = embd + n_outputs_prev*n_embd_out;
 
                         if (n_outputs) {
@@ -1821,7 +1821,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
 
     const auto n_batch    = cparams.n_batch;
     const auto n_vocab    = vocab.n_tokens();
-    const auto n_embd_out = hparams.get_n_embd_out();
+    const auto n_embd_out = hparams.n_embd_out();
 
     bool has_logits = true;
     bool has_embd   = cparams.embeddings;

+ 111 - 6
src/llama-graph.cpp

@@ -407,6 +407,27 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
     return res;
 }
 
+void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) {
+    mctx->set_input_k_idxs(self_k_idxs, ubatch);
+
+    mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
+}
+
+bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
+    const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
+
+    this->mctx = mctx;
+
+    bool res = true;
+
+    res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
+
+    res &= self_kq_mask->ne[0] == mctx->get_n_kv();
+    res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
+
+    return res;
+}
+
 void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
     mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
     mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
@@ -1596,11 +1617,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
             v = ggml_transpose(ctx0, v);
         }
 
-        // TODO: update llama_kv_cache to not store V cache in the MLA case and automatically return a view of K
-        if (v_mla) {
-            v = ggml_view_4d(ctx0, k, v->ne[0], v->ne[1], v->ne[2], v->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
-        }
-
         // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
         if (k->type == GGML_TYPE_F32) {
             k = ggml_cast(ctx0, k, GGML_TYPE_F16);
@@ -1823,9 +1839,11 @@ ggml_tensor * llm_graph_context::build_attn(
         ggml_tensor * v_cur,
         ggml_tensor * kq_b,
         ggml_tensor * sinks,
-        ggml_tensor * v_mla,
+        ggml_tensor * v_mla, // TODO: remove
             float     kq_scale,
             int       il) const {
+    GGML_ASSERT(v_mla == nullptr);
+
     // these nodes are added to the graph together so that they are not reordered
     // by doing so, the number of splits in the graph is reduced
     // expand k later to enable rope fusion which directly writes into k-v cache
@@ -1868,6 +1886,93 @@ ggml_tensor * llm_graph_context::build_attn(
     return cur;
 }
 
+static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
+           ggml_context * ctx0,
+     const llama_ubatch & ubatch,
+    const llama_hparams & hparams,
+    const llama_cparams & cparams,
+    const llama_kv_cache_context * mctx_cur) {
+
+    auto inp = std::make_unique<llm_graph_input_attn_k>(hparams, cparams, mctx_cur);
+
+    {
+        GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
+
+        const auto n_kv     = mctx_cur->get_n_kv();
+        const auto n_tokens = ubatch.n_tokens;
+        const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
+
+        inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
+
+        inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
+        ggml_set_input(inp->self_kq_mask);
+
+        inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
+    }
+
+    return inp;
+}
+
+llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const {
+    const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
+
+    auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
+
+    return (llm_graph_input_attn_k *) res->add_input(std::move(inp));
+}
+
+ggml_tensor * llm_graph_context::build_attn(
+        llm_graph_input_attn_k * inp,
+        ggml_tensor * wo,
+        ggml_tensor * wo_b,
+        ggml_tensor * q_cur,
+        ggml_tensor * k_cur,
+        ggml_tensor * v_cur,
+        ggml_tensor * kq_b,
+        ggml_tensor * sinks,
+        ggml_tensor * v_mla,
+            float     kq_scale,
+            int       il) const {
+    // these nodes are added to the graph together so that they are not reordered
+    // by doing so, the number of splits in the graph is reduced
+    // expand k later to enable rope fusion which directly writes into k-v cache
+    ggml_build_forward_expand(gf, q_cur);
+    ggml_build_forward_expand(gf, v_cur);
+    ggml_build_forward_expand(gf, k_cur);
+
+    const auto * mctx_cur = inp->mctx;
+
+    // store to KV cache
+    {
+        const auto & k_idxs = inp->get_k_idxs();
+
+        ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
+    }
+
+    const auto & kq_mask = inp->get_kq_mask();
+
+    ggml_tensor * q = q_cur;
+    ggml_tensor * k = mctx_cur->get_k(ctx0, il);
+    ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
+
+    ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il);
+    cb(cur, "kqv_out", il);
+
+    if (wo) {
+        cur = build_lora_mm(wo, cur);
+        if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) {
+            // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators
+            ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
+        }
+    }
+
+    if (wo_b) {
+        cur = ggml_add(ctx0, cur, wo_b);
+    }
+
+    return cur;
+}
+
 ggml_tensor * llm_graph_context::build_attn(
         llm_graph_input_attn_kv_iswa * inp,
         ggml_tensor * wo,

+ 48 - 0
src/llama-graph.h

@@ -317,6 +317,39 @@ public:
     const llama_kv_cache_context * mctx;
 };
 
+// V-less input for the KV cache
+// ref: https://github.com/ggml-org/llama.cpp/pull/19067
+class llm_graph_input_attn_k : public llm_graph_input_i {
+public:
+    llm_graph_input_attn_k(
+            const llama_hparams & hparams,
+            const llama_cparams & cparams,
+            const llama_kv_cache_context * mctx) :
+        hparams(hparams),
+        cparams(cparams),
+        mctx(mctx) {
+    }
+    ~llm_graph_input_attn_k() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    bool can_reuse(const llm_graph_params & params) override;
+
+    ggml_tensor * get_k_idxs() const { return self_k_idxs; }
+
+    ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
+
+    ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
+
+    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
+    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch/n_stream, 1, n_stream]
+
+    const llama_hparams hparams;
+    const llama_cparams cparams;
+
+    const llama_kv_cache_context * mctx;
+};
+
 class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
 public:
     llm_graph_input_attn_kv_iswa(
@@ -833,6 +866,21 @@ struct llm_graph_context {
             ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
             ggml_tensor * kq_b,
             ggml_tensor * sinks, // [n_head_q]
+            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] // TODO: remove
+                  float   kq_scale,
+                    int   il) const;
+
+    llm_graph_input_attn_k  * build_attn_inp_k() const;
+
+    ggml_tensor * build_attn(
+            llm_graph_input_attn_k * inp,
+            ggml_tensor * wo,
+            ggml_tensor * wo_b,
+            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
+            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
+            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
+            ggml_tensor * kq_b,
+            ggml_tensor * sinks, // [n_head_q]
             ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
                   float   kq_scale,
                     int   il) const;

+ 17 - 2
src/llama-hparams.cpp

@@ -72,8 +72,8 @@ uint32_t llama_hparams::n_embd_inp() const {
     return n_embd_inp;
 }
 
-uint32_t llama_hparams::get_n_embd_out() const {
-    return n_embd_out > 0 ? n_embd_out : n_embd;
+uint32_t llama_hparams::n_embd_out() const {
+    return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd;
 }
 
 uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const {
@@ -175,6 +175,21 @@ bool llama_hparams::is_swa(uint32_t il) const {
     GGML_ABORT("fatal error");
 }
 
+bool llama_hparams::is_mla() const {
+    assert((n_embd_head_k_mla_impl == 0 && n_embd_head_v_mla_impl == 0) ||
+           (n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0));
+
+    return n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0;
+}
+
+uint32_t llama_hparams::n_embd_head_k_mla() const {
+    return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k;
+}
+
+uint32_t llama_hparams::n_embd_head_v_mla() const {
+    return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v;
+}
+
 bool llama_hparams::has_kv(uint32_t il) const {
     if (n_layer_kv_from_start >= 0) {
         if (il < (uint32_t) n_layer_kv_from_start) {

+ 10 - 4
src/llama-hparams.h

@@ -53,8 +53,8 @@ struct llama_hparams {
     uint32_t n_rel_attn_bkts = 0;
 
     // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA
-    uint32_t n_embd_head_k_mla = 0;
-    uint32_t n_embd_head_v_mla = 0;
+    uint32_t n_embd_head_k_mla_impl = 0;
+    uint32_t n_embd_head_v_mla_impl = 0;
 
     // for WavTokenizer
     struct llama_hparams_posnet   posnet;
@@ -164,7 +164,7 @@ struct llama_hparams {
     uint32_t n_cls_out = 1;
 
     // output embedding dimension (0 = use n_embd)
-    uint32_t n_embd_out = 0;
+    uint32_t n_embd_out_impl = 0;
 
     // llama4 smallthinker
     uint32_t n_moe_layer_step        = 0;
@@ -239,7 +239,7 @@ struct llama_hparams {
     uint32_t n_embd_inp() const;
 
     // dimension of output embeddings
-    uint32_t get_n_embd_out() const;
+    uint32_t n_embd_out() const;
 
     // dimension of key embeddings across all k-v heads
     uint32_t n_embd_k_gqa(uint32_t il = 0) const;
@@ -269,6 +269,12 @@ struct llama_hparams {
 
     bool is_swa(uint32_t il) const;
 
+    // note: currently only support if either all or none of the layers are MLA
+    bool is_mla() const;
+
+    uint32_t n_embd_head_k_mla() const;
+    uint32_t n_embd_head_v_mla() const;
+
     bool has_kv(uint32_t il) const;
 
     // number of layers for which has_kv() returns true

+ 28 - 8
src/llama-kv-cache.cpp

@@ -97,6 +97,8 @@ llama_kv_cache::llama_kv_cache(
                 __func__, hparams.n_embd_v_gqa_max());
     }
 
+    const bool is_mla = hparams.is_mla();
+
     for (uint32_t il = 0; il < hparams.n_layer; il++) {
         if (!hparams.has_kv(il)) {
             LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il);
@@ -130,18 +132,21 @@ llama_kv_cache::llama_kv_cache(
             throw std::runtime_error("failed to create ggml context for kv cache");
         }
 
-        ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
-        ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
+        const bool has_k = true;
+        const bool has_v = !is_mla;
+
+        ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr;
+        ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr;
 
-        ggml_format_name(k, "cache_k_l%d", il);
-        ggml_format_name(v, "cache_v_l%d", il);
+        has_k && ggml_format_name(k, "cache_k_l%d", il);
+        has_v && ggml_format_name(v, "cache_v_l%d", il);
 
         std::vector<ggml_tensor *> k_stream;
         std::vector<ggml_tensor *> v_stream;
 
         for (uint32_t s = 0; s < n_stream; ++s) {
-            k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]));
-            v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]));
+            k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]) : nullptr);
+            v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]) : nullptr);
         }
 
         map_layer_ids[il] = layers.size();
@@ -647,7 +652,10 @@ bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_co
                 const auto & layer = layers[il];
 
                 ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]);
-                ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
+
+                if (layer.v_stream[ssrc]) {
+                    ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]);
+                }
             }
         }
     }
@@ -1516,7 +1524,7 @@ size_t llama_kv_cache::size_v_bytes() const {
     size_t size_v_bytes = 0;
 
     for (const auto & layer : layers) {
-        size_v_bytes += ggml_nbytes(layer.v);
+        size_v_bytes += layer.v ? ggml_nbytes(layer.v) : 0;
     }
 
     return size_v_bytes;
@@ -1798,6 +1806,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
             auto * v = layer.v_stream[cr.strm];
+            if (!v) {
+                continue;
+            }
 
             // Write value type
             const int32_t v_type_i = (int32_t) v->type;
@@ -1824,6 +1835,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
             auto * v = layer.v_stream[cr.strm];
+            if (!v) {
+                continue;
+            }
 
             // Write value type
             const int32_t v_type_i = (int32_t) v->type;
@@ -2027,6 +2041,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
             auto * v = layer.v_stream[strm];
+            if (!v) {
+                continue;
+            }
 
             // Read type of value
             int32_t v_type_i_ref;
@@ -2068,6 +2085,9 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
             const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
 
             auto * v = layer.v_stream[strm];
+            if (!v) {
+                continue;
+            }
 
             // Read type of value
             int32_t v_type_i_ref;

+ 2 - 2
src/llama-model-saver.cpp

@@ -146,8 +146,8 @@ void llama_model_saver::add_kv_from_model() {
     add_kv(LLM_KV_VOCAB_SIZE,                        vocab.n_tokens());
     add_kv(LLM_KV_CONTEXT_LENGTH,                    hparams.n_ctx_train);
     add_kv(LLM_KV_EMBEDDING_LENGTH,                  hparams.n_embd);
-    if (hparams.n_embd_out > 0) {
-        add_kv(LLM_KV_EMBEDDING_LENGTH_OUT,          hparams.n_embd_out);
+    if (hparams.n_embd_out_impl > 0) {
+        add_kv(LLM_KV_EMBEDDING_LENGTH_OUT,          hparams.n_embd_out_impl);
     }
     add_kv(LLM_KV_BLOCK_COUNT,                       hparams.n_layer);
     add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT,         hparams.n_layer_dense_lead);

+ 14 - 16
src/llama-model.cpp

@@ -512,7 +512,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
 
     ml.get_key(LLM_KV_CONTEXT_LENGTH,          hparams.n_ctx_train);
     ml.get_key(LLM_KV_EMBEDDING_LENGTH,        hparams.n_embd);
-    ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT,    hparams.n_embd_out, false);
+    ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT,    hparams.n_embd_out_impl, false);
     ml.get_key(LLM_KV_BLOCK_COUNT,             hparams.n_layer);
     ml.get_key(LLM_KV_EXPERT_COUNT,            hparams.n_expert,        false);
     ml.get_key(LLM_KV_EXPERT_USED_COUNT,       hparams.n_expert_used,   false);
@@ -1697,15 +1697,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
         case LLM_ARCH_DEEPSEEK2:
             {
                 // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
-                bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
+                const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
+
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
                 ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT,   hparams.n_layer_dense_lead);
                 if (!is_lite) {
                     ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q);
                 }
                 ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK,     hparams.n_lora_kv);
-                ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA,   hparams.n_embd_head_k_mla, false);
-                ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false);
+                ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA,   hparams.n_embd_head_k_mla_impl, false);
+                ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false);
                 ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
                 ml.get_key(LLM_KV_EXPERT_SHARED_COUNT,        hparams.n_expert_shared);
                 ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE,       hparams.expert_weights_scale, false);
@@ -4909,14 +4910,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                 } break;
             case LLM_ARCH_DEEPSEEK2:
                 {
-                    // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
-                    const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
-
-                    const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
+                    const bool is_mla = hparams.is_mla();
 
                     // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA
-                    const int64_t n_embd_head_k_mla = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k;
-                    const int64_t n_embd_head_v_mla = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v;
+                    const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
+                    const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
 
                     const int64_t n_embd_head_qk_rope = hparams.n_rot;
                     const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;
@@ -4941,13 +4939,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         auto & layer = layers[i];
 
                         layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
-                        if (!is_lite) {
+                        if (q_lora_rank > 0) {
                             layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0);
                         }
 
                         layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
 
-                        if (!is_lite) {
+                        if (q_lora_rank > 0) {
                             layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
                             layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0);
                         } else {
@@ -6597,7 +6595,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                     }
 
                     // for LFM2-ColBert-350M
-                    dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.get_n_embd_out()}, TENSOR_NOT_REQUIRED);
+                    dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED);
                 } break;
             case LLM_ARCH_SMALLTHINKER:
                 {
@@ -7316,8 +7314,8 @@ void llama_model::print_info() const {
         LLAMA_LOG_INFO("%s: n_layer_dense_lead    = %d\n",     __func__, hparams.n_layer_dense_lead);
         LLAMA_LOG_INFO("%s: n_lora_q              = %d\n",     __func__, hparams.n_lora_q);
         LLAMA_LOG_INFO("%s: n_lora_kv             = %d\n",     __func__, hparams.n_lora_kv);
-        LLAMA_LOG_INFO("%s: n_embd_head_k_mla     = %d\n",     __func__, hparams.n_embd_head_k_mla);
-        LLAMA_LOG_INFO("%s: n_embd_head_v_mla     = %d\n",     __func__, hparams.n_embd_head_v_mla);
+        LLAMA_LOG_INFO("%s: n_embd_head_k_mla     = %d\n",     __func__, hparams.n_embd_head_k_mla());
+        LLAMA_LOG_INFO("%s: n_embd_head_v_mla     = %d\n",     __func__, hparams.n_embd_head_v_mla());
         LLAMA_LOG_INFO("%s: n_ff_exp              = %d\n",     __func__, hparams.n_ff_exp);
         LLAMA_LOG_INFO("%s: n_expert_shared       = %d\n",     __func__, hparams.n_expert_shared);
         LLAMA_LOG_INFO("%s: expert_weights_scale  = %.1f\n",   __func__, hparams.expert_weights_scale);
@@ -8162,7 +8160,7 @@ int32_t llama_model_n_embd_inp(const llama_model * model) {
 }
 
 int32_t llama_model_n_embd_out(const llama_model * model) {
-    return model->hparams.get_n_embd_out();
+    return model->hparams.n_embd_out();
 }
 
 int32_t llama_model_n_layer(const llama_model * model) {

+ 10 - 9
src/models/deepseek2.cpp

@@ -2,14 +2,11 @@
 
 llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) :
     llm_graph_context(params) {
-    // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
-    bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
-
-    const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0);
+    const bool is_mla = hparams.is_mla();
 
     // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA
-    const int64_t n_embd_head_k = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k;
-    const int64_t n_embd_head_v = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v;
+    const int64_t n_embd_head_k = hparams.n_embd_head_k_mla();
+    const int64_t n_embd_head_v = hparams.n_embd_head_v_mla();
 
     const int64_t n_embd_head_qk_rope = hparams.n_rot;
     const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope;
@@ -43,7 +40,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
     // inp_pos - contains the positions
     ggml_tensor * inp_pos = build_inp_pos();
 
-    auto * inp_attn = build_attn_inp_kv();
+    auto * inp_attn_kv = !is_mla ? build_attn_inp_kv() : nullptr;
+    auto * inp_attn_k  =  is_mla ? build_attn_inp_k()  : nullptr;
 
     ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -57,6 +55,9 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
         // self_attention
         {
             ggml_tensor * q = NULL;
+
+            const bool is_lite = model.layers[il].wq;
+
             if (!is_lite) {
                 q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
                 cb(q, "q", il);
@@ -145,7 +146,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
                 }
 
                 // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group)
-                cur = build_attn(inp_attn,
+                cur = build_attn(inp_attn_k,
                         model.layers[il].wo, NULL,
                         Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il);
             } else {
@@ -182,7 +183,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
                 }
 
                 // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups)
-                cur = build_attn(inp_attn,
+                cur = build_attn(inp_attn_kv,
                             model.layers[il].wo, NULL,
                             Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
             }