Explorar el Código

kv-cache : drop the "unified" prefix (#15467)

* kv-cache : drop the "unified" prefix

ggml-ci

* cont : fix comment [no ci]
Georgi Gerganov hace 4 meses
padre
commit
715a6db02c

+ 0 - 4
include/llama.h

@@ -64,8 +64,6 @@ extern "C" {
 
     typedef struct llama_memory_i * llama_memory_t;
 
-    struct llama_kv_cache; // DEPRECATED (use llama_memory instead)
-
     typedef int32_t llama_pos;
     typedef int32_t llama_token;
     typedef int32_t llama_seq_id;
@@ -469,8 +467,6 @@ extern "C" {
     LLAMA_API           llama_memory_t   llama_get_memory  (const struct llama_context * ctx);
     LLAMA_API  enum llama_pooling_type   llama_pooling_type(const struct llama_context * ctx); // TODO: rename to llama_get_pooling_type
 
-    DEPRECATED(LLAMA_API struct llama_kv_cache * llama_get_kv_self(struct llama_context * ctx), "use llama_get_memory instead");
-
     LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
     LLAMA_API enum llama_rope_type       llama_model_rope_type(const struct llama_model * model);
 

+ 2 - 2
src/CMakeLists.txt

@@ -20,8 +20,8 @@ add_library(llama
             llama-hparams.cpp
             llama-impl.cpp
             llama-io.cpp
-            llama-kv-cache-unified.cpp
-            llama-kv-cache-unified-iswa.cpp
+            llama-kv-cache.cpp
+            llama-kv-cache-iswa.cpp
             llama-memory.cpp
             llama-memory-hybrid.cpp
             llama-memory-recurrent.cpp

+ 0 - 5
src/llama-context.cpp

@@ -2338,11 +2338,6 @@ const llama_model * llama_get_model(const llama_context * ctx) {
     return &ctx->get_model();
 }
 
-// deprecated
-llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
-    return dynamic_cast<llama_kv_cache *>(ctx->get_memory());
-}
-
 // deprecated
 void llama_kv_self_update(llama_context * ctx) {
     ctx->kv_self_update(false);

+ 27 - 27
src/llama-graph.cpp

@@ -4,8 +4,8 @@
 #include "llama-batch.h"
 #include "llama-cparams.h"
 
-#include "llama-kv-cache-unified.h"
-#include "llama-kv-cache-unified-iswa.h"
+#include "llama-kv-cache.h"
+#include "llama-kv-cache-iswa.h"
 #include "llama-memory-hybrid.h"
 #include "llama-memory-recurrent.h"
 
@@ -277,7 +277,7 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
                 for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
                     const llama_seq_id s0 = ubatch->seq_id[i0][0];
 
-                    // TODO: reimplement this like in llama_kv_cache_unified
+                    // TODO: reimplement this like in llama_kv_cache
                     if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
                         if (hparams.use_alibi) {
                             f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
@@ -294,15 +294,15 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
     }
 }
 
-void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
+void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) {
     mctx->set_input_k_idxs(self_k_idxs, ubatch);
     mctx->set_input_v_idxs(self_v_idxs, ubatch);
 
     mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
 }
 
-bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) {
-    const auto * mctx = static_cast<const llama_kv_cache_unified_context *>(params.mctx);
+bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
+    const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx);
 
     this->mctx = mctx;
 
@@ -319,7 +319,7 @@ bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params)
     return res;
 }
 
-void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
+void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) {
     mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
     mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
 
@@ -331,8 +331,8 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
     mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
 }
 
-bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) {
-    const auto * mctx = static_cast<const llama_kv_cache_unified_iswa_context *>(params.mctx);
+bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
+    const auto * mctx = static_cast<const llama_kv_cache_iswa_context *>(params.mctx);
 
     this->mctx = mctx;
 
@@ -1186,7 +1186,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
 }
 
 ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
-    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
+    const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
 
     auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
 
@@ -1399,17 +1399,17 @@ ggml_tensor * llm_graph_context::build_attn(
     return cur;
 }
 
-static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
+static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
            ggml_context * ctx0,
      const llama_ubatch & ubatch,
     const llama_hparams & hparams,
     const llama_cparams & cparams,
