Browse Source

memory : add llama_memory_hybrid_iswa (#18601)

* memory : add llama_memory_hybrid_iswa

* Update src/llama-memory-hybrid-iswa.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Tarek Dakhran 1 week ago
parent
commit
ad8d85bd94
6 changed files with 598 additions and 17 deletions
  1. 1 0
      src/CMakeLists.txt
  2. 112 0
      src/llama-graph.cpp
  3. 31 0
      src/llama-graph.h
  4. 275 0
      src/llama-memory-hybrid-iswa.cpp
  5. 140 0
      src/llama-memory-hybrid-iswa.h
  6. 39 17
      src/llama-model.cpp

+ 1 - 0
src/CMakeLists.txt

@@ -24,6 +24,7 @@ add_library(llama
             llama-kv-cache-iswa.cpp
             llama-kv-cache-iswa.cpp
             llama-memory.cpp
             llama-memory.cpp
             llama-memory-hybrid.cpp
             llama-memory-hybrid.cpp
+            llama-memory-hybrid-iswa.cpp
             llama-memory-recurrent.cpp
             llama-memory-recurrent.cpp
             llama-mmap.cpp
             llama-mmap.cpp
             llama-model-loader.cpp
             llama-model-loader.cpp

+ 112 - 0
src/llama-graph.cpp

@@ -7,6 +7,7 @@
 #include "llama-kv-cache.h"
 #include "llama-kv-cache.h"
 #include "llama-kv-cache-iswa.h"
 #include "llama-kv-cache-iswa.h"
 #include "llama-memory-hybrid.h"
 #include "llama-memory-hybrid.h"
+#include "llama-memory-hybrid-iswa.h"
 #include "llama-memory-recurrent.h"
 #include "llama-memory-recurrent.h"
 
 
 #include <cassert>
 #include <cassert>
@@ -510,6 +511,76 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
     return res;
     return res;
 }
 }
 
 
+void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) {
+    const auto * attn_ctx = mctx->get_attn();
+
+    // base tensors may not be allocated if there are no non-SWA attention layers
+    if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
+        attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch);
+        attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch);
+
+        attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn);
+    }
+
+    // swa tensors may not be allocated if there are no SWA attention layers
+    if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
+        attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch);
+        attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch);
+
+        attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn);
+    }
+
+    const int64_t n_rs = mctx->get_recr()->get_n_rs();
+
+    if (inp_rs->s_copy) {
+        GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer));
+        int32_t * data = (int32_t *) inp_rs->s_copy->data;
+
+        // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
+        for (uint32_t i = 0; i < n_rs; ++i) {
+            data[i] = mctx->get_recr()->s_copy(i);
+        }
+    }
+}
+
+bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) {
+    const auto * mctx = static_cast<const llama_memory_hybrid_iswa_context *>(params.mctx);
+
+    this->mctx = mctx;
+
+    bool res = true;
+
+    const auto * attn_ctx = mctx->get_attn();
+
+    // base tensors may not be allocated if there are no non-SWA attention layers
+    if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) {
+        res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
+      //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
+
+        res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv();
+        res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
+    }
+
+    // swa tensors may not be allocated if there are no SWA attention layers
+    if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) {
+        res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
+      //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
+
+        res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv();
+        res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
+    }
+
+    res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
+
+    res &= inp_rs->s_copy_main->ne[0]  == params.ubatch.n_seqs;
+    res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
+
+    res &= inp_rs->head == mctx->get_recr()->get_head();
+    res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
+
+    return res;
+}
+
 void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
 void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
     // set the inputs only for the active samplers in the current ubatch
     // set the inputs only for the active samplers in the current ubatch
     std::unordered_set<llama_seq_id> active_samplers;
     std::unordered_set<llama_seq_id> active_samplers;
@@ -2056,6 +2127,47 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
     return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
     return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
 }
 }
 
 
+llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const {
+    const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx);
+
+    auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr());
+
+    // build iswa attention input
+    const auto * attn_ctx = mctx_cur->get_attn();
+
+    auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
+
+    const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
+
+    {
+        const auto n_kv = attn_ctx->get_base()->get_n_kv();
+
+        inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
+        inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
+
+        inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
+        ggml_set_input(inp_attn->self_kq_mask);
+
+        inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
+    }
+
+    {
+        const auto n_kv = attn_ctx->get_swa()->get_n_kv();
+
+        inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
+        inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
+
+        inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
+        ggml_set_input(inp_attn->self_kq_mask_swa);
+
+        inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
+    }
+
+    auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur);
+
+    return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp));
+}
+
 void llm_graph_context::build_dense_out(
 void llm_graph_context::build_dense_out(
     ggml_tensor * dense_2,
     ggml_tensor * dense_2,
     ggml_tensor * dense_3) const {
     ggml_tensor * dense_3) const {

+ 31 - 0
src/llama-graph.h

@@ -24,6 +24,7 @@ class llama_kv_cache_context;
 class llama_kv_cache_iswa_context;
 class llama_kv_cache_iswa_context;
 class llama_memory_recurrent_context;
 class llama_memory_recurrent_context;
 class llama_memory_hybrid_context;
 class llama_memory_hybrid_context;
+class llama_memory_hybrid_iswa_context;
 
 
 // certain models (typically multi-modal) can produce different types of graphs
 // certain models (typically multi-modal) can produce different types of graphs
 enum llm_graph_type {
 enum llm_graph_type {
@@ -397,6 +398,34 @@ public:
     const llama_memory_hybrid_context * mctx;
     const llama_memory_hybrid_context * mctx;
 };
 };
 
 
+class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i {
+public:
+    llm_graph_input_mem_hybrid_iswa(
+            const llama_cparams & cparams,
+            std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn,
+            std::unique_ptr<llm_graph_input_rs>          inp_rs,
+            const llama_memory_hybrid_iswa_context *     mctx) :
+        inp_attn(std::move(inp_attn)),
+        inp_rs(std::move(inp_rs)),
+        cparams(cparams),
+        mctx(mctx) { }
+    virtual ~llm_graph_input_mem_hybrid_iswa() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
+    bool can_reuse(const llm_graph_params & params) override;
+
+    std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn;
+    std::unique_ptr<llm_graph_input_rs>          inp_rs;
+
+    llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); }
+    llm_graph_input_rs           * get_recr() const { return inp_rs.get(); }
+
+    const llama_cparams cparams;
+
+    const llama_memory_hybrid_iswa_context * mctx;
+};
+
 class llm_graph_input_sampling : public llm_graph_input_i {
 class llm_graph_input_sampling : public llm_graph_input_i {
 public:
 public:
     llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
     llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
@@ -881,6 +910,8 @@ struct llm_graph_context {
 
 
     llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
     llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
 
 
+    llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const;
+
     //
     //
     // pooling
     // pooling
     //
     //

+ 275 - 0
src/llama-memory-hybrid-iswa.cpp

@@ -0,0 +1,275 @@
+#include "llama-memory-hybrid-iswa.h"
+
+#include "llama-impl.h"
+#include "llama-model.h"
+#include "llama-context.h"
+
+//
+// llama_memory_hybrid_iswa
+//
+
+llama_memory_hybrid_iswa::llama_memory_hybrid_iswa(
+        const llama_model & model,
+                            /* attn */
+                ggml_type   type_k,
+                ggml_type   type_v,
+                     bool   v_trans,
+                     bool   swa_full,
+                 uint32_t   kv_size,
+                 uint32_t   n_ubatch,
+                 uint32_t   n_pad,
+                            /* recurrent */
+                ggml_type   type_r,
+                ggml_type   type_s,
+                 uint32_t   rs_size,
+                            /* common */
+                 uint32_t   n_seq_max,
+                     bool   offload,
+                     bool   unified,
+                            /* layer filters */
+    const layer_filter_cb & filter_attn,
+    const layer_filter_cb & filter_recr) :
+    hparams(model.hparams),
+    mem_attn(new llama_kv_cache_iswa(
+        model,
+        type_k,
+        type_v,
+        v_trans,
+        offload,
+        swa_full,
+        unified,
+        kv_size,
+        n_seq_max,
+        n_ubatch,
+        n_pad,
+        filter_attn == nullptr ?
+            [&](int32_t il) { return !hparams.is_recurrent(il); }
+            : filter_attn,
+        nullptr
+    )),
+    mem_recr(new llama_memory_recurrent(
+        model,
+        type_r,
+        type_s,
+        offload,
+        rs_size,
+        n_seq_max,
+        filter_recr == nullptr ?
+            [&](int32_t il) { return hparams.is_recurrent(il); }
+            : filter_recr
+    )) {}
+
+llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
+    do {
+        balloc.split_reset();
+
+        // follow the recurrent pattern for creating the ubatch splits
+        std::vector<llama_ubatch> ubatches;
+
+        while (true) {
+            llama_ubatch ubatch;
+
+            if (embd_all) {
+                // if all tokens are output, split by sequence
+                ubatch = balloc.split_seq(n_ubatch);
+            } else {
+                // TODO: non-sequential equal split can be done if using unified KV cache
+                //       for simplicity, we always use sequential equal split for now
+                ubatch = balloc.split_equal(n_ubatch, true);
+            }
+
+            if (ubatch.n_tokens == 0) {
+                break;
+            }
+
+            ubatches.push_back(std::move(ubatch)); // NOLINT
+        }
+
+        if (balloc.get_n_used() < balloc.get_n_tokens()) {
+            // failed to find a suitable split
+            break;
+        }
+
+        // prepare the recurrent batches first
+        if (!mem_recr->prepare(ubatches)) {
+            // TODO: will the recurrent cache be in an undefined context at this point?
+            LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
+            return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+        }
+
+        // prepare the attention cache (iswa version returns both base and swa slot infos)
+        auto sinfos_base = mem_attn->get_base()->prepare(ubatches);
+        if (sinfos_base.empty()) {
+            LLAMA_LOG_ERROR("%s: failed to prepare attention base ubatches\n", __func__);
+            return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+        }
+
+        auto sinfos_swa = mem_attn->get_swa()->prepare(ubatches);
+        if (sinfos_swa.empty()) {
+            LLAMA_LOG_ERROR("%s: failed to prepare attention swa ubatches\n", __func__);
+            return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+        }
+
+        return std::make_unique<llama_memory_hybrid_iswa_context>(
+                this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
+    } while(false);
+
+    return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
+}
+
+llama_memory_context_ptr llama_memory_hybrid_iswa::init_full() {
+    return std::make_unique<llama_memory_hybrid_iswa_context>(this);
+}
+
+llama_memory_context_ptr llama_memory_hybrid_iswa::init_update(llama_context * lctx, bool optimize) {
+    return std::make_unique<llama_memory_hybrid_iswa_context>(this, lctx, optimize);
+}
+
+bool llama_memory_hybrid_iswa::get_can_shift() const {
+    // Shifting is trivially supported for recurrent
+    return mem_attn->get_can_shift();
+}
+
+void llama_memory_hybrid_iswa::clear(bool data) {
+    mem_attn->clear(data);
+    mem_recr->clear(data);
+}
+
+bool llama_memory_hybrid_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
+    // Try removing from the recurrent cache first since it may fail. If it does
+    // fail, the cache will not have been mutated.
+    if (!mem_recr->seq_rm(seq_id, p0, p1)) {
+        return false;
+    }
+    return mem_attn->seq_rm(seq_id, p0, p1);
+}
+
+void llama_memory_hybrid_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
+    mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+    mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
+}
+
+void llama_memory_hybrid_iswa::seq_keep(llama_seq_id seq_id) {
+    mem_attn->seq_keep(seq_id);
+    mem_recr->seq_keep(seq_id);
+}
+
+void llama_memory_hybrid_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
+    mem_attn->seq_add(seq_id, p0, p1, shift);
+    mem_recr->seq_add(seq_id, p0, p1, shift);
+}
+
+void llama_memory_hybrid_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
+    mem_attn->seq_div(seq_id, p0, p1, d);
+    mem_recr->seq_div(seq_id, p0, p1, d);
+}
+
+llama_pos llama_memory_hybrid_iswa::seq_pos_min(llama_seq_id seq_id) const {
+    // the min of the total cache is the max of the two caches' min values
+    return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
+}
+
+llama_pos llama_memory_hybrid_iswa::seq_pos_max(llama_seq_id seq_id) const {
+    // the max of the total cache is the min of the two caches' max values
+    return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
+}
+
+std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid_iswa::memory_breakdown() const {
+    std::map<ggml_backend_buffer_type_t, size_t> mb = mem_attn->memory_breakdown();
+    for (const auto & buft_size : mem_recr->memory_breakdown()) {
+        mb[buft_size.first] += buft_size.second;
+    }
+    return mb;
+}
+
+void llama_memory_hybrid_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
+    mem_attn->state_write(io, seq_id, flags);
+    mem_recr->state_write(io, seq_id, flags);
+}
+
+void llama_memory_hybrid_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
+    mem_attn->state_read(io, seq_id, flags);
+    mem_recr->state_read(io, seq_id, flags);
+}
+
+llama_kv_cache_iswa * llama_memory_hybrid_iswa::get_mem_attn() const {
+    return mem_attn.get();
+}
+
+llama_memory_recurrent * llama_memory_hybrid_iswa::get_mem_recr() const {
+    return mem_recr.get();
+}
+
+//
+// llama_memory_hybrid_iswa_context
+//
+
+llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_status status) : status(status) {}
+
+llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem) :
+    ctx_attn(mem->get_mem_attn()->init_full()),
+    ctx_recr(mem->get_mem_recr()->init_full()),
+    status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
+}
+
+llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(
+        llama_memory_hybrid_iswa * mem,
+                   llama_context * lctx,
+                            bool   optimize) :
+    ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
+    ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
+    status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
+}
+
+llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(
+           llama_memory_hybrid_iswa * mem,
+                    slot_info_vec_t   sinfos_base,
+                    slot_info_vec_t   sinfos_swa,
+          std::vector<llama_ubatch>   ubatches) :
+    ubatches(std::move(ubatches)),
+    // note: here we copy the ubatches. not sure if this is ideal
+    ctx_attn(new llama_kv_cache_iswa_context(mem->get_mem_attn(), std::move(sinfos_base), std::move(sinfos_swa), this->ubatches)),
+    ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
+    status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
+}
+
+bool llama_memory_hybrid_iswa_context::next() {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+
+    ctx_attn->next();
+    ctx_recr->next();
+
+    if (++i_next >= ubatches.size()) {
+        return false;
+    }
+
+    return true;
+}
+
+bool llama_memory_hybrid_iswa_context::apply() {
+    assert(!llama_memory_status_is_fail(status));
+
+    bool res = true;
+
+    res = res & ctx_attn->apply();
+    res = res & ctx_recr->apply();
+
+    return res;
+}
+
+llama_memory_status llama_memory_hybrid_iswa_context::get_status() const {
+    return status;
+}
+
+const llama_ubatch & llama_memory_hybrid_iswa_context::get_ubatch() const {
+    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+    return ubatches[i_next];
+}
+
+const llama_kv_cache_iswa_context * llama_memory_hybrid_iswa_context::get_attn() const {
+    return static_cast<const llama_kv_cache_iswa_context *>(ctx_attn.get());
+}
+
+const llama_memory_recurrent_context * llama_memory_hybrid_iswa_context::get_recr() const {
+    return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
+}

+ 140 - 0
src/llama-memory-hybrid-iswa.h

@@ -0,0 +1,140 @@
+#pragma once
+
+#include "llama-batch.h"
+#include "llama-graph.h"
+#include "llama-kv-cache-iswa.h"
+#include "llama-memory.h"
+#include "llama-memory-recurrent.h"
+
+#include <memory>
+#include <vector>
+
+//
+// llama_memory_hybrid_iswa
+//
+
+// utilizes instances of llama_memory_recurrent and llama_kv_cache_iswa to
+//   support models where each layer may be either attention-based (with SWA support) or recurrent
+
+class llama_memory_hybrid_iswa : public llama_memory_i {
+public:
+    llama_memory_hybrid_iswa(
+        const llama_model & model,
+                            /* attn */
+                ggml_type   type_k,
+                ggml_type   type_v,
+                     bool   v_trans,
+                     bool   swa_full,
+                 uint32_t   kv_size,
+                 uint32_t   n_ubatch,
+                 uint32_t   n_pad,
+                            /* recurrent */
+                ggml_type   type_r,
+                ggml_type   type_s,
+                 uint32_t   rs_size,
+                            /* common */
+                 uint32_t   n_seq_max,
+                     bool   offload,
+                     bool   unified,
+                            /* layer filters */
+    const layer_filter_cb & filter_attn = nullptr,
+    const layer_filter_cb & filter_recr = nullptr);
+
+    ~llama_memory_hybrid_iswa() = default;
+
+    //
+    // llama_memory_i
+    //
+
+    llama_memory_context_ptr init_batch(
+            llama_batch_allocr & balloc,
+            uint32_t n_ubatch,
+            bool embd_all) override;
+
+    llama_memory_context_ptr init_full() override;
+
+    llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
+
+    bool get_can_shift() const override;
+
+    void clear(bool data) override;
+
+    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
+    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+    void seq_keep(llama_seq_id seq_id)                                                          override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos shift) override;
+    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
+
+    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
+    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
+
+    std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
+
+    // state write/load
+
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0)       override;
+
+    //
+    // llama_memory_hybrid_iswa specific API
+    //
+
+    llama_kv_cache_iswa * get_mem_attn() const;
+    llama_memory_recurrent * get_mem_recr() const;
+
+private:
+    const llama_hparams & hparams;
+
+    const std::unique_ptr<llama_kv_cache_iswa> mem_attn;
+    const std::unique_ptr<llama_memory_recurrent> mem_recr;
+};
+
+class llama_memory_hybrid_iswa_context : public llama_memory_context_i {
+public:
+    using slot_info_vec_t = llama_kv_cache::slot_info_vec_t;
+
+    // init failure
+    explicit llama_memory_hybrid_iswa_context(llama_memory_status status);
+
+    // init full
+    explicit llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem);
+
+    // init update
+    explicit llama_memory_hybrid_iswa_context(
+        llama_memory_hybrid_iswa * mem,
+                   llama_context * lctx,
+                            bool   optimize);
+
+    // init success
+    llama_memory_hybrid_iswa_context(
+           llama_memory_hybrid_iswa * mem,
+                    slot_info_vec_t   sinfos_base,
+                    slot_info_vec_t   sinfos_swa,
+          std::vector<llama_ubatch>   ubatches);
+
+    ~llama_memory_hybrid_iswa_context() = default;
+
+    bool next()  override;
+    bool apply() override;
+
+    llama_memory_status  get_status() const override;
+    const llama_ubatch & get_ubatch() const override;
+
+    //
+    // llama_memory_hybrid_iswa_context
+    //
+
+    const llama_kv_cache_iswa_context * get_attn() const;
+    const llama_memory_recurrent_context * get_recr() const;
+
+private:
+    // the index of the next ubatch to process
+    size_t i_next = 0;
+
+    std::vector<llama_ubatch> ubatches;
+
+    const llama_memory_context_ptr ctx_attn;
+    const llama_memory_context_ptr ctx_recr;
+
+    const llama_memory_status status;
+};

+ 39 - 17
src/llama-model.cpp

@@ -8,6 +8,7 @@
 #include "llama-kv-cache.h"
 #include "llama-kv-cache.h"
 #include "llama-kv-cache-iswa.h"
 #include "llama-kv-cache-iswa.h"
 #include "llama-memory-hybrid.h"
 #include "llama-memory-hybrid.h"
+#include "llama-memory-hybrid-iswa.h"
 #include "llama-memory-recurrent.h"
 #include "llama-memory-recurrent.h"
 
 
 #include "ggml-cpp.h"
 #include "ggml-cpp.h"
@@ -7528,23 +7529,44 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                         };
                         };
                     }
                     }
 
 
-                    res = new llama_memory_hybrid(
-                        /* model             */ *this,
-                        /* attn_type_k       */ params.type_k,
-                        /* attn_type_v       */ params.type_v,
-                        /* attn_v_trans      */ !cparams.flash_attn,
-                        /* attn_kv_size      */ cparams.n_ctx,
-                        /* attn_n_pad        */ 1,
-                        /* attn_n_swa        */ hparams.n_swa,
-                        /* attn_swa_type     */ hparams.swa_type,
-                        /* recurrent_type_k  */ GGML_TYPE_F32,
-                        /* recurrent_type_v  */ GGML_TYPE_F32,
-                        /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
-                        /* n_seq_max         */ cparams.n_seq_max,
-                        /* offload           */ cparams.offload_kqv,
-                        /* unified           */ cparams.kv_unified,
-                        /* filter_attn       */ std::move(filter_attn),
-                        /* filter_recr       */ std::move(filter_recr));
+                    if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
+                        // Use hybrid-iswa for hybrid models with SWA
+                        res = new llama_memory_hybrid_iswa(
+                            /* model             */ *this,
+                            /* attn_type_k       */ params.type_k,
+                            /* attn_type_v       */ params.type_v,
+                            /* attn_v_trans      */ !cparams.flash_attn,
+                            /* attn_swa_full     */ params.swa_full,
+                            /* attn_kv_size      */ cparams.n_ctx,
+                            /* attn_n_ubatch     */ cparams.n_ubatch,
+                            /* attn_n_pad        */ 1,
+                            /* recurrent_type_r  */ GGML_TYPE_F32,
+                            /* recurrent_type_s  */ GGML_TYPE_F32,
+                            /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max),
+                            /* n_seq_max         */ cparams.n_seq_max,
+                            /* offload           */ cparams.offload_kqv,
+                            /* unified           */ cparams.kv_unified,
+                            /* filter_attn       */ std::move(filter_attn),
+                            /* filter_recr       */ std::move(filter_recr));
+                    } else {
+                        res = new llama_memory_hybrid(
+                            /* model             */ *this,
+                            /* attn_type_k       */ params.type_k,
+                            /* attn_type_v       */ params.type_v,
+                            /* attn_v_trans      */ !cparams.flash_attn,
+                            /* attn_kv_size      */ cparams.n_ctx,
+                            /* attn_n_pad        */ 1,
+                            /* attn_n_swa        */ hparams.n_swa,
+                            /* attn_swa_type     */ hparams.swa_type,
+                            /* recurrent_type_k  */ GGML_TYPE_F32,
+                            /* recurrent_type_v  */ GGML_TYPE_F32,
+                            /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
+                            /* n_seq_max         */ cparams.n_seq_max,
+                            /* offload           */ cparams.offload_kqv,
+                            /* unified           */ cparams.kv_unified,
+                            /* filter_attn       */ std::move(filter_attn),
+                            /* filter_recr       */ std::move(filter_recr));
+                    }
                 } else {
                 } else {
                     llama_memory_i::layer_reuse_cb reuse = nullptr;
                     llama_memory_i::layer_reuse_cb reuse = nullptr;