Просмотр исходного кода

kv-cache : add SWA support (#13194)

* kv-cache : prepare for SWA

ggml-ci

* kv-cache : initial iSWA implementation

ggml-ci

* kv-cache : rework error recovery logic

ggml-ci

* models : fix Phi-3 SWA parameters

ggml-ci

* model : adjust Granite to rope factor changes

ggml-ci

* server : check if context can do shifts

ggml-ci

* iswa : for now, always enable shifts (experiment)

ggml-ci

* kv-cache : simplify SWA logic

ggml-ci

* kv-cache : apply defrag when we fail to find slots for the batch

ggml-ci

* llama : update docs about llama_decode

ggml-ci

* kv-cache : update warning logs when no space for the batch is available

ggml-ci

* llama : add llama_kv_self_seq_pos_min()

* kv-cache : keep track of partial SWA computes and print warnings

* server : disallow use cases involving partial SWA context

ggml-ci

* llama : add param to control SWA cache size

ggml-ci

* minor : clean-up

ggml-ci
Georgi Gerganov 8 месяцев назад
Родитель
Сommit
e298d2fbd0

+ 8 - 0
common/arg.cpp

@@ -1445,6 +1445,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.n_keep = value;
         }
     ));
+    add_opt(common_arg(
+        {"--swa-full"},
+        string_format("use full-size SWA cache (default: %s)\n"
+            "[(more info)](https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)", params.swa_full ? "true" : "false"),
+        [](common_params & params) {
+            params.swa_full = true;
+        }
+    ));
     add_opt(common_arg(
         {"--no-context-shift"},
         string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),

+ 1 - 0
common/common.cpp

@@ -1136,6 +1136,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
     cparams.flash_attn        = params.flash_attn;
     cparams.no_perf           = params.no_perf;
     cparams.op_offload        = !params.no_op_offload;
+    cparams.swa_full          = params.swa_full;
 
     if (params.reranking) {
         cparams.embeddings    = true;

+ 1 - 0
common/common.h

@@ -323,6 +323,7 @@ struct common_params {
     bool flash_attn        = false; // flash attention
     bool no_perf           = false; // disable performance metrics
     bool ctx_shift         = true;  // context shift on inifinite text generation
+    bool swa_full          = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
 
     bool input_prefix_bos  = false; // prefix BOS to user inputs, preceding input_prefix
     bool use_mmap          = true;  // use mmap for faster loads

+ 20 - 8
include/llama.h

@@ -361,10 +361,11 @@ extern "C" {
 
         // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
         bool embeddings;  // if true, extract embeddings (together with logits)
-        bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
-        bool flash_attn;  // whether to use flash attention [EXPERIMENTAL]
-        bool no_perf;     // whether to measure performance timings
-        bool op_offload;  // whether to offload host tensor operations to device
+        bool offload_kqv; // offload the KQV ops (including the KV cache) to GPU
+        bool flash_attn;  // use flash attention [EXPERIMENTAL]
+        bool no_perf;     // measure performance timings
+        bool op_offload;  // offload host tensor operations to device
+        bool swa_full;    // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
     };
 
     // model quantization parameters
@@ -730,10 +731,18 @@ extern "C" {
                        llama_pos   p1,
                              int   d);
 
+    // Returns the smallest position present in the KV cache for the specified sequence
+    // This is typically non-zero only for SWA caches
+    // Return -1 if the sequence is empty
+    LLAMA_API llama_pos llama_kv_self_seq_pos_min(
+            struct llama_context * ctx,
+                    llama_seq_id   seq_id);
+
     // Returns the largest position present in the KV cache for the specified sequence
+    // Return -1 if the sequence is empty
     LLAMA_API llama_pos llama_kv_self_seq_pos_max(
             struct llama_context * ctx,
-                     llama_seq_id   seq_id);
+                    llama_seq_id   seq_id);
 
     // Defragment the KV cache
     // This will be applied:
@@ -943,9 +952,12 @@ extern "C" {
     // Requires KV cache.
     // For encode-decoder contexts, processes the batch using the decoder.
     // Positive return values does not mean a fatal error, but rather a warning.
-    //   0 - success
-    //   1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
-    // < 0 - error. the KV cache state is restored to the state before this call
+    // Upon non-zero return values, the KV cache state is restored to the state before this call
+    //    0 - success
+    //    1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
+    //    2 - aborted
+    //   -1 - invalid input batch
+    // < -1 - error
     LLAMA_API int32_t llama_decode(
             struct llama_context * ctx,
               struct llama_batch   batch);

+ 30 - 6
src/llama-context.cpp

@@ -93,6 +93,7 @@ llama_context::llama_context(
     }
 
     cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
+
     cparams.op_offload = params.op_offload;
 
     const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
@@ -176,8 +177,9 @@ llama_context::llama_context(
     // init the memory module
     if (!hparams.vocab_only) {
         llama_memory_params params_mem = {
-            /*.type_k =*/ params.type_k,
-            /*.type_v =*/ params.type_v,
+            /*.type_k   =*/ params.type_k,
+            /*.type_v   =*/ params.type_v,
+            /*.swa_full =*/ params.swa_full,
         };
 
         memory.reset(model.create_memory(params_mem, cparams));
@@ -947,8 +949,6 @@ int llama_context::decode(llama_batch & inp_batch) {
 
         // find KV slot
         if (!kv_self->find_slot(ubatch)) {
-            LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
-
             return 1;
         }
 
@@ -2093,6 +2093,7 @@ llama_context_params llama_context_default_params() {
         /*.flash_attn                  =*/ false,
         /*.no_perf                     =*/ true,
         /*.op_offload                  =*/ true,
+        /*.swa_full                    =*/ true,
     };
 
     return result;
@@ -2467,6 +2468,15 @@ void llama_kv_self_seq_div(
     kv->seq_div(seq_id, p0, p1, d);
 }
 
+llama_pos llama_kv_self_seq_pos_min(llama_context * ctx, llama_seq_id seq_id) {
+    const auto * kv = ctx->get_kv_self();
+    if (!kv) {
+        return -1;
+    }
+
+    return kv->seq_pos_min(seq_id);
+}
+
 // deprecated
 llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
     return llama_kv_self_seq_pos_max(ctx, seq_id);
@@ -2475,7 +2485,7 @@ llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
 llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
     const auto * kv = ctx->get_kv_self();
     if (!kv) {
-        return 0;
+        return -1;
     }
 
     return kv->seq_pos_max(seq_id);
@@ -2637,7 +2647,21 @@ int32_t llama_encode(
 int32_t llama_decode(
         llama_context * ctx,
           llama_batch   batch) {
-    const int ret = ctx->decode(batch);
+    int ret = ctx->decode(batch);
+
+    // defrag and try again
+    // TODO: distinguish return code when we are sure that even after defrag there is no space available
+    if (ret == 1) {
+        llama_kv_self_defrag(ctx);
+        ret = ctx->decode(batch);
+
+        if (ret == 1) {
+            LLAMA_LOG_WARN("%s: failed to find KV cache slot for batch of size %d\n", __func__, batch.n_tokens);
+
+            return ret;
+        }
+    }
+
     if (ret != 0) {
         LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
     }

+ 134 - 233
src/llama-graph.cpp

@@ -9,33 +9,6 @@
 #include <cmath>
 #include <cstring>
 
-static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
-    // TODO move to hparams if a T5 variant appears that uses a different value
-    const int64_t max_distance = 128;
-
-    if (bidirectional) {
-        n_buckets >>= 1;
-    }
-
-    const int64_t max_exact = n_buckets >> 1;
-
-    int32_t relative_position = x - y;
-    int32_t relative_bucket = 0;
-
-    if (bidirectional) {
-        relative_bucket += (relative_position > 0) * n_buckets;
-        relative_position = abs(relative_position);
-    } else {
-        relative_position = -std::min<int32_t>(relative_position, 0);
-    }
-
-    int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
-    relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
-    relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
-
-    return relative_bucket;
-}
-
 void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
     if (ubatch->token) {
         const int64_t n_tokens = ubatch->n_tokens;
@@ -110,22 +83,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
 
 void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
     if (pos_bucket) {
-        const int64_t n_tokens = ubatch->n_tokens;
-
-        GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer));
-        GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
-
-        int32_t * data = (int32_t *) pos_bucket->data;
-
-        const int64_t n_kv = kv_self->n;
-
-        for (int h = 0; h < 1; ++h) {
-            for (int j = 0; j < n_tokens; ++j) {
-                for (int i = 0; i < n_kv; ++i) {
-                    data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(kv_self->cells[i].pos, ubatch->pos[j], hparams.n_rel_attn_bkts, false);
-                }
-            }
-        }
+        kv_self->set_input_pos_bucket(pos_bucket, ubatch);
     }
 }
 
@@ -403,99 +361,18 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
 }
 
 void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
-    if (self_kq_mask || self_kq_mask_swa) {
-        const int64_t n_kv         = kv_self->n;
-        const int64_t n_tokens     = ubatch->n_tokens;
-        const int64_t n_seq_tokens = ubatch->n_seq_tokens;
-        const int64_t n_seqs       = ubatch->n_seqs;
-
-        float * data     = nullptr;
-        float * data_swa = nullptr;
-
-        if (self_kq_mask) {
-            GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
-            data = (float *) self_kq_mask->data;
-        }
-
-        if (self_kq_mask_swa) {
-            GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
-            data_swa = (float *) self_kq_mask_swa->data;
-        }
-
-        // Use only the previous KV cells of the correct sequence for each token of the ubatch.
-        // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
-        // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
-        //   Causal mask:
-        //      xxx-------
-        //      xxxx------
-        //      xxxxx-----
-        //   Non-causal mask:
-        //      xxxxx-----
-        //      xxxxx-----
-        //      xxxxx-----
-        // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
-        for (int h = 0; h < 1; ++h) {
-            for (int s = 0; s < n_seqs; ++s) {
-                const llama_seq_id seq_id = ubatch->seq_id[s][0];
-
-                for (int j = 0; j < n_seq_tokens; ++j) {
-                    const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
-                    for (int i = 0; i < n_kv; ++i) {
-                        float f;
-                        // mask the token if:
-                        if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
-                            || (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
-                        ) {
-                            f = -INFINITY;
-                        } else {
-                            if (hparams.use_alibi) {
-                                f = -std::abs(kv_self->cells[i].pos - pos);
-                            } else {
-                                f = 0.0f;
-                            }
-                        }
-
-                        if (data) {
-                            data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
-                        }
-
-                        // may need to cut off old tokens for sliding window
-                        // TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
-                        if (data_swa) {
-                            if (hparams.n_attn_chunk) {
-                                llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
-                                if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
-                                    f = -INFINITY;
-                                }
-                            } else {
-                                if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
-                                    f = -INFINITY;
-                                }
-                            }
-                            data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
-                        }
-                    }
-                }
-            }
+    if (self_kq_mask) {
+        kv_self->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
+    }
+}
 
-            // mask padded tokens
-            if (data) {
-                for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                    for (int j = 0; j < n_kv; ++j) {
-                        data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
-                    }
-                }
-            }
+void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
+    if (self_kq_mask) {
+        kv_self->get_kv_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
+    }
 
-            // mask padded tokens
-            if (data_swa) {
-                for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
-                    for (int j = 0; j < n_kv; ++j) {
-                        data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
-                    }
-                }
-            }
-        }
+    if (self_kq_mask_swa) {
+        kv_self->get_kv_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
     }
 }
 
@@ -545,7 +422,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
     n_layer          (hparams.n_layer),
     n_rot            (hparams.n_rot),
     n_ctx            (cparams.n_ctx),
-    n_ctx_per_seq    (cparams.n_ctx / cparams.n_seq_max),
     n_head           (hparams.n_head()),
     n_head_kv        (hparams.n_head_kv()),
     n_embd_head_k    (hparams.n_embd_head_k),
@@ -1153,7 +1029,7 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
 
     auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, kv_self);
 
-    const auto n_kv = kv_self->n;
+    const auto n_kv = kv_self->get_n();
 
     auto & cur = inp->pos_bucket;
 
@@ -1188,16 +1064,12 @@ ggml_tensor * llm_graph_context::build_attn_mha(
          ggml_tensor * kq_b,
          ggml_tensor * kq_mask,
          ggml_tensor * v_mla,
-             bool      v_trans,
              float     kq_scale) const {
-  //const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-  //const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
-
-  //const int64_t n_head    = hparams.n_head(il);
-  //const int64_t n_head_kv = hparams.n_head_kv(il);
+    const bool v_trans = v->nb[1] > v->nb[2];
 
-  //const auto & n_embd_head_k = hparams.n_embd_head_k;
-  //const auto & n_embd_head_v = hparams.n_embd_head_v;
+    q = ggml_permute(ctx0, q, 0, 2, 1, 3);
+    k = ggml_permute(ctx0, k, 0, 2, 1, 3);
+    v = ggml_permute(ctx0, v, 0, 2, 1, 3);
 
     const auto n_tokens = q->ne[1];
     const auto n_head   = q->ne[2];
@@ -1336,17 +1208,11 @@ ggml_tensor * llm_graph_context::build_attn(
 
     const auto & kq_mask = inp->get_kq_mask();
 
-    ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
-    //cb(q, "q", il);
-
-    ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
-    //cb(k, "k", il);
-
-    ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
-    //cb(k, "v", il);
-
-    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
+    ggml_tensor * q = q_cur;
+    ggml_tensor * k = k_cur;
+    ggml_tensor * v = v_cur;
 
+    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1369,22 +1235,17 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
 
     auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
 
-    const auto n_kv = kv_self->n;
-
-    inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-    //cb(inp->self_kq_mask, "KQ_mask", -1);
-    ggml_set_input(inp->self_kq_mask);
-
-    inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
+    {
+        GGML_ASSERT(hparams.n_swa_pattern == 1 && "Use llama_kv_cache_unified_iswa for SWA");
+        GGML_ASSERT(hparams.n_swa == 0         && "Use llama_kv_cache_unified_iswa for SWA");
 
-    if (hparams.n_swa_pattern > 1) {
-        GGML_ASSERT(hparams.n_swa > 0);
+        const auto n_kv = kv_self->get_n();
 
-        inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
-        //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
-        ggml_set_input(inp->self_kq_mask_swa);
+        inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        //cb(inp->self_kq_mask, "KQ_mask", -1);
+        ggml_set_input(inp->self_kq_mask);
 
-        inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
+        inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
     }
 
     return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
@@ -1409,81 +1270,100 @@ ggml_tensor * llm_graph_context::build_attn(
     ggml_build_forward_expand(gf, v_cur);
 
     const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
-    const auto & n_ctx = cparams.n_ctx;
 
-    const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
-    const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
+    // store to KV cache
+    {
+        ggml_build_forward_expand(gf, kv_self->cpy_k(ctx0, k_cur, il));
+        ggml_build_forward_expand(gf, kv_self->cpy_v(ctx0, v_cur, il));
+    }
 
-    const auto n_tokens = q_cur->ne[2];
+    const auto & kq_mask = inp->get_kq_mask();
 
-    const bool v_trans = !cparams.flash_attn;
+    ggml_tensor * q = q_cur;
+    ggml_tensor * k = kv_self->get_k(ctx0, il);
+    ggml_tensor * v = kv_self->get_v(ctx0, il);
 
-    // store to KV cache
-    {
-        const auto kv_head = kv_self->head;
+    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
+    cb(cur, "kqv_out", il);
 
-        GGML_ASSERT(kv_self->size == n_ctx);
+    if (wo) {
+        cur = build_lora_mm(wo, cur);
+    }
 
-        ggml_tensor * k_cache_view = ggml_view_1d(ctx0, kv_self->k_l[il], n_tokens*n_embd_k_gqa, ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa)*kv_head);
-        //cb(k_cache_view, "k_cache_view", il);
+    if (wo_b) {
+        cur = ggml_add(ctx0, cur, wo_b);
+    }
 
-        // note: storing RoPE-ed version of K in the KV cache
-        ggml_build_forward_expand(gf, ggml_cpy(ctx0, k_cur, k_cache_view));
+    return cur;
+}
 
-        v_cur = ggml_reshape_2d(ctx0, v_cur, n_embd_v_gqa, n_tokens);
+llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
+    const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
 
-        ggml_tensor * v_cache_view = nullptr;
+    auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_self);
 
-        if (!v_trans) {
-            v_cache_view = ggml_view_1d(ctx0, kv_self->v_l[il], n_tokens*n_embd_v_gqa, ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa)*kv_head);
-        } else {
-            // note: the V cache is transposed when not using flash attention
-            v_cache_view = ggml_view_2d(ctx0, kv_self->v_l[il], n_tokens, n_embd_v_gqa,
-                    (  n_ctx)*ggml_element_size(kv_self->v_l[il]),
-                    (kv_head)*ggml_element_size(kv_self->v_l[il]));
+    {
+        const auto n_kv = kv_self->get_kv_base()->get_n();
 
-            v_cur = ggml_transpose(ctx0, v_cur);
-        }
-        //cb(v_cache_view, "v_cache_view", il);
+        inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        //cb(inp->self_kq_mask, "KQ_mask", -1);
+        ggml_set_input(inp->self_kq_mask);
+
+        inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
+    }
 
-        ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
+    if (hparams.n_swa_pattern > 1) {
+        GGML_ASSERT(hparams.n_swa > 0          && "Use llama_kv_cache_unified for non-SWA");
+
+        const auto n_kv = kv_self->get_kv_swa()->get_n();
+
+        inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
+        //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
+        ggml_set_input(inp->self_kq_mask_swa);
+
+        inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
     }
 
+    return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
+}
+
+ggml_tensor * llm_graph_context::build_attn(
+        llm_graph_input_attn_kv_unified_iswa * inp,
+        ggml_cgraph * gf,
+        ggml_tensor * wo,
+        ggml_tensor * wo_b,
+        ggml_tensor * q_cur,
+        ggml_tensor * k_cur,
+        ggml_tensor * v_cur,
+        ggml_tensor * kq_b,
+        ggml_tensor * v_mla,
+            float     kq_scale,
+            int       il) const {
+    // these nodes are added to the graph together so that they are not reordered
+    // by doing so, the number of splits in the graph is reduced
+    ggml_build_forward_expand(gf, q_cur);
+    ggml_build_forward_expand(gf, k_cur);
+    ggml_build_forward_expand(gf, v_cur);
+
     const bool is_swa = hparams.is_swa(il);
 
+    const llama_kv_cache_unified_iswa * kv_self = static_cast<const llama_kv_cache_unified_iswa *>(memory);
+
+    const auto * kv = is_swa ? kv_self->get_kv_swa() : kv_self->get_kv_base();
+
+    // store to KV cache
+    {
+        ggml_build_forward_expand(gf, kv->cpy_k(ctx0, k_cur, il));
+        ggml_build_forward_expand(gf, kv->cpy_v(ctx0, v_cur, il));
+    }
+
     const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
 
-    const auto n_kv = kv_self->n;
+    ggml_tensor * q = q_cur;
+    ggml_tensor * k = kv->get_k(ctx0, il);
+    ggml_tensor * v = kv->get_v(ctx0, il);
 
-    const int64_t n_head_kv = hparams.n_head_kv(il);
-
-    const auto & n_embd_head_k = hparams.n_embd_head_k;
-    const auto & n_embd_head_v = hparams.n_embd_head_v;
-
-    ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
-    //cb(q, "q", il);
-
-    ggml_tensor * k =
-        ggml_view_3d(ctx0, kv_self->k_l[il],
-                n_embd_head_k, n_kv, n_head_kv,
-                ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
-                ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
-                0);
-    //cb(k, "k", il);
-
-    ggml_tensor * v = !v_trans ?
-        ggml_view_3d(ctx0, kv_self->v_l[il],
-                n_embd_head_v, n_kv, n_head_kv,
-                ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
-                ggml_row_size(kv_self->v_l[il]->type, n_embd_head_v),
-                0) :
-        ggml_view_3d(ctx0, kv_self->v_l[il],
-                n_kv, n_embd_head_v, n_head_kv,
-                ggml_element_size(kv_self->v_l[il])*n_ctx,
-                ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
-                0);
-
-    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
+    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1534,17 +1414,11 @@ ggml_tensor * llm_graph_context::build_attn(
 
     const auto & kq_mask = inp->get_kq_mask_cross();
 
-    ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
-    //cb(q, "q", il);
-
-    ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
-    //cb(k, "k", il);
-
-    ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
-    //cb(k, "v", il);
-
-    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
+    ggml_tensor * q = q_cur;
+    ggml_tensor * k = k_cur;
+    ggml_tensor * v = v_cur;
 
+    ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
     cb(cur, "kqv_out", il);
 
     if (wo) {
@@ -1712,3 +1586,30 @@ void llm_graph_context::build_pooling(
 
     ggml_build_forward_expand(gf, cur);
 }
+
+int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
+    // TODO move to hparams if a T5 variant appears that uses a different value
+    const int64_t max_distance = 128;
+
+    if (bidirectional) {
+        n_buckets >>= 1;
+    }
+
+    const int64_t max_exact = n_buckets >> 1;
+
+    int32_t relative_position = x - y;
+    int32_t relative_bucket = 0;
+
+    if (bidirectional) {
+        relative_bucket += (relative_position > 0) * n_buckets;
+        relative_position = abs(relative_position);
+    } else {
+        relative_position = -std::min<int32_t>(relative_position, 0);
+    }
+
+    int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (n_buckets - max_exact) / log(1.0 * max_distance / max_exact));
+    relative_position_if_large = std::min<int32_t>(relative_position_if_large, n_buckets - 1);
+    relative_bucket += (relative_position < max_exact ? relative_position : relative_position_if_large);
+
+    return relative_bucket;
+}

+ 49 - 7
src/llama-graph.h

@@ -19,6 +19,7 @@ struct llama_cparams;
 
 class llama_memory_i;
 class llama_kv_cache_unified;
+class llama_kv_cache_unified_iswa;
 class llama_kv_cache_recurrent;
 
 // certain models (typically multi-modal) can produce different types of graphs
@@ -255,6 +256,31 @@ public:
 
     void set_input(const llama_ubatch * ubatch) override;
 
+    ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
+
+    ggml_tensor * self_kq_mask     = nullptr; // F32 [n_kv, n_batch]
+    ggml_tensor * self_kq_mask_cnv = nullptr; //     [n_kv, n_batch]
+
+    const llama_hparams & hparams;
+    const llama_cparams & cparams;
+
+    const llama_kv_cache_unified * kv_self;
+};
+
+class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
+public:
+    llm_graph_input_attn_kv_unified_iswa(
+            const llama_hparams & hparams,
+            const llama_cparams & cparams,
+            const llama_kv_cache_unified_iswa * kv_self) :
+        hparams(hparams),
+        cparams(cparams),
+        kv_self(kv_self) {
+    }
+    ~llm_graph_input_attn_kv_unified_iswa() = default;
+
+    void set_input(const llama_ubatch * ubatch) override;
+
     ggml_tensor * get_kq_mask()     const { return self_kq_mask_cnv; }
     ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
 
@@ -266,7 +292,7 @@ public:
     const llama_hparams & hparams;
     const llama_cparams & cparams;
 
-    const llama_kv_cache_unified * kv_self;
+    const llama_kv_cache_unified_iswa * kv_self;
 };
 
 class llm_graph_input_attn_cross : public llm_graph_input_i {
@@ -378,7 +404,6 @@ struct llm_graph_context {
     const int64_t n_layer;
     const int64_t n_rot;
     const int64_t n_ctx;       // user-specified context size (can be different from n_ctx_train)
-    const int64_t n_ctx_per_seq;
     const int64_t n_head;
     const int64_t n_head_kv;
     const int64_t n_embd_head_k;
@@ -507,13 +532,12 @@ struct llm_graph_context {
 
     ggml_tensor * build_attn_mha(
              ggml_cgraph * gf,
-             ggml_tensor * q,     // [n_embd_head_q, n_tokens, n_head_q]
-             ggml_tensor * k,     // [n_embd_head_k, n_tokens, n_head_k]
-             ggml_tensor * v,     // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
+             ggml_tensor * q,       // [n_embd_head_q, n_head_q, n_tokens]
+             ggml_tensor * k,       // [n_embd_head_k, n_head_k, n_tokens]
+             ggml_tensor * v,       // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
              ggml_tensor * kq_b,
              ggml_tensor * kq_mask,
-             ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
-                    bool   v_trans,
+             ggml_tensor * v_mla,   // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
                    float   kq_scale) const;
 
     llm_graph_input_attn_no_cache * build_attn_inp_no_cache() const;
@@ -546,6 +570,21 @@ struct llm_graph_context {
                   float   kq_scale,
                     int   il) const;
 
+    llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
+
+    ggml_tensor * build_attn(
+            llm_graph_input_attn_kv_unified_iswa * inp,
+            ggml_cgraph * gf,
+            ggml_tensor * wo,
+            ggml_tensor * wo_b,
+            ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
+            ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
+            ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
+            ggml_tensor * kq_b,
+            ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
+                  float   kq_scale,
+                    int   il) const;
+
     llm_graph_input_attn_cross * build_attn_inp_cross() const;
 
     ggml_tensor * build_attn(
@@ -596,3 +635,6 @@ struct llm_graph_context {
             ggml_tensor * cls_out,
             ggml_tensor * cls_out_b) const;
 };
+
+// TODO: better name
+int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional);

+ 14 - 5
src/llama-hparams.h

@@ -14,6 +14,12 @@ enum llama_expert_gating_func_type {
     LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
 };
 
+enum llama_swa_type {
+    LLAMA_SWA_TYPE_NONE     = 0,
+    LLAMA_SWA_TYPE_STANDARD = 1,
+    LLAMA_SWA_TYPE_CHUNKED  = 2,
+};
+
 struct llama_hparams_posnet {
     uint32_t n_embd;
     uint32_t n_layer;
@@ -35,8 +41,6 @@ struct llama_hparams {
     uint32_t n_embd_features = 0;
     uint32_t n_layer;
     uint32_t n_rot;
-    uint32_t n_swa = 0; // sliding window attention (SWA)
-    uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
     uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
     uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
     uint32_t n_expert = 0;
@@ -96,6 +100,12 @@ struct llama_hparams {
 
     std::array<int, 4> rope_sections;
 
+    // Sliding Window Attention (SWA)
+    llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
+
+    uint32_t n_swa = 0;         // the size of the sliding window (0 - no SWA)
+    uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention
+
     // for State Space Models
     uint32_t ssm_d_conv  = 0;
     uint32_t ssm_d_inner = 0;
@@ -116,11 +126,10 @@ struct llama_hparams {
     bool causal_attn   = true;
     bool use_alibi     = false;
     bool attn_soft_cap = false;
+    bool use_kq_norm   = true;
 
+    // llama4
     uint32_t n_moe_layer_step        = 0;
-    bool     use_kq_norm             = true;
-    uint32_t n_attn_chunk            = 0;
-    // values below seems to be fixed on llama4
     uint32_t n_no_rope_layer_step    = 4;
     uint32_t n_attn_temp_floor_scale = 8192;
     float    f_attn_temp_scale       = 0.1;

Разница между файлами не показана из-за своего большого размера
+ 442 - 161
src/llama-kv-cache.cpp


+ 198 - 52
src/llama-kv-cache.h

@@ -8,6 +8,7 @@
 #include "ggml-cpp.h"
 
 #include <set>
+#include <unordered_map>
 #include <vector>
 
 struct llama_cparams;
@@ -40,6 +41,9 @@ struct llama_kv_cache : public llama_memory_i {
     // batch processing
     //
 
+    // =============================================================================================================
+    // TODO: refactor  and simplify this
+
     virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0;
 
     // different KV caches require different batch splitting strategies
@@ -48,6 +52,8 @@ struct llama_kv_cache : public llama_memory_i {
     // find an empty slot of size "n_tokens" in the cache
     virtual bool find_slot(const llama_ubatch & batch) = 0;
 
+    // =============================================================================================================
+
     // getters
     virtual int32_t   get_n_tokens()   const = 0;
     virtual int32_t   get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache
@@ -87,38 +93,24 @@ private:
 // llama_kv_cache_unified
 //
 
-// TODO: add notion of max sequences
 class llama_kv_cache_unified : public llama_kv_cache {
 public:
-    struct kv_cell {
-        llama_pos pos   = -1;
-        llama_pos delta =  0;
-
-        std::set<llama_seq_id> seq_id;
-
-        bool has_seq_id(const llama_seq_id & id) const {
-            return seq_id.find(id) != seq_id.end();
-        }
-
-        bool is_empty() const {
-            return seq_id.empty();
-        }
-
-        bool is_same_seq(const kv_cell & other) const {
-            return seq_id == other.seq_id;
-        }
-    };
-
     static uint32_t get_padding(const llama_cparams & cparams);
 
+    // this callback is used to filter out layers that should not be included in the cache
+    using layer_filter_cb = std::function<bool(int32_t il)>;
+
     llama_kv_cache_unified(
-            const llama_model & model,
-                    ggml_type   type_k,
-                    ggml_type   type_v,
-                         bool   v_trans,
-                         bool   offload,
-                     uint32_t   kv_size,
-                     uint32_t   padding);
+            const llama_model &  model,
+              layer_filter_cb && filter,
+                    ggml_type    type_k,
+                    ggml_type    type_v,
+                         bool    v_trans,
+                         bool    offload,
+                     uint32_t    kv_size,
+                     uint32_t    padding,
+                     uint32_t    n_swa,
+               llama_swa_type    swa_type);
 
     ~llama_kv_cache_unified() = default;
 
@@ -130,10 +122,11 @@ public:
 
     bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
     void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
-    void seq_keep(llama_seq_id seq_id) override;
+    void seq_keep(llama_seq_id seq_id)                                                          override;
     void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override;
     void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
 
+    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
     llama_pos seq_pos_max(llama_seq_id seq_id) const override;
 
     //
@@ -150,7 +143,6 @@ public:
     void set_full() override;
 
     llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
-
     llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
 
     // updates the cache head
@@ -169,29 +161,72 @@ public:
     // state write/load
 
     void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
-    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1) override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1)       override;
 
-    uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
-    uint32_t size = 0; // total number of cells, shared across all sequences
-    uint32_t used = 0; // used cells (i.e. at least one seq_id)
+    //
+    // llama_kv_cache_unified specific API
+    //
 
-    // computed before each graph build
-    uint32_t n = 0;
+    uint32_t get_n() const;
+    uint32_t get_size() const;
 
-    std::vector<kv_cell> cells;
+    // get views of the current state of the cache
+    ggml_tensor * get_k(ggml_context * ctx, int32_t il) const;
+    ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
 
-    std::vector<ggml_tensor *> k_l; // per layer
-    std::vector<ggml_tensor *> v_l;
+    // store k_cur and v_cur in the cache based on the current head location
+    ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
+    ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
+
+    void prune_swa(llama_seq_id seq_id, llama_pos pmin, llama_pos pmax);
+
+    void set_input_kq_mask   (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
+    void set_input_k_shift   (ggml_tensor * dst) const;
+    void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
 
 private:
     const llama_model & model;
     const llama_hparams & hparams;
 
+    struct kv_cell {
+        llama_pos pos   = -1;
+        llama_pos delta =  0;
+
+        // TODO: replace with bitset uint64_t
+        std::set<llama_seq_id> seq_id;
+
+        bool has_seq_id(const llama_seq_id & id) const {
+            return seq_id.find(id) != seq_id.end();
+        }
+
+        bool is_empty() const {
+            return seq_id.empty();
+        }
+
+        bool is_same_seq(const kv_cell & other) const {
+            return seq_id == other.seq_id;
+        }
+    };
+
+    struct kv_layer {
+        // layer index in the model
+        // note: can be different from the layer index in the KV cache
+        uint32_t il;
+
+        ggml_tensor * k;
+        ggml_tensor * v;
+    };
+
     bool has_shift = false;
     bool do_defrag = false;
-
     bool v_trans   = true;  // the value tensor is transposed
-    bool can_shift = false;
+
+    uint32_t head = 0; // the location where the batch will be placed in the cache (see find_slot())
+    uint32_t size = 0; // total number of cells, shared across all sequences
+    uint32_t used = 0; // used cells (i.e. at least one seq_id) (TODO: add `struct kv_cells` and keep track automaticallt)
+
+    // computed before each graph build
+    uint32_t n = 0;
 
     // required padding
     uint32_t padding = 1;
@@ -199,9 +234,29 @@ private:
     ggml_type type_k = GGML_TYPE_F16;
     ggml_type type_v = GGML_TYPE_F16;
 
+    // SWA
+    uint32_t n_swa = 0;
+
+    llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
+
     std::vector<ggml_context_ptr>        ctxs;
     std::vector<ggml_backend_buffer_ptr> bufs;
 
+    std::vector<kv_cell>  cells;  // TODO: replace with `struct kv_cells`
+    std::vector<kv_layer> layers;
+
+    // model layer id -> KV cache layer id
+    std::unordered_map<int32_t, int32_t> map_layer_ids;
+
+    // recovery information used to restore the KV cells to their original state in case of a failure
+    struct {
+        void clear() {
+            cells.clear();
+        }
+
+        std::unordered_map<uint32_t, kv_cell> cells;
+    } recovery;
+
     // defrag
     struct {
         std::vector<uint32_t> ids;
@@ -210,17 +265,6 @@ private:
     // return true if cells have been moved
     bool defrag_prepare(int32_t n_max_nodes);
 
-    // commit/restore cache
-    struct slot_range {
-        uint32_t c0 = 0; // note: these are cell indices, not sequence positions
-        uint32_t c1 = 0;
-    };
-
-    // pending cell updates that are not yet committed
-    struct {
-        std::vector<slot_range> ranges;
-    } pending;
-
     // find how many cells are currently in use
     uint32_t cell_max() const;
 
@@ -229,6 +273,8 @@ private:
     size_t size_k_bytes() const;
     size_t size_v_bytes() const;
 
+    bool is_masked_swa(llama_pos p0, llama_pos p1) const;
+
     ggml_tensor * build_rope_shift(
             const llama_cparams & cparams,
                    ggml_context * ctx,
@@ -255,6 +301,106 @@ private:
     bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
 };
 
+//
+// llama_kv_cache_unified_iswa
+//
+
+// utilizes two instances of llama_kv_cache_unified
+//   the first instance is for the non-SWA layers of the model and the second instance is for the SWA layers
+//   upon successful commit, the SWA cache removes old tokens outside the n_swa window
+
+class llama_kv_cache_unified_iswa : public llama_kv_cache {
+public:
+    llama_kv_cache_unified_iswa(
+            const llama_model & model,
+                    ggml_type   type_k,
+                    ggml_type   type_v,
+                         bool   v_trans,
+                         bool   offload,
+                     uint32_t   kv_size,
+                         bool   swa_full,
+                     uint32_t   n_seq_max,
+                     uint32_t   n_batch,
+                     uint32_t   padding);
+
+    ~llama_kv_cache_unified_iswa() = default;
+
+    //
+    // llama_memory_i
+    //
+
+    void clear() override;
+
+    bool seq_rm  (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1) override;
+    void seq_cp  (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
+    void seq_keep(llama_seq_id seq_id)                                                          override;
+    void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override;
+    void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
+
+    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
+    llama_pos seq_pos_max(llama_seq_id seq_id) const override;
+
+    //
+    // llama_kv_cache
+    //
+
+    void restore() override;
+    void commit()  override;
+
+    bool update(llama_context & ctx) override;
+
+    void defrag_sched(float thold) override;
+
+    void set_full() override;
+
+    llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
+    llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
+
+    bool find_slot(const llama_ubatch & batch) override;
+
+    int32_t get_n_tokens()   const override;
+    int32_t get_used_cells() const override;
+
+    // TODO: better data structures to reduce the cost of this operation
+    llama_pos get_pos_max() const override;
+
+    bool get_can_shift() const override;
+
+    // state write/load
+
+    void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
+    void state_read (llama_io_read_i  & io, llama_seq_id seq_id = -1)       override;
+
+    //
+    // llama_kv_cache_unified_iswa specific API
+    //
+
+    llama_kv_cache_unified * get_kv_base() const;
+    llama_kv_cache_unified * get_kv_swa () const;
+
+private:
+    const llama_hparams & hparams;
+
+    bool do_prune = true;
+
+    struct {
+        struct entry {
+            llama_pos pmin;
+            llama_pos pmax;
+        };
+
+        void clear() {
+            pos.clear();
+        }
+
+        // used to perform SWA pruning of old tokens
+        std::unordered_map<llama_seq_id, entry> pos;
+    } pending;
+
+    std::unique_ptr<llama_kv_cache_unified> kv_base;
+    std::unique_ptr<llama_kv_cache_unified> kv_swa;
+};
+
 //
 // llama_kv_cache_recurrent
 //
@@ -302,6 +448,7 @@ public:
     void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) override;
     void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) override;
 
+    llama_pos seq_pos_min(llama_seq_id seq_id) const override;
     llama_pos seq_pos_max(llama_seq_id seq_id) const override;
 
     //
@@ -318,7 +465,6 @@ public:
     void set_full() override;
 
     llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
-
     llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
 
     bool find_slot(const llama_ubatch & batch) override;

+ 3 - 2
src/llama-memory.h

@@ -7,8 +7,8 @@ struct llama_memory_params {
     ggml_type type_k;
     ggml_type type_v;
 
-    // parameters for other types of memory
-    // ...
+    // use full-size SWA cache
+    bool swa_full;
 };
 
 // general concept of LLM memory
@@ -25,6 +25,7 @@ public:
     virtual void seq_add (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, llama_pos delta) = 0;
     virtual void seq_div (llama_seq_id seq_id,                              llama_pos p0, llama_pos p1, int d) = 0;
 
+    virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
     virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
 
     virtual bool get_can_edit() const = 0;

+ 261 - 90
src/llama-model.cpp

@@ -571,9 +571,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
                 ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,  hparams.n_ff_exp);
                 ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP,   hparams.n_moe_layer_step);
+
+                hparams.swa_type      = LLAMA_SWA_TYPE_CHUNKED;
+                hparams.n_swa         = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
                 hparams.n_swa_pattern = 4;    // pattern: 3 chunked - 1 full
-                hparams.n_attn_chunk  = 8192; // should this be a gguf kv? currently it's the same for Scout and Maverick
-                hparams.n_swa = 1; // TODO @ngxson : this is added to trigger the SWA branch (we store the chunked attn mask in the SWA tensor), will need to clean this up later
 
                 switch (hparams.n_expert) {
                     case 16:  type = LLM_TYPE_17B_16E; break;
@@ -855,20 +856,42 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 // for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931
                 if ((hparams.n_layer == 32 || hparams.n_layer == 40) && hparams.n_ctx_train == 4096) {
                     // default value for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct
+                    LLAMA_LOG_WARN("%s: assuming n_swa = 2047 for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct\n", __func__);
+
+                    hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+
                     hparams.n_swa = 2047;
                 } else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) {
                     // default value for Phi-3-mini-128k-instruct
-                    // note: this seems incorrect because the window is bigger than the train context?
-                    hparams.n_swa = 262144;
+                    LLAMA_LOG_WARN("%s: assuming no SWA for Phi-3-mini-128k-instruct\n", __func__);
+
+                    hparams.swa_type = LLAMA_SWA_TYPE_NONE;
+
+                    hparams.n_swa         = hparams.n_ctx_train;
+                    hparams.n_swa_pattern = 1;
                 } else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) {
                     // default value for Phi-3-medium-128k-instruct
-                    // note: this seems incorrect because the window is equal to the train context?
-                    hparams.n_swa = 131072;
+                    LLAMA_LOG_WARN("%s: assuming no SWA for Phi-3-medium-128k-instruct\n", __func__);
+
+                    hparams.swa_type = LLAMA_SWA_TYPE_NONE;
+
+                    hparams.n_swa         = hparams.n_ctx_train;
+                    hparams.n_swa_pattern = 1;
                 }
+
                 bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
                 if (!found_swa && hparams.n_swa == 0) {
                     throw std::runtime_error("invalid value for sliding_window");
                 }
+
+                if (hparams.n_swa > hparams.n_ctx_train) {
+                    LLAMA_LOG_WARN("%s: unexpected n_swa: %d >= %d, disabling SWA\n", __func__, hparams.n_swa, hparams.n_ctx_train);
+
+                    hparams.swa_type = LLAMA_SWA_TYPE_NONE;
+
+                    hparams.n_swa         = hparams.n_ctx_train;
+                    hparams.n_swa_pattern = 1;
+                }
             } break;
         case LLM_ARCH_PHIMOE:
             {
@@ -937,6 +960,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
             } break;
         case LLM_ARCH_GEMMA2:
             {
+                hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
                 hparams.n_swa = 4096; // default value of gemma 2
                 hparams.n_swa_pattern = 2;
                 hparams.attn_soft_cap = true;
@@ -955,6 +979,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
             } break;
         case LLM_ARCH_GEMMA3:
             {
+                hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
                 hparams.n_swa_pattern = 6;
 
                 hparams.rope_freq_base_train_swa  = 10000.0f;
@@ -1039,6 +1064,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
             } break;
         case LLM_ARCH_COHERE2:
             {
+                hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
                 hparams.n_swa_pattern = 4;
 
                 ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
@@ -4489,7 +4515,17 @@ const ggml_tensor * llama_model::get_tensor(const char * name) const {
     return it->second;
 }
 
-ggml_tensor * llama_model::get_rope_factors(uint32_t n_ctx_per_seq, int il) const {
+float llama_model::get_rope_freq_base (const llama_cparams & cparams, int il) const {
+    return hparams.is_swa(il) ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
+}
+
+float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) const {
+    return hparams.is_swa(il) ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
+}
+
+ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const {
+    const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
+
     // choose long/short freq factors based on the context size
     if (layers[il].rope_freqs != nullptr) {
         return layers[il].rope_freqs;
@@ -4517,21 +4553,174 @@ struct llm_build_llama : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
+        auto * inp_attn = build_attn_inp_kv_unified();
+
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
+
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+                cb(cur, "attn_out", il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network (non-MoE)
+            if (model.layers[il].ffn_gate_inp == nullptr) {
+
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_moe_ffn(cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        nullptr,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, true,
+                        false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                        il);
+                cb(cur, "ffn_moe_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_llama_iswa : public llm_graph_context {
+    llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
         // temperature tuning
         ggml_tensor * inp_attn_scale = nullptr;
-        if (arch == LLM_ARCH_LLAMA4) {
-            inp_attn_scale = build_inp_attn_scale();
-        }
+        inp_attn_scale = build_inp_attn_scale();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv_unified_iswa();
 
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
 
-            bool use_rope = arch == LLM_ARCH_LLAMA4
-                ? (il + 1) % hparams.n_no_rope_layer_step != 0
-                : true;
+            const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0;
 
             // norm
             cur = build_norm(inpL,
@@ -4542,7 +4731,7 @@ struct llm_build_llama : public llm_graph_context {
             // self-attention
             {
                 // rope freq factors for llama3; may return nullptr for llama2 and other models
-                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
 
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -4590,7 +4779,7 @@ struct llm_build_llama : public llm_graph_context {
                 cb(Kcur, "Kcur", il);
                 cb(Vcur, "Vcur", il);
 
-                if (arch == LLM_ARCH_LLAMA4 && use_rope && hparams.use_kq_norm) {
+                if (use_rope && hparams.use_kq_norm) {
                     // Llama4TextL2Norm
                     Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps);
                     Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps);
@@ -4614,23 +4803,7 @@ struct llm_build_llama : public llm_graph_context {
             ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
             cb(ffn_inp, "ffn_inp", il);
 
-            // feed-forward network (non-MoE)
-            if (model.layers[il].ffn_gate_inp == nullptr) {
-
-                cur = build_norm(ffn_inp,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = build_ffn(cur,
-                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
-                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
-                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
-                        NULL,
-                        LLM_FFN_SILU, LLM_FFN_PAR, il);
-                cb(cur, "ffn_out", il);
-
-            } else if (arch == LLM_ARCH_LLAMA4) {
+            {
                 // llama4 MoE
                 ggml_tensor * ffn_inp_normed = build_norm(ffn_inp,
                         model.layers[il].ffn_norm, NULL,
@@ -4660,26 +4833,6 @@ struct llm_build_llama : public llm_graph_context {
 
                 cur = ggml_add(ctx0, moe_out, shexp_out);
                 cb(cur, "ffn_moe_out_merged", il);
-
-            } else {
-                // MoE branch
-                cur = build_norm(ffn_inp,
-                        model.layers[il].ffn_norm, NULL,
-                        LLM_NORM_RMS, il);
-                cb(cur, "ffn_norm", il);
-
-                cur = build_moe_ffn(cur,
-                        model.layers[il].ffn_gate_inp,
-                        model.layers[il].ffn_up_exps,
-                        model.layers[il].ffn_gate_exps,
-                        model.layers[il].ffn_down_exps,
-                        nullptr,
-                        n_expert, n_expert_used,
-                        LLM_FFN_SILU, true,
-                        false, 0.0,
-                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
-                        il);
-                cb(cur, "ffn_moe_out", il);
             }
 
             cur = ggml_add(ctx0, cur, ffn_inp);
@@ -4753,7 +4906,7 @@ struct llm_build_deci : public llm_graph_context {
             } else if (n_head > 0) {
                 // self-attention
                 // rope freq factors for llama3; may return nullptr for llama2 and other models
-                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
 
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -7202,8 +7355,8 @@ struct llm_build_phi2 : public llm_graph_context {
     }
 };
 
-struct llm_build_phi3 : public llm_graph_context {
-    llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+struct llm_build_phi3_iswa : public llm_graph_context {
+    llm_build_phi3_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
         const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
 
@@ -7217,7 +7370,7 @@ struct llm_build_phi3 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv_unified_iswa();
 
         for (int il = 0; il < n_layer; ++il) {
             auto * residual = inpL;
@@ -7225,7 +7378,7 @@ struct llm_build_phi3 : public llm_graph_context {
             // self-attention
             {
                 // rope freq factors for 128k context
-                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
 
                 ggml_tensor* attn_norm_output = build_norm(inpL,
                         model.layers[il].attn_norm,
@@ -7977,7 +8130,7 @@ struct llm_build_minicpm3 : public llm_graph_context {
         for (int il = 0; il < n_layer; ++il) {
             ggml_tensor * inpSA = inpL;
 
-            ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+            ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
 
             // norm
             cur = build_norm(inpL,
@@ -8277,8 +8430,8 @@ struct llm_build_gemma : public llm_graph_context {
     }
 };
 
-struct llm_build_gemma2 : public llm_graph_context {
-    llm_build_gemma2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+struct llm_build_gemma2_iswa : public llm_graph_context {
+    llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_k;
 
         ggml_tensor * cur;
@@ -8292,7 +8445,7 @@ struct llm_build_gemma2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv_unified_iswa();
 
         for (int il = 0; il < n_layer; ++il) {
             // norm
@@ -8414,8 +8567,8 @@ struct llm_build_gemma2 : public llm_graph_context {
     }
 };
 
-struct llm_build_gemma3 : public llm_graph_context {
-    llm_build_gemma3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+struct llm_build_gemma3_iswa : public llm_graph_context {
+    llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_k;
 
         ggml_tensor * cur;
@@ -8433,13 +8586,11 @@ struct llm_build_gemma3 : public llm_graph_context {
         ggml_tensor * inp_pos = build_inp_pos();
 
         // TODO: is causal == true correct? might need some changes
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv_unified_iswa();
 
         for (int il = 0; il < n_layer; ++il) {
-            const bool is_swa = hparams.is_swa(il);
-
-            const float freq_base_l  = is_swa ? hparams.rope_freq_base_train_swa  : cparams.rope_freq_base;
-            const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
+            const float freq_base_l  = model.get_rope_freq_base (cparams, il);
+            const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
 
             // norm
             cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
@@ -9016,8 +9167,8 @@ struct llm_build_command_r : public llm_graph_context {
     }
 };
 
-struct llm_build_cohere2 : public llm_graph_context {
-    llm_build_cohere2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+struct llm_build_cohere2_iswa : public llm_graph_context {
+    llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
 
         GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@@ -9032,7 +9183,7 @@ struct llm_build_cohere2 : public llm_graph_context {
         // inp_pos - contains the positions
         ggml_tensor * inp_pos = build_inp_pos();
 
-        auto * inp_attn = build_attn_inp_kv_unified();
+        auto * inp_attn = build_attn_inp_kv_unified_iswa();
 
         for (int il = 0; il < n_layer; ++il) {
             const bool is_swa = hparams.is_swa(il);
@@ -9045,7 +9196,7 @@ struct llm_build_cohere2 : public llm_graph_context {
             // self-attention
             {
                 // rope freq factors for 128k context
-                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
 
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -9983,7 +10134,7 @@ struct llm_build_deepseek : public llm_graph_context {
             // self-attention
             {
                 // rope freq factors for llama3; may return nullptr for llama2 and other models
-                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
 
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -11347,7 +11498,7 @@ struct llm_build_exaone : public llm_graph_context {
             // self-attention
             {
                 // rope freq factors for llama3; may return nullptr for llama2 and other models
-                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
 
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -12263,7 +12414,7 @@ struct llm_build_granite : public llm_graph_context {
                 Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
 
                 if (use_rope) {
-                    ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+                    ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
                     Qcur = ggml_rope_ext(
                             ctx0, Qcur, inp_pos, rope_factors,
                             n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
@@ -12916,7 +13067,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
             // self-attention
             {
                 // rope freq factors for llama3; may return nullptr for llama2 and other models
-                ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il);
+                ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
 
                 // compute Q and K and RoPE them
                 ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
@@ -13068,14 +13219,31 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
 
                 LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
 
-                res = new llama_kv_cache_unified(
-                        *this,
-                        params.type_k,
-                        params.type_v,
-                        !cparams.flash_attn,
-                        cparams.offload_kqv,
-                        cparams.n_ctx,
-                        padding);
+                if (hparams.n_swa > 0) {
+                    res = new llama_kv_cache_unified_iswa(
+                            *this,
+                            params.type_k,
+                            params.type_v,
+                            !cparams.flash_attn,
+                            cparams.offload_kqv,
+                            cparams.n_ctx,
+                            params.swa_full,
+                            cparams.n_seq_max,
+                            cparams.n_batch,
+                            padding);
+                } else {
+                    res = new llama_kv_cache_unified(
+                            *this,
+                            nullptr,
+                            params.type_k,
+                            params.type_v,
+                            !cparams.flash_attn,
+                            cparams.offload_kqv,
+                            cparams.n_ctx,
+                            padding,
+                            hparams.n_swa,
+                            hparams.swa_type);
+                }
             }
     }
 
@@ -13090,11 +13258,14 @@ llm_graph_result_ptr llama_model::build_graph(
 
     switch (arch) {
         case LLM_ARCH_LLAMA:
-        case LLM_ARCH_LLAMA4:
         case LLM_ARCH_MINICPM:
             {
                 llm = std::make_unique<llm_build_llama>(*this, params, gf);
             } break;
+        case LLM_ARCH_LLAMA4:
+            {
+                llm = std::make_unique<llm_build_llama_iswa>(*this, params, gf);
+            } break;
         case LLM_ARCH_DECI:
             {
                 llm = std::make_unique<llm_build_deci>(*this, params, gf);
@@ -13169,7 +13340,7 @@ llm_graph_result_ptr llama_model::build_graph(
         case LLM_ARCH_PHI3:
         case LLM_ARCH_PHIMOE:
             {
-                llm = std::make_unique<llm_build_phi3>(*this, params, gf);
+                llm = std::make_unique<llm_build_phi3_iswa>(*this, params, gf);
             } break;
         case LLM_ARCH_PLAMO:
             {
@@ -13201,11 +13372,11 @@ llm_graph_result_ptr llama_model::build_graph(
             } break;
         case LLM_ARCH_GEMMA2:
             {
-                llm = std::make_unique<llm_build_gemma2>(*this, params, gf);
+                llm = std::make_unique<llm_build_gemma2_iswa>(*this, params, gf);
             } break;
         case LLM_ARCH_GEMMA3:
             {
-                llm = std::make_unique<llm_build_gemma3>(*this, params, gf);
+                llm = std::make_unique<llm_build_gemma3_iswa>(*this, params, gf);
             } break;
         case LLM_ARCH_STARCODER2:
             {
@@ -13225,7 +13396,7 @@ llm_graph_result_ptr llama_model::build_graph(
             } break;
         case LLM_ARCH_COHERE2:
             {
-                llm = std::make_unique<llm_build_cohere2>(*this, params, gf);
+                llm = std::make_unique<llm_build_cohere2_iswa>(*this, params, gf);
             } break;
         case LLM_ARCH_DBRX:
             {

+ 4 - 1
src/llama-model.h

@@ -398,7 +398,10 @@ struct llama_model {
 
     const struct ggml_tensor * get_tensor(const char * name) const;
 
-    ggml_tensor * get_rope_factors(uint32_t n_ctx_per_seq, int il) const;
+    float get_rope_freq_base (const llama_cparams & cparams, int il) const;
+    float get_rope_freq_scale(const llama_cparams & cparams, int il) const;
+
+    ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const;
 
     // note: can mutate `cparams`
     // TODO: move this to new llm_arch_model_i interface

+ 1 - 0
tools/llama-bench/llama-bench.cpp

@@ -991,6 +991,7 @@ struct cmd_params_instance {
         cparams.flash_attn   = flash_attn;
         cparams.embeddings   = embeddings;
         cparams.op_offload   = !no_op_offload;
+        cparams.swa_full     = false;
 
         return cparams;
     }

+ 26 - 1
tools/server/server.cpp

@@ -2004,6 +2004,23 @@ struct server_context {
             }
         }
 
+        if (!llama_kv_self_can_shift(ctx)) {
+            if (params_base.ctx_shift) {
+                params_base.ctx_shift = false;
+                SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled");
+            }
+
+            if (params_base.n_cache_reuse) {
+                params_base.n_cache_reuse = 0;
+                SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled");
+            }
+
+            if (!params_base.speculative.model.path.empty()) {
+                SRV_ERR("%s\n", "err: speculative decode is not supported by this context");
+                return false;
+            }
+        }
+
         return true;
     }
 
@@ -3181,7 +3198,15 @@ struct server_context {
                                 // if we don't cache the prompt, we have to remove the entire KV cache
                                 llama_kv_self_seq_rm(ctx, slot.id, 0, -1);
                                 slot.n_past = 0;
-                                slot.cache_tokens.clear();
+                                slot.cache_tokens.clear(); // TODO: not needed, will be cleared later via "keep_first()"
+                            }
+
+                            if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
+                                if (llama_kv_self_seq_pos_min(ctx, slot.id) > 0) {
+                                    SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
+                                            "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
+                                    slot.n_past = 0;
+                                }
                             }
                         }
 

Некоторые файлы не были показаны из-за большого количества измененных файлов