-    const llama_kv_cache_unified_context * mctx_cur) {
+    const llama_kv_cache_context * mctx_cur) {
 
-    auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
+    auto inp = std::make_unique<llm_graph_input_attn_kv>(hparams, cparams, mctx_cur);
 
     {
-        GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
+        GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
 
         const auto n_kv     = mctx_cur->get_n_kv();
         const auto n_tokens = ubatch.n_tokens;
@@ -1427,16 +1427,16 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
     return inp;
 }
 
-llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
-    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
+llm_graph_input_attn_kv * llm_graph_context::build_attn_inp_kv() const {
+    const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx);
 
-    auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
+    auto inp = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
 
-    return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
+    return (llm_graph_input_attn_kv *) res->add_input(std::move(inp));
 }
 
 ggml_tensor * llm_graph_context::build_attn(
-        llm_graph_input_attn_kv_unified * inp,
+        llm_graph_input_attn_kv * inp,
         ggml_tensor * wo,
         ggml_tensor * wo_b,
         ggml_tensor * q_cur,
@@ -1488,7 +1488,7 @@ ggml_tensor * llm_graph_context::build_attn(
 }
 
 ggml_tensor * llm_graph_context::build_attn(
-        llm_graph_input_attn_kv_unified_iswa * inp,
+        llm_graph_input_attn_kv_iswa * inp,
         ggml_tensor * wo,
         ggml_tensor * wo_b,
         ggml_tensor * q_cur,
@@ -1513,7 +1513,7 @@ ggml_tensor * llm_graph_context::build_attn(
 }
 
 ggml_tensor * llm_graph_context::build_attn_with_sinks(
-        llm_graph_input_attn_kv_unified_iswa * inp,
+        llm_graph_input_attn_kv_iswa * inp,
         ggml_tensor * wo,
         ggml_tensor * wo_b,
         ggml_tensor * q_cur,
@@ -1636,10 +1636,10 @@ ggml_tensor * llm_graph_context::build_attn(
 // TODO: maybe separate the inner implementation into a separate function
 //       like with the non-sliding window equivalent
 //       once sliding-window hybrid caches are a thing.
-llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
-    const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
+llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const {
+    const auto * mctx_cur = static_cast<const llama_kv_cache_iswa_context *>(mctx);
 
-    auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
+    auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
 
     const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
 
@@ -1656,7 +1656,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
     }
 
     {
-        GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
+        GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
 
         const auto n_kv = mctx_cur->get_swa()->get_n_kv();
 
@@ -1669,7 +1669,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
         inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
     }
 
-    return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
+    return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
 }
 
 ggml_tensor * llm_graph_context::build_rs(
@@ -1792,7 +1792,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
     const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
 
     auto inp_rs   = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
-    auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
+    auto inp_attn = build_attn_inp_kv_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
 
     auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
 

+ 25 - 25
src/llama-graph.h

@@ -19,8 +19,8 @@ struct llama_cparams;
 
 struct llama_memory_context_i;
 
-class llama_kv_cache_unified_context;
-class llama_kv_cache_unified_iswa_context;
+class llama_kv_cache_context;
+class llama_kv_cache_iswa_context;
 class llama_memory_recurrent_context;
 class llama_memory_hybrid_context;
 
@@ -152,7 +152,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
 public:
     llm_graph_input_pos_bucket_kv(
             const llama_hparams & hparams,
-            const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
+            const llama_kv_cache_context * mctx) : hparams(hparams), mctx(mctx) {}
     virtual ~llm_graph_input_pos_bucket_kv() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
@@ -161,7 +161,7 @@ public:
 
     const llama_hparams hparams;
 
-    const llama_kv_cache_unified_context * mctx;
+    const llama_kv_cache_context * mctx;
 };
 
 class llm_graph_input_out_ids : public llm_graph_input_i {
@@ -257,17 +257,17 @@ public:
     const llama_cparams cparams;
 };
 
-class llm_graph_input_attn_kv_unified : public llm_graph_input_i {
+class llm_graph_input_attn_kv : public llm_graph_input_i {
 public:
-    llm_graph_input_attn_kv_unified(
+    llm_graph_input_attn_kv(
             const llama_hparams & hparams,
             const llama_cparams & cparams,
-            const llama_kv_cache_unified_context * mctx) :
+            const llama_kv_cache_context * mctx) :
         hparams(hparams),
         cparams(cparams),
         mctx(mctx) {
     }
-    ~llm_graph_input_attn_kv_unified() = default;
+    ~llm_graph_input_attn_kv() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
@@ -290,20 +290,20 @@ public:
     const llama_hparams hparams;
     const llama_cparams cparams;
 
-    const llama_kv_cache_unified_context * mctx;
+    const llama_kv_cache_context * mctx;
 };
 
-class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
+class llm_graph_input_attn_kv_iswa : public llm_graph_input_i {
 public:
-    llm_graph_input_attn_kv_unified_iswa(
+    llm_graph_input_attn_kv_iswa(
             const llama_hparams & hparams,
             const llama_cparams & cparams,
-            const llama_kv_cache_unified_iswa_context * mctx) :
+            const llama_kv_cache_iswa_context * mctx) :
         hparams(hparams),
         cparams(cparams),
         mctx(mctx) {
     }
-    ~llm_graph_input_attn_kv_unified_iswa() = default;
+    ~llm_graph_input_attn_kv_iswa() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
@@ -330,7 +330,7 @@ public:
     const llama_hparams hparams;
     const llama_cparams cparams;
 
-    const llama_kv_cache_unified_iswa_context * mctx;
+    const llama_kv_cache_iswa_context * mctx;
 };
 
 class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -351,7 +351,7 @@ public:
 class llm_graph_input_mem_hybrid : public llm_graph_input_i {
 public:
     llm_graph_input_mem_hybrid(
-            std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
+            std::unique_ptr<llm_graph_input_attn_kv> inp_attn,
             std::unique_ptr<llm_graph_input_rs>              inp_rs,
             const llama_memory_hybrid_context *              mctx) :
         inp_attn(std::move(inp_attn)),
@@ -361,11 +361,11 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
-    std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
-    std::unique_ptr<llm_graph_input_rs>              inp_rs;
+    std::unique_ptr<llm_graph_input_attn_kv> inp_attn;
+    std::unique_ptr<llm_graph_input_rs>      inp_rs;
 
-    llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
-    llm_graph_input_rs              * get_recr() const { return inp_rs.get(); }
+    llm_graph_input_attn_kv * get_attn() const { return inp_attn.get(); }
+    llm_graph_input_rs      * get_recr() const { return inp_rs.get(); }
 
     const llama_memory_hybrid_context * mctx;
 };
@@ -703,10 +703,10 @@ struct llm_graph_context {
                   float   kq_scale,
                     int   il) const;
 
-    llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
+    llm_graph_input_attn_kv * build_attn_inp_kv() const;
 
     ggml_tensor * build_attn(
-            llm_graph_input_attn_kv_unified * inp,
+            llm_graph_input_attn_kv * inp,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
             ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -717,11 +717,11 @@ struct llm_graph_context {
                   float   kq_scale,
                     int   il) const;
 
-    llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
+    llm_graph_input_attn_kv_iswa * build_attn_inp_kv_iswa() const;
 
     // note: if k_cur or v_cur are not provided, they will not be stored in the memory
     ggml_tensor * build_attn(
-            llm_graph_input_attn_kv_unified_iswa * inp,
+            llm_graph_input_attn_kv_iswa * inp,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
             ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -734,7 +734,7 @@ struct llm_graph_context {
 
     // TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
     ggml_tensor * build_attn_with_sinks(
-            llm_graph_input_attn_kv_unified_iswa * inp,
+            llm_graph_input_attn_kv_iswa * inp,
             ggml_tensor * wo,
             ggml_tensor * wo_b,
             ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
@@ -765,7 +765,7 @@ struct llm_graph_context {
     //
 
     // TODO: move this implementation to llama_memory_recurrent.
-    //       this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
+    //       this is analogous to llama_kv_cache::cpy_k / cpy_v
     //       when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
     //         implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
     //         `llama_memory_recurrent`

+ 47 - 47
src/llama-kv-cache-unified-iswa.cpp → src/llama-kv-cache-iswa.cpp

@@ -1,4 +1,4 @@
-#include "llama-kv-cache-unified-iswa.h"
+#include "llama-kv-cache-iswa.h"
 
 #include "llama-impl.h"
 #include "llama-batch.h"
@@ -8,10 +8,10 @@
 #include <cassert>
 
 //
-// llama_kv_cache_unified_iswa
+// llama_kv_cache_iswa
 //
 
-llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
+llama_kv_cache_iswa::llama_kv_cache_iswa(
         const llama_model & model,
                 ggml_type   type_k,
                 ggml_type   type_v,
@@ -23,8 +23,8 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
                  uint32_t   n_seq_max,
                  uint32_t   n_ubatch,
                  uint32_t   n_pad) : hparams(model.hparams), unified(unified) {
-    llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
-    llama_kv_cache_unified::layer_filter_cb filter_swa  = [&](int32_t il) { return  model.hparams.is_swa(il); };
+    llama_kv_cache::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
+    llama_kv_cache::layer_filter_cb filter_swa  = [&](int32_t il) { return  model.hparams.is_swa(il); };
 
     const uint32_t size_base = kv_size;
 
@@ -40,25 +40,25 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
 
     LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base);
 
-    kv_base = std::make_unique<llama_kv_cache_unified>(
+    kv_base = std::make_unique<llama_kv_cache>(
             model, std::move(filter_base), type_k, type_v,
             v_trans, offload, unified, size_base, n_seq_max, n_pad,
             0, LLAMA_SWA_TYPE_NONE);
 
     LLAMA_LOG_INFO("%s: creating     SWA KV cache, size = %u cells\n", __func__, size_swa);
 
-    kv_swa = std::make_unique<llama_kv_cache_unified>(
+    kv_swa = std::make_unique<llama_kv_cache>(
             model, std::move(filter_swa), type_k, type_v,
             v_trans, offload, unified, size_swa, n_seq_max, n_pad,
             hparams.n_swa, hparams.swa_type);
 }
 
-void llama_kv_cache_unified_iswa::clear(bool data) {
+void llama_kv_cache_iswa::clear(bool data) {
     kv_base->clear(data);
     kv_swa ->clear(data);
 }
 
-bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+bool llama_kv_cache_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
     bool res = true;
 
     res = res & kv_base->seq_rm(seq_id, p0, p1);
@@ -67,36 +67,36 @@ bool llama_kv_cache_unified_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llam
     return res;
 }
 
-void llama_kv_cache_unified_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
+void llama_kv_cache_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
     kv_base->seq_cp(seq_id_src, seq_id_dst, p0, p1);
     kv_swa ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
 }
 
-void llama_kv_cache_unified_iswa::seq_keep(llama_seq_id seq_id) {
+void llama_kv_cache_iswa::seq_keep(llama_seq_id seq_id) {
     kv_base->seq_keep(seq_id);
     kv_swa ->seq_keep(seq_id);
 }
 
-void llama_kv_cache_unified_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+void llama_kv_cache_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
     kv_base->seq_add(seq_id, p0, p1, shift);
     kv_swa ->seq_add(seq_id, p0, p1, shift);
 }
 
-void llama_kv_cache_unified_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
+void llama_kv_cache_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
     kv_base->seq_div(seq_id, p0, p1, d);
     kv_swa ->seq_div(seq_id, p0, p1, d);
 }
 
-llama_pos llama_kv_cache_unified_iswa::seq_pos_min(llama_seq_id seq_id) const {
+llama_pos llama_kv_cache_iswa::seq_pos_min(llama_seq_id seq_id) const {
     // the base cache is a superset of the SWA cache, so we can just check the SWA cache
     return kv_swa->seq_pos_min(seq_id);
 }
 
-llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
+llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
     return kv_swa->seq_pos_max(seq_id);
 }
 
-llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
+llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
     GGML_UNUSED(embd_all);
 
     // first try simple split
@@ -136,7 +136,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
 
         assert(sinfos_base.size() == sinfos_swa.size());
 
-        return std::make_unique<llama_kv_cache_unified_iswa_context>(
+        return std::make_unique<llama_kv_cache_iswa_context>(
                 this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
     } while (false);
 
@@ -172,29 +172,29 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
 
         assert(sinfos_base.size() == sinfos_swa.size());
 
-        return std::make_unique<llama_kv_cache_unified_iswa_context>(
+        return std::make_unique<llama_kv_cache_iswa_context>(
                 this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
     } while (false);
 
     // TODO: if we fail again, we should attempt different splitting strategies
     //       but to do that properly, we first have to refactor the batches to be more flexible
 
-    return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+    return std::make_unique<llama_kv_cache_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 }
 
-llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
-    return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
+llama_memory_context_ptr llama_kv_cache_iswa::init_full() {
+    return std::make_unique<llama_kv_cache_iswa_context>(this);
 }
 
-llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
-    return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
+llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, bool optimize) {
+    return std::make_unique<llama_kv_cache_iswa_context>(this, lctx, optimize);
 }
 
-bool llama_kv_cache_unified_iswa::get_can_shift() const {
+bool llama_kv_cache_iswa::get_can_shift() const {
     return kv_base->get_size() == kv_swa->get_size();
 }
 
-void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
+void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
     if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
         kv_base->state_write(io, seq_id, flags);
     }
@@ -202,7 +202,7 @@ void llama_kv_cache_unified_iswa::state_write(llama_io_write_i & io, llama_seq_i
     kv_swa->state_write(io, seq_id, flags);
 }
 
-void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
+void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
     if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
         kv_base->state_read(io, seq_id, flags);
     }
@@ -210,29 +210,29 @@ void llama_kv_cache_unified_iswa::state_read(llama_io_read_i & io, llama_seq_id
     kv_swa->state_read(io, seq_id, flags);
 }
 
-llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_base() const {
+llama_kv_cache * llama_kv_cache_iswa::get_base() const {
     return kv_base.get();
 }
 
-llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
+llama_kv_cache * llama_kv_cache_iswa::get_swa() const {
     return kv_swa.get();
 }
 
 //
-// llama_kv_cache_unified_iswa_context
+// llama_kv_cache_iswa_context
 //
 
-llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
+llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(llama_memory_status status) : status(status) {}
 
-llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
-        llama_kv_cache_unified_iswa * kv) :
+llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
+        llama_kv_cache_iswa * kv) :
     ctx_base(kv->get_base()->init_full()),
     ctx_swa (kv->get_swa ()->init_full()),
     status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
 }
 
-llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
-        llama_kv_cache_unified_iswa * kv,
+llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
+        llama_kv_cache_iswa * kv,
         llama_context * lctx,
         bool optimize) :
     ctx_base(kv->get_base()->init_update(lctx, optimize)),
@@ -240,21 +240,21 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
     status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
 }
 
-llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
-        llama_kv_cache_unified_iswa * kv,
+llama_kv_cache_iswa_context::llama_kv_cache_iswa_context(
+        llama_kv_cache_iswa * kv,
         slot_info_vec_t sinfos_base,
         slot_info_vec_t sinfos_swa,
         std::vector<llama_ubatch> ubatches) :
     ubatches(std::move(ubatches)),
     // note: here we copy the ubatches. not sure if this is ideal
-    ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
-    ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa),  this->ubatches)),
+    ctx_base(new llama_kv_cache_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
+    ctx_swa (new llama_kv_cache_context(kv->get_swa (), std::move(sinfos_swa),  this->ubatches)),
     status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
 }
 
-llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
+llama_kv_cache_iswa_context:: ~llama_kv_cache_iswa_context() = default;
 
-bool llama_kv_cache_unified_iswa_context::next() {
+bool llama_kv_cache_iswa_context::next() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     ctx_base->next();
@@ -267,7 +267,7 @@ bool llama_kv_cache_unified_iswa_context::next() {
     return true;
 }
 
-bool llama_kv_cache_unified_iswa_context::apply() {
+bool llama_kv_cache_iswa_context::apply() {
     assert(!llama_memory_status_is_fail(status));
 
     bool res = true;
@@ -278,24 +278,24 @@ bool llama_kv_cache_unified_iswa_context::apply() {
     return res;
 }
 
-llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
+llama_memory_status llama_kv_cache_iswa_context::get_status() const {
     return status;
 }
 
-const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
+const llama_ubatch & llama_kv_cache_iswa_context::get_ubatch() const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     return ubatches[i_next];
 }
 
-const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
+const llama_kv_cache_context * llama_kv_cache_iswa_context::get_base() const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
-    return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
+    return static_cast<const llama_kv_cache_context *>(ctx_base.get());
 }
 
-const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa()  const {
+const llama_kv_cache_context * llama_kv_cache_iswa_context::get_swa()  const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
-    return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
+    return static_cast<const llama_kv_cache_context *>(ctx_swa.get());
 }

+ 26 - 26
src/llama-kv-cache-unified-iswa.h → src/llama-kv-cache-iswa.h

@@ -1,32 +1,32 @@
 #pragma once
 
-#include "llama-kv-cache-unified.h"
+#include "llama-kv-cache.h"
 
 #include <vector>
 
 //
-// llama_kv_cache_unified_iswa
+// llama_kv_cache_iswa
 //
 
-// utilizes two instances of llama_kv_cache_unified
+// utilizes two instances of llama_kv_cache
 //   the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
 
-class llama_kv_cache_unified_iswa : public llama_memory_i {
+class llama_kv_cache_iswa : public llama_memory_i {
 public:
-    llama_kv_cache_unified_iswa(
+    llama_kv_cache_iswa(
             const llama_model & model,
                     ggml_type   type_k,
                     ggml_type   type_v,
                          bool   v_trans,
                          bool   offload,
                          bool   swa_full,
-                         bool   unified,
+                         bool  ,
                      uint32_t   kv_size,
                      uint32_t   n_seq_max,
                      uint32_t   n_ubatch,
                      uint32_t   n_pad);
 
-    ~llama_kv_cache_unified_iswa() = default;
+    ~llama_kv_cache_iswa() = default;
 
     //
     // llama_memory_i
@@ -60,46 +60,46 @@ public:
     void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
 
     //
-    // llama_kv_cache_unified_iswa specific API
+    // llama_kv_cache_iswa specific API
     //
 
-    llama_kv_cache_unified * get_base() const;
-    llama_kv_cache_unified * get_swa () const;
+    llama_kv_cache * get_base() const;
+    llama_kv_cache * get_swa () const;
 
 private:
     const llama_hparams & hparams;
 
     const bool unified;
 
-    std::unique_ptr<llama_kv_cache_unified> kv_base;
-    std::unique_ptr<llama_kv_cache_unified> kv_swa;
+    std::unique_ptr<llama_kv_cache> kv_base;
+    std::unique_ptr<llama_kv_cache> kv_swa;
 };
 
-class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
+class llama_kv_cache_iswa_context : public llama_memory_context_i {
 public:
-    using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
+    using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
 
     // used for errors
-    llama_kv_cache_unified_iswa_context(llama_memory_status status);
+    llama_kv_cache_iswa_context(llama_memory_status status);
 
     // used to create a full-cache context
-    llama_kv_cache_unified_iswa_context(
-            llama_kv_cache_unified_iswa * kv);
+    llama_kv_cache_iswa_context(
+            llama_kv_cache_iswa * kv);
 
     // used to create an update context
-    llama_kv_cache_unified_iswa_context(
-            llama_kv_cache_unified_iswa * kv,
+    llama_kv_cache_iswa_context(
+            llama_kv_cache_iswa * kv,
             llama_context * lctx,
             bool optimize);
 
     // used to create a batch processing context from a batch
-    llama_kv_cache_unified_iswa_context(
-            llama_kv_cache_unified_iswa * kv,
+    llama_kv_cache_iswa_context(
+            llama_kv_cache_iswa * kv,
             slot_info_vec_t sinfos_base,
             slot_info_vec_t sinfos_swa,
             std::vector<llama_ubatch> ubatches);
 
-    virtual ~llama_kv_cache_unified_iswa_context();
+    virtual ~llama_kv_cache_iswa_context();
 
     //
     // llama_memory_context_i
@@ -112,14 +112,14 @@ public:
     const llama_ubatch & get_ubatch() const override;
 
     //
-    // llama_kv_cache_unified_iswa_context specific API
+    // llama_kv_cache_iswa_context specific API
     //
 
-    const llama_kv_cache_unified_context * get_base() const;
-    const llama_kv_cache_unified_context * get_swa()  const;
+    const llama_kv_cache_context * get_base() const;
+    const llama_kv_cache_context * get_swa()  const;
 
 private:
-    //llama_kv_cache_unified_iswa * kv;
+    //llama_kv_cache_iswa * kv;
 
     // the index of the next ubatch to process
     size_t i_next = 0;

+ 85 - 85
src/llama-kv-cache-unified.cpp → src/llama-kv-cache.cpp

@@ -1,4 +1,4 @@
-#include "llama-kv-cache-unified.h"
+#include "llama-kv-cache.h"
 
 #include "llama-impl.h"
 #include "llama-io.h"
@@ -13,10 +13,10 @@
 #include <stdexcept>
 
 //
-// llama_kv_cache_unified
+// llama_kv_cache
 //
 
-llama_kv_cache_unified::llama_kv_cache_unified(
+llama_kv_cache::llama_kv_cache(
         const llama_model &  model,
           layer_filter_cb && filter,
                 ggml_type    type_k,
@@ -209,7 +209,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
     }
 }
 
-void llama_kv_cache_unified::clear(bool data) {
+void llama_kv_cache::clear(bool data) {
     for (uint32_t s = 0; s < n_stream; ++s) {
         v_cells[s].reset();
         v_heads[s] = 0;
@@ -222,7 +222,7 @@ void llama_kv_cache_unified::clear(bool data) {
     }
 }
 
-bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
     GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
 
     if (p0 < 0) {
@@ -285,7 +285,7 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
     return true;
 }
 
-void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
+void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
     GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size());
     GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size());
 
@@ -368,7 +368,7 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id
     //}
 }
 
-void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
+void llama_kv_cache::seq_keep(llama_seq_id seq_id) {
     GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
 
     auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -390,7 +390,7 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) {
     }
 }
 
-void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
     GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
 
     auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -434,7 +434,7 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po
     head = new_head != cells.size() ? new_head : 0;
 }
 
-void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
+void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
     GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
 
     auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -467,7 +467,7 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po
     }
 }
 
-llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
+llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const {
     GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
 
     const auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -475,7 +475,7 @@ llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const {
     return cells.seq_pos_min(seq_id);
 }
 
-llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
+llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
     GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
 
     const auto & cells = v_cells[seq_to_stream[seq_id]];
@@ -483,7 +483,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
     return cells.seq_pos_max(seq_id);
 }
 
-llama_memory_context_ptr llama_kv_cache_unified::init_batch(
+llama_memory_context_ptr llama_kv_cache::init_batch(
             llama_batch_allocr & balloc,
             uint32_t n_ubatch,
             bool embd_all) {
@@ -513,18 +513,18 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
             break;
         }
 
-        return std::make_unique<llama_kv_cache_unified_context>(
+        return std::make_unique<llama_kv_cache_context>(
                 this, std::move(sinfos), std::move(ubatches));
     } while (false);
 
-    return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+    return std::make_unique<llama_kv_cache_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 }
 
-llama_memory_context_ptr llama_kv_cache_unified::init_full() {
-    return std::make_unique<llama_kv_cache_unified_context>(this);
+llama_memory_context_ptr llama_kv_cache::init_full() {
+    return std::make_unique<llama_kv_cache_context>(this);
 }
 
-llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
+llama_memory_context_ptr llama_kv_cache::init_update(llama_context * lctx, bool optimize) {
     bool do_shift = get_has_shift();
 
     defrag_info dinfo;
@@ -557,18 +557,18 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
         }
     }
 
-    return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo), std::move(sc_info));
+    return std::make_unique<llama_kv_cache_context>(this, lctx, do_shift, std::move(dinfo), std::move(sc_info));
 }
 
-llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
-    llama_kv_cache_unified::slot_info_vec_t res;
+llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ubatch> & ubatches) {
+    llama_kv_cache::slot_info_vec_t res;
 
     struct state_t {
         slot_info sinfo; // slot info for the ubatch
 
         std::vector<uint32_t> v_heads_old; // old positions of the heads, before placing the ubatch
 
-        std::vector<llama_kv_cells_unified> v_cells; // copy of the old cells, before placing the ubatch
+        std::vector<llama_kv_cells> v_cells; // copy of the old cells, before placing the ubatch
     };
 
     // remember the old state of the cells so we can restore it in the end
@@ -629,7 +629,7 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st
     return res;
 }
 
-bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) {
+bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) {
     bool updated = false;
 
     auto * sched = lctx->get_sched();
@@ -749,7 +749,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
     return updated;
 }
 
-llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
+llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, bool cont) const {
 
     if (debug > 0) {
         for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
@@ -948,7 +948,7 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
     return res;
 }
 
-void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
+void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
     // keep track of the max sequence position that we would overwrite with this ubatch
     // for non-SWA cache, this would be always empty
     llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
@@ -1013,21 +1013,21 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u
     }
 }
 
-bool llama_kv_cache_unified::get_can_shift() const {
+bool llama_kv_cache::get_can_shift() const {
     return true;
 }
 
-uint32_t llama_kv_cache_unified::get_size() const {
+uint32_t llama_kv_cache::get_size() const {
     const auto & cells = v_cells[seq_to_stream[0]];
 
     return cells.size();
 }
 
-uint32_t llama_kv_cache_unified::get_n_stream() const {
+uint32_t llama_kv_cache::get_n_stream() const {
     return n_stream;
 }
 
-bool llama_kv_cache_unified::get_has_shift() const {
+bool llama_kv_cache::get_has_shift() const {
     bool result = false;
 
     for (uint32_t s = 0; s < n_stream; ++s) {
@@ -1037,7 +1037,7 @@ bool llama_kv_cache_unified::get_has_shift() const {
     return result;
 }
 
-uint32_t llama_kv_cache_unified::get_n_kv() const {
+uint32_t llama_kv_cache::get_n_kv() const {
     uint32_t result = 0;
 
     for (uint32_t s = 0; s < n_stream; ++s) {
@@ -1049,11 +1049,11 @@ uint32_t llama_kv_cache_unified::get_n_kv() const {
     return result;
 }
 
-bool llama_kv_cache_unified::get_supports_set_rows() const {
+bool llama_kv_cache::get_supports_set_rows() const {
     return supports_set_rows;
 }
 
-ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
+ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
     const int32_t ikv = map_layer_ids.at(il);
 
     auto * k = layers[ikv].k;
@@ -1073,7 +1073,7 @@ ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint
             ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0);
 }
 
-ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
+ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const {
     const int32_t ikv = map_layer_ids.at(il);
 
     auto * v = layers[ikv].v;
@@ -1105,7 +1105,7 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
             ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0);
 }
 
-ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
+ggml_tensor * llama_kv_cache::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
     const int32_t ikv = map_layer_ids.at(il);
 
     auto * k = layers[ikv].k;
@@ -1135,7 +1135,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
     return ggml_cpy(ctx, k_cur, k_view);
 }
 
-ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
+ggml_tensor * llama_kv_cache::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
     const int32_t ikv = map_layer_ids.at(il);
 
     auto * v = layers[ikv].v;
@@ -1189,7 +1189,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
     return ggml_cpy(ctx, v_cur, v_view);
 }
 
-ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
+ggml_tensor * llama_kv_cache::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
     const uint32_t n_tokens = ubatch.n_tokens;
 
     ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
@@ -1199,7 +1199,7 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con
     return k_idxs;
 }
 
-ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
+ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
     const uint32_t n_tokens = ubatch.n_tokens;
 
     ggml_tensor * v_idxs;
@@ -1215,7 +1215,7 @@ ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, con
     return v_idxs;
 }
 
-void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
+void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
     if (!supports_set_rows) {
         return;
     }
@@ -1235,7 +1235,7 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba
     }
 }
 
-void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
+void llama_kv_cache::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
     if (!supports_set_rows) {
         return;
     }
@@ -1272,7 +1272,7 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba
     }
 }
 
-void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
+void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const {
     GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
 
     int32_t * data = (int32_t *) dst->data;
@@ -1286,7 +1286,7 @@ void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const {
     }
 }
 
-void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
+void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
     const uint32_t n_tokens = ubatch->n_tokens;
 
     GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
@@ -1358,7 +1358,7 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
     }
 }
 
-void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
+void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
     const int64_t n_tokens = ubatch->n_tokens;
 
     GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams");
@@ -1383,7 +1383,7 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
     }
 }
 
-size_t llama_kv_cache_unified::total_size() const {
+size_t llama_kv_cache::total_size() const {
     size_t size = 0;
 
     for (const auto & buf : bufs) {
@@ -1393,7 +1393,7 @@ size_t llama_kv_cache_unified::total_size() const {
     return size;
 }
 
-size_t llama_kv_cache_unified::size_k_bytes() const {
+size_t llama_kv_cache::size_k_bytes() const {
     size_t size_k_bytes = 0;
 
     for (const auto & layer : layers) {
@@ -1403,7 +1403,7 @@ size_t llama_kv_cache_unified::size_k_bytes() const {
     return size_k_bytes;
 }
 
-size_t llama_kv_cache_unified::size_v_bytes() const {
+size_t llama_kv_cache::size_v_bytes() const {
     size_t size_v_bytes = 0;
 
     for (const auto & layer : layers) {
@@ -1413,7 +1413,7 @@ size_t llama_kv_cache_unified::size_v_bytes() const {
     return size_v_bytes;
 }
 
-ggml_tensor * llama_kv_cache_unified::build_rope_shift(
+ggml_tensor * llama_kv_cache::build_rope_shift(
         const llama_cparams & cparams,
                ggml_context * ctx,
                 ggml_tensor * cur,
@@ -1465,14 +1465,14 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift(
 
 class llm_graph_input_k_shift : public llm_graph_input_i {
 public:
-    llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
+    llm_graph_input_k_shift(const llama_kv_cache * kv_self) : kv_self(kv_self) {}
     virtual ~llm_graph_input_k_shift() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
 
     ggml_tensor * k_shift; // I32 [kv_size*n_stream]
 
-    const llama_kv_cache_unified * kv_self;
+    const llama_kv_cache * kv_self;
 };
 
 void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
@@ -1483,7 +1483,7 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
     }
 }
 
-ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
+ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const {
     auto * ctx = res->get_ctx();
     auto * gf  = res->get_gf();
 
@@ -1525,7 +1525,7 @@ ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res,
     return gf;
 }
 
-ggml_cgraph * llama_kv_cache_unified::build_graph_defrag(
+ggml_cgraph * llama_kv_cache::build_graph_defrag(
          llm_graph_result * res,
             llama_context * lctx,
         const defrag_info & dinfo) const {
@@ -1679,7 +1679,7 @@ ggml_cgraph * llama_kv_cache_unified::build_graph_defrag(
     return gf;
 }
 
-llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const {
+llama_kv_cache::defrag_info llama_kv_cache::defrag_prepare(int32_t n_max_nodes) const {
     GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag");
 
     const auto & cells = v_cells[0];
@@ -1802,7 +1802,7 @@ llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32
     return res;
 }
 
-bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
+bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
     assert(p0 >= 0 && p1 >= 0);
 
     switch (swa_type) {
@@ -1828,7 +1828,7 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const {
     return false;
 }
 
-void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
+void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
     GGML_UNUSED(flags);
 
     io.write(&n_stream, sizeof(n_stream));
@@ -1881,7 +1881,7 @@ void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq
     }
 }
 
-void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
+void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
     GGML_UNUSED(flags);
 
     GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
@@ -1917,7 +1917,7 @@ void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_i
     }
 }
 
-void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
+void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const {
     const auto & cells = v_cells[cr.strm];
 
     for (const auto & range : cr.data) {
@@ -1945,7 +1945,7 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_
     }
 }
 
-void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
+void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const {
     const auto & cells = v_cells[cr.strm];
 
     const uint32_t v_trans = this->v_trans ? 1 : 0;
@@ -2040,7 +2040,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_
     }
 }
 
-bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
+bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
     auto & cells = v_cells[strm];
     auto & head  = v_heads[strm];
 
@@ -2137,7 +2137,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm
     return true;
 }
 
-bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
+bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
     auto & cells = v_cells[strm];
     auto & head  = v_heads[strm];
 
@@ -2274,13 +2274,13 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm
 }
 
 //
-// llama_kv_cache_unified_context
+// llama_kv_cache_context
 //
 
-llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
+llama_kv_cache_context::llama_kv_cache_context(llama_memory_status status) : status(status) {}
 
-llama_kv_cache_unified_context::llama_kv_cache_unified_context(
-        llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
+llama_kv_cache_context::llama_kv_cache_context(
+        llama_kv_cache * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
     n_kv = kv->get_size();
 
     const uint32_t n_stream = kv->get_n_stream();
@@ -2296,8 +2296,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
     }
 }
 
-llama_kv_cache_unified_context::llama_kv_cache_unified_context(
-        llama_kv_cache_unified * kv,
+llama_kv_cache_context::llama_kv_cache_context(
+        llama_kv_cache * kv,
         llama_context * lctx,
         bool do_shift,
         defrag_info dinfo,
@@ -2307,15 +2307,15 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
     }
 }
 
-llama_kv_cache_unified_context::llama_kv_cache_unified_context(
-        llama_kv_cache_unified * kv,
-        llama_kv_cache_unified::slot_info_vec_t sinfos,
+llama_kv_cache_context::llama_kv_cache_context(
+        llama_kv_cache * kv,
+        llama_kv_cache::slot_info_vec_t sinfos,
         std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
 }
 
-llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
+llama_kv_cache_context::~llama_kv_cache_context() = default;
 
-bool llama_kv_cache_unified_context::next() {
+bool llama_kv_cache_context::next() {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     if (++i_cur >= ubatches.size()) {
@@ -2325,7 +2325,7 @@ bool llama_kv_cache_unified_context::next() {
     return true;
 }
 
-bool llama_kv_cache_unified_context::apply() {
+bool llama_kv_cache_context::apply() {
     assert(!llama_memory_status_is_fail(status));
 
     // no ubatches -> this is a KV cache update
@@ -2342,69 +2342,69 @@ bool llama_kv_cache_unified_context::apply() {
     return true;
 }
 
-llama_memory_status llama_kv_cache_unified_context::get_status() const {
+llama_memory_status llama_kv_cache_context::get_status() const {
     return status;
 }
 
-const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
+const llama_ubatch & llama_kv_cache_context::get_ubatch() const {
     assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
 
     return ubatches[i_cur];
 }
 
-uint32_t llama_kv_cache_unified_context::get_n_kv() const {
+uint32_t llama_kv_cache_context::get_n_kv() const {
     return n_kv;
 }
 
-bool llama_kv_cache_unified_context::get_supports_set_rows() const {
+bool llama_kv_cache_context::get_supports_set_rows() const {
     return kv->get_supports_set_rows();
 }
 
-ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
+ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const {
     return kv->get_k(ctx, il, n_kv, sinfos[i_cur]);
 }
 
-ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
+ggml_tensor * llama_kv_cache_context::get_v(ggml_context * ctx, int32_t il) const {
     return kv->get_v(ctx, il, n_kv, sinfos[i_cur]);
 }
 
-ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
+ggml_tensor * llama_kv_cache_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
     return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
 }
 
-ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
+ggml_tensor * llama_kv_cache_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
     return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
 }
 
-ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
+ggml_tensor * llama_kv_cache_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
     return kv->build_input_k_idxs(ctx, ubatch);
 }
 
-ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
+ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
     return kv->build_input_v_idxs(ctx, ubatch);
 }
 
-void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
+void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const {
     kv->set_input_k_shift(dst);
 }
 
-void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
+void llama_kv_cache_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
     kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
 }
 
-void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
+void llama_kv_cache_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
     kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
 }
 
-void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
+void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
     kv->set_input_kq_mask(dst, ubatch, causal_attn);
 }
 
-void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
+void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
     kv->set_input_pos_bucket(dst, ubatch);
 }
 
-uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) {
+uint32_t llama_kv_cache::get_padding(const llama_cparams & cparams) {
     // the FA kernels require padding to avoid extra runtime boundary checks
     return cparams.flash_attn ? 256u : 32u;
 }

+ 20 - 20
src/llama-kv-cache-unified.h → src/llama-kv-cache.h

@@ -14,10 +14,10 @@ struct llama_model;
 struct llama_context;
 
 //
-// llama_kv_cache_unified
+// llama_kv_cache
 //
 
-class llama_kv_cache_unified : public llama_memory_i {
+class llama_kv_cache : public llama_memory_i {
 public:
     static uint32_t get_padding(const llama_cparams & cparams);
 
@@ -92,7 +92,7 @@ public:
 
     using slot_info_vec_t = std::vector<slot_info>;
 
-    llama_kv_cache_unified(
+    llama_kv_cache(
             const llama_model &  model,
               layer_filter_cb && filter,
                     ggml_type    type_k,
@@ -106,7 +106,7 @@ public:
                      uint32_t    n_swa,
                llama_swa_type    swa_type);
 
-    ~llama_kv_cache_unified() = default;
+    ~llama_kv_cache() = default;
 
     //
     // llama_memory_i
@@ -140,7 +140,7 @@ public:
     void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override;
 
     //
-    // llama_kv_cache_unified specific API
+    // llama_kv_cache specific API
     //
 
     uint32_t get_size()     const;
@@ -241,7 +241,7 @@ private:
     // note: this is not part of the KV state and it's only used to speed-up the find_slot() method
     std::vector<uint32_t> v_heads;
 
-    std::vector<llama_kv_cells_unified> v_cells;
+    std::vector<llama_kv_cells> v_cells;
 
     // maps from a sequence id to a stream id
     std::vector<uint32_t> seq_to_stream;
@@ -295,35 +295,35 @@ private:
     bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
 };
 
-class llama_kv_cache_unified_context : public llama_memory_context_i {
+class llama_kv_cache_context : public llama_memory_context_i {
 public:
     // some shorthands
-    using slot_info_vec_t  = llama_kv_cache_unified::slot_info_vec_t;
-    using defrag_info      = llama_kv_cache_unified::defrag_info;
-    using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
+    using slot_info_vec_t  = llama_kv_cache::slot_info_vec_t;
+    using defrag_info      = llama_kv_cache::defrag_info;
+    using stream_copy_info = llama_kv_cache::stream_copy_info;
 
     // used for errors
-    llama_kv_cache_unified_context(llama_memory_status status);
+    llama_kv_cache_context(llama_memory_status status);
 
     // used to create a full-cache context
-    llama_kv_cache_unified_context(
-            llama_kv_cache_unified * kv);
+    llama_kv_cache_context(
+            llama_kv_cache * kv);
 
     // used to create an update context
-    llama_kv_cache_unified_context(
-            llama_kv_cache_unified * kv,
+    llama_kv_cache_context(
+            llama_kv_cache * kv,
             llama_context * lctx,
             bool do_shift,
             defrag_info dinfo,
             stream_copy_info sc_info);
 
     // used to create a batch procesing context from a batch
-    llama_kv_cache_unified_context(
-            llama_kv_cache_unified * kv,
+    llama_kv_cache_context(
+            llama_kv_cache * kv,
             slot_info_vec_t sinfos,
             std::vector<llama_ubatch> ubatches);
 
-    virtual ~llama_kv_cache_unified_context();
+    virtual ~llama_kv_cache_context();
 
     //
     // llama_memory_context_i
@@ -336,7 +336,7 @@ public:
     const llama_ubatch & get_ubatch() const override;
 
     //
-    // llama_kv_cache_unified_context specific API
+    // llama_kv_cache_context specific API
     //
 
     uint32_t get_n_kv() const;
@@ -365,7 +365,7 @@ public:
 private:
     llama_memory_status status;
 
-    llama_kv_cache_unified * kv;
+    llama_kv_cache * kv;
     llama_context * lctx;
 
     //

+ 7 - 7
src/llama-kv-cells.h

@@ -11,7 +11,7 @@
 
 // meta information about KV cells that can be part of multiple sequences at the same time
 // TODO: add unit tests
-class llama_kv_cells_unified {
+class llama_kv_cells {
 public:
     void reset() {
         for (uint32_t i = 0; i < pos.size(); ++i) {
@@ -97,10 +97,10 @@ public:
     }
 
     // copy the state of cells [i, i + n) (used for save/restore the state of the cells)
-    llama_kv_cells_unified cp(uint32_t i, uint32_t n) const {
+    llama_kv_cells cp(uint32_t i, uint32_t n) const {
         assert(i + n <= pos.size());
 
-        llama_kv_cells_unified res;
+        llama_kv_cells res;
 
         res.resize(n);
 
@@ -117,8 +117,8 @@ public:
     }
 
     // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
-    llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
-        llama_kv_cells_unified res;
+    llama_kv_cells cp(const std::vector<uint32_t> & idxs) const {
+        llama_kv_cells res;
 
         res.resize(idxs.size());
 
@@ -135,7 +135,7 @@ public:
     }
 
     // set the state of cells [i, i + other.pos.size()) (used for save/restore the state of the cells)
-    void set(uint32_t i, const llama_kv_cells_unified & other) {
+    void set(uint32_t i, const llama_kv_cells & other) {
         assert(i + other.pos.size() <= pos.size());
 
         for (uint32_t j = 0; j < other.pos.size(); ++j) {
@@ -165,7 +165,7 @@ public:
     }
 
     // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
-    void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
+    void set(const std::vector<uint32_t> & idxs, const llama_kv_cells & other) {
         assert(idxs.size() == other.pos.size());
 
         for (uint32_t j = 0; j < other.pos.size(); ++j) {

+ 5 - 5
src/llama-memory-hybrid.cpp

@@ -30,7 +30,7 @@ llama_memory_hybrid::llama_memory_hybrid(
       layer_filter_cb && filter_attn,
       layer_filter_cb && filter_recr) :
     hparams(model.hparams),
-    mem_attn(new llama_kv_cache_unified(
+    mem_attn(new llama_kv_cache(
         model,
         filter_attn == nullptr ?
             [&](int32_t il) { return !hparams.is_recurrent(il); }
@@ -179,7 +179,7 @@ void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id,
     mem_recr->state_read(io, seq_id);
 }
 
-llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
+llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {
     return mem_attn.get();
 }
 
@@ -210,7 +210,7 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
         std::vector<llama_ubatch>   ubatches) :
     ubatches(std::move(ubatches)),
     // note: here we copy the ubatches. not sure if this is ideal
-    ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
+    ctx_attn(new llama_kv_cache_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
     ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(),                        this->ubatches)),
     status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
 }
@@ -248,8 +248,8 @@ const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
     return ubatches[i_next];
 }
 
-const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
-    return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
+const llama_kv_cache_context * llama_memory_hybrid_context::get_attn() const {
+    return static_cast<const llama_kv_cache_context *>(ctx_attn.get());
 }
 
 const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {

+ 6 - 6
src/llama-memory-hybrid.h

@@ -2,7 +2,7 @@
 
 #include "llama-batch.h"
 #include "llama-graph.h"
-#include "llama-kv-cache-unified.h"
+#include "llama-kv-cache.h"
 #include "llama-memory.h"
 #include "llama-memory-recurrent.h"
 
@@ -13,7 +13,7 @@
 // llama_memory_hybrid
 //
 
-// utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
+// utilizes instances of llama_memory_recurrent and llama_kv_cache to
 //   support models where each layer may be either attention-based or recurrent
 
 class llama_memory_hybrid : public llama_memory_i {
@@ -81,19 +81,19 @@ public:
     // llama_memory_hybrid specific API
     //
 
-    llama_kv_cache_unified * get_mem_attn() const;
+    llama_kv_cache * get_mem_attn() const;
     llama_memory_recurrent * get_mem_recr() const;
 
 private:
     const llama_hparams & hparams;
 
-    const std::unique_ptr<llama_kv_cache_unified> mem_attn;
+    const std::unique_ptr<llama_kv_cache> mem_attn;
     const std::unique_ptr<llama_memory_recurrent> mem_recr;
 };
 
 class llama_memory_hybrid_context : public llama_memory_context_i {
 public:
-    using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
+    using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
 
     // init failure
     explicit llama_memory_hybrid_context(llama_memory_status status);
@@ -125,7 +125,7 @@ public:
     // llama_memory_hybrid_context
     //
 
-    const llama_kv_cache_unified_context * get_attn() const;
+    const llama_kv_cache_context * get_attn() const;
     const llama_memory_recurrent_context * get_recr() const;
 
 private:

+ 1 - 1
src/llama-memory-recurrent.h

@@ -12,7 +12,7 @@
 //
 
 // TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
-//       see the implementation of llama_kv_cache_unified_context_i for an example how to do it
+//       see the implementation of llama_kv_cache_context_i for an example how to do it
 class llama_memory_recurrent : public llama_memory_i {
 public:
 

+ 2 - 7
src/llama-memory.h

@@ -36,8 +36,8 @@ bool llama_memory_status_is_fail(llama_memory_status status);
 
 // the interface for managing the memory context during batch processing
 // this interface is implemented per memory type. see:
-//   - llama_kv_cache_unified_context
-//   - llama_kv_cache_unified_iswa_context
+//   - llama_kv_cache_context
+//   - llama_kv_cache_iswa_context
 //   ...
 //
 // the only method that should mutate the memory and the memory context is llama_memory_i::apply()
@@ -109,8 +109,3 @@ struct llama_memory_i {
 };
 
 using llama_memory_ptr = std::unique_ptr<llama_memory_i>;
-
-// TODO: temporary until the llama_kv_cache is removed from the public API
-struct llama_kv_cache : public llama_memory_i {
-    virtual ~llama_kv_cache() = default;
-};

+ 93 - 93
src/llama-model.cpp

@@ -6,8 +6,8 @@
 #include "llama-cparams.h"
 #include "llama-model-loader.h"
 
-#include "llama-kv-cache-unified.h"
-#include "llama-kv-cache-unified-iswa.h"
+#include "llama-kv-cache.h"
+#include "llama-kv-cache-iswa.h"
 #include "llama-memory-hybrid.h"
 #include "llama-memory-recurrent.h"
 
@@ -5986,7 +5986,7 @@ struct llm_build_llama : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -6146,7 +6146,7 @@ struct llm_build_llama_iswa : public llm_graph_context {
         ggml_tensor * inp_attn_scale = nullptr;
         inp_attn_scale = build_inp_attn_scale();
 
-        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+        auto * inp_attn = build_attn_inp_kv_iswa();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -6325,7 +6325,7 @@ struct llm_build_deci : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -6481,7 +6481,7 @@ struct llm_build_baichuan : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = model.type == LLM_TYPE_7B ? build_inp_pos() : nullptr;
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -6603,7 +6603,7 @@ struct llm_build_xverse : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -6717,7 +6717,7 @@ struct llm_build_falcon : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -6841,7 +6841,7 @@ struct llm_build_grok : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -7001,7 +7001,7 @@ struct llm_build_dbrx : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -7125,7 +7125,7 @@ struct llm_build_starcoder : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
         cb(pos, "pos_embd", -1);
@@ -7230,7 +7230,7 @@ struct llm_build_refact : public llm_graph_context {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -7632,7 +7632,7 @@ struct llm_build_bloom : public llm_graph_context {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         inpL = build_norm(inpL,
                 model.tok_norm,
@@ -7739,7 +7739,7 @@ struct llm_build_mpt : public llm_graph_context {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         if (model.pos_embd) {
             // inp_pos - contains the positions
@@ -7889,7 +7889,7 @@ struct llm_build_stablelm : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -8041,7 +8041,7 @@ struct llm_build_qwen : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -8156,7 +8156,7 @@ struct llm_build_qwen2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -8481,7 +8481,7 @@ struct llm_build_qwen2vl : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         int sections[4];
         std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
@@ -8602,7 +8602,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -8761,7 +8761,7 @@ struct llm_build_qwen3 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -8882,7 +8882,7 @@ struct llm_build_qwen3moe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -9012,7 +9012,7 @@ struct llm_build_phi2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -9141,13 +9141,13 @@ struct llm_build_phi3 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
+        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
         inp_attn_type * inp_attn = nullptr;
 
         if constexpr (iswa) {
-            inp_attn = build_attn_inp_kv_unified_iswa();
+            inp_attn = build_attn_inp_kv_iswa();
         } else {
-            inp_attn = build_attn_inp_kv_unified();
+            inp_attn = build_attn_inp_kv();
         }
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -9299,7 +9299,7 @@ struct llm_build_plamo : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -9415,7 +9415,7 @@ struct llm_build_gpt2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
         cb(pos, "pos_embd", -1);
@@ -9525,7 +9525,7 @@ struct llm_build_codeshell : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -9638,7 +9638,7 @@ struct llm_build_orion : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -9765,7 +9765,7 @@ struct llm_build_internlm2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -9901,7 +9901,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -10096,7 +10096,7 @@ struct llm_build_gemma : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -10212,7 +10212,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+        auto * inp_attn = build_attn_inp_kv_iswa();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -10346,7 +10346,7 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
         ggml_tensor * inp_pos = build_inp_pos();
 
         // TODO: is causal == true correct? might need some changes
-        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+        auto * inp_attn = build_attn_inp_kv_iswa();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -10497,7 +10497,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context {
         ggml_tensor * inp_pos = build_inp_pos();
 
         // TODO: is causal == true correct? might need some changes
-        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+        auto * inp_attn = build_attn_inp_kv_iswa();
 
         // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
         ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
@@ -10904,7 +10904,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -11473,7 +11473,7 @@ struct llm_build_command_r : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -11620,7 +11620,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+        auto * inp_attn = build_attn_inp_kv_iswa();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -11755,7 +11755,7 @@ struct llm_build_olmo : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -11883,7 +11883,7 @@ struct llm_build_olmo2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -12012,7 +12012,7 @@ struct llm_build_olmoe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -12138,7 +12138,7 @@ struct llm_build_openelm : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -12269,7 +12269,7 @@ struct llm_build_gptneox : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -12415,7 +12415,7 @@ struct llm_build_arctic : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -12553,7 +12553,7 @@ struct llm_build_deepseek : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -12730,7 +12730,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -12977,7 +12977,7 @@ struct llm_build_bitnet : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -13241,7 +13241,7 @@ struct llm_build_t5_dec : public llm_graph_context {
 
         const int64_t n_outputs_enc = embd_enc->ne[1];
 
-        auto * inp_attn_self  = build_attn_inp_kv_unified();
+        auto * inp_attn_self  = build_attn_inp_kv();
         auto * inp_attn_cross = build_attn_inp_cross();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -13406,7 +13406,7 @@ struct llm_build_jais : public llm_graph_context {
 
         inpL = build_inp_embd(model.tok_embd);
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -13504,7 +13504,7 @@ struct llm_build_chatglm : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -13637,7 +13637,7 @@ struct llm_build_glm4 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -13787,7 +13787,7 @@ struct llm_build_glm4_moe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -13947,7 +13947,7 @@ struct llm_build_nemotron : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -14076,7 +14076,7 @@ struct llm_build_exaone : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -14208,13 +14208,13 @@ struct llm_build_exaone4 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
+        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
         inp_attn_type * inp_attn = nullptr;
 
         if constexpr (iswa) {
-            inp_attn = build_attn_inp_kv_unified_iswa();
+            inp_attn = build_attn_inp_kv_iswa();
         } else {
-            inp_attn = build_attn_inp_kv_unified();
+            inp_attn = build_attn_inp_kv();
         }
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -15097,7 +15097,7 @@ struct llm_build_granite : public llm_graph_context {
             inp_pos = build_inp_pos();
         }
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -15148,12 +15148,12 @@ struct llm_build_granite : public llm_graph_context {
     }
 
     ggml_tensor * build_attention_layer(
-              ggml_tensor                     * cur,
-              ggml_tensor                     * inp_pos,
-              llm_graph_input_attn_kv_unified * inp_attn,
-        const llama_model                     & model,
-        const int64_t                           n_embd_head,
-        const int                               il) {
+              ggml_tensor             * cur,
+              ggml_tensor             * inp_pos,
+              llm_graph_input_attn_kv * inp_attn,
+        const llama_model             & model,
+        const int64_t                 n_embd_head,
+        const int                     il) {
 
         // compute Q and K and (optionally) RoPE them
         ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -15367,12 +15367,12 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba {
     }
 
     ggml_tensor * build_attention_layer(
-              ggml_tensor                     * cur,
-              ggml_tensor                     * inp_pos,
-              llm_graph_input_attn_kv_unified * inp_attn,
-        const llama_model                     & model,
-        const int64_t                           n_embd_head,
-        const int                               il) {
+              ggml_tensor             * cur,
+              ggml_tensor             * inp_pos,
+              llm_graph_input_attn_kv * inp_attn,
+        const llama_model             & model,
+        const int64_t                 n_embd_head,
+        const int                     il) {
 
         // compute Q and K and (optionally) RoPE them
         ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -15529,7 +15529,7 @@ struct llm_build_chameleon : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -15860,7 +15860,7 @@ struct llm_build_plm : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -16025,7 +16025,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -16174,7 +16174,7 @@ struct llm_build_dots1 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -16324,7 +16324,7 @@ struct llm_build_ernie4_5 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -16454,7 +16454,7 @@ struct llm_build_ernie4_5_moe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
 
@@ -16828,7 +16828,7 @@ struct llm_build_plamo2 : public llm_graph_context_mamba {
 
 private:
     ggml_tensor * build_plamo2_attn_layer(
-            llm_graph_input_attn_kv_unified * inp,
+            llm_graph_input_attn_kv * inp,
             ggml_tensor * inp_pos,
             ggml_tensor * cur,
             const llama_model & model,
@@ -17061,7 +17061,7 @@ struct llm_build_arcee : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -17196,7 +17196,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
 
@@ -17357,7 +17357,7 @@ struct llm_build_hunyuan_dense : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = 1.0f / sqrtf(float(n_embd_head));
 
@@ -17495,7 +17495,7 @@ struct llm_build_smollm3 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
 
@@ -17627,7 +17627,7 @@ struct llm_build_openai_moe_iswa : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified_iswa();
+        auto * inp_attn = build_attn_inp_kv_iswa();
 
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
@@ -17809,10 +17809,10 @@ struct llm_build_lfm2 : public llm_graph_context {
         return cur;
     }
 
-    ggml_tensor * build_attn_block(ggml_tensor                     * cur,
-                                   ggml_tensor                     * inp_pos,
-                                   llm_graph_input_attn_kv_unified * inp_attn,
-                                   int                               il) const {
+    ggml_tensor * build_attn_block(ggml_tensor             * cur,
+                                   ggml_tensor             * inp_pos,
+                                   llm_graph_input_attn_kv * inp_attn,
+                                   int                     il) const {
         GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il));
         auto const n_embd_head = hparams.n_embd_head_v;
         auto const n_head_kv = hparams.n_head_kv(il);
@@ -17940,13 +17940,13 @@ struct llm_build_smallthinker : public llm_graph_context{
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
+        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>;
         inp_attn_type * inp_attn = nullptr;
 
         if constexpr (iswa) {
-            inp_attn = build_attn_inp_kv_unified_iswa();
+            inp_attn = build_attn_inp_kv_iswa();
         } else {
-            inp_attn = build_attn_inp_kv_unified();
+            inp_attn = build_attn_inp_kv();
         }
 
         ggml_tensor * inp_out_ids = build_inp_out_ids();
@@ -18076,7 +18076,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                             std::max((uint32_t) 1, cparams.n_seq_max),
                             cparams.n_seq_max);
                 } else if (llm_arch_is_hybrid(arch)) {
-                    const auto padding = llama_kv_cache_unified::get_padding(cparams);
+                    const auto padding = llama_kv_cache::get_padding(cparams);
 
                     cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
 
@@ -18098,7 +18098,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                         /* filter_attn       */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr,
                         /* filter_recr       */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr);
                 } else {
-                    const auto padding = llama_kv_cache_unified::get_padding(cparams);
+                    const auto padding = llama_kv_cache::get_padding(cparams);
 
                     uint32_t n_ctx_per_stream = cparams.n_ctx;
 
@@ -18118,7 +18118,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                     if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
                         GGML_ASSERT(hparams.is_swa_any());
 
-                        res = new llama_kv_cache_unified_iswa(
+                        res = new llama_kv_cache_iswa(
                                 *this,
                                 params.type_k,
                                 params.type_v,
@@ -18133,7 +18133,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                     } else {
                         GGML_ASSERT(!hparams.is_swa_any());
 
-                        res = new llama_kv_cache_unified(
+                        res = new llama_kv_cache(
                                 *this,
                                 nullptr,
                                 params.type_k,