Selaa lähdekoodia

llama : use n_swa + n_ubatch cells for SWA cache (#13833)

* llama : use n_swa + n_ubatch cells for SWA cache

ggml-ci

* llama : add warning about multi-sqeuence SWA contexts
Georgi Gerganov 7 kuukautta sitten
vanhempi
sitoutus
3600cc2886
6 muutettua tiedostoa jossa 24 lisäystä ja 11 poistoa
  1. 3 0
      include/llama.h
  2. 5 0
      src/llama-context.cpp
  3. 2 2
      src/llama-kv-cache.cpp
  4. 1 1
      src/llama-kv-cache.h
  5. 5 1
      src/llama-model.cpp
  6. 8 7
      tools/server/server.cpp

+ 3 - 0
include/llama.h

@@ -366,6 +366,8 @@ extern "C" {
         bool no_perf;     // measure performance timings
         bool op_offload;  // offload host tensor operations to device
         bool swa_full;    // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
+                          // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
+                          //       ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
     };
 
     // model quantization parameters
@@ -502,6 +504,7 @@ extern "C" {
     LLAMA_API int32_t llama_model_n_layer    (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_head     (const struct llama_model * model);
     LLAMA_API int32_t llama_model_n_head_kv  (const struct llama_model * model);
+    LLAMA_API int32_t llama_model_n_swa      (const struct llama_model * model);
 
     // Get the model's RoPE frequency scaling factor
     LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);

+ 5 - 0
src/llama-context.cpp

@@ -123,6 +123,11 @@ llama_context::llama_context(
                 __func__, n_ctx_per_seq, hparams.n_ctx_train);
     }
 
+    if (!params.swa_full && cparams.n_seq_max > 1) {
+        LLAMA_LOG_WARN("%s: requested n_seq_max (%u) > 1, but swa_full is not enabled -- performance may be degraded: %s\n",
+                __func__, cparams.n_seq_max, "https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573");
+    }
+
     if (!hparams.vocab_only) {
         // GPU backends
         for (auto * dev : model.devices) {

+ 2 - 2
src/llama-kv-cache.cpp

@@ -1731,14 +1731,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
                      bool   swa_full,
                  uint32_t   kv_size,
                  uint32_t   n_seq_max,
-                 uint32_t   n_batch,
+                 uint32_t   n_ubatch,
                  uint32_t   n_pad) : hparams(model.hparams) {
     llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
     llama_kv_cache_unified::layer_filter_cb filter_swa  = [&](int32_t il) { return  model.hparams.is_swa(il); };
 
     const uint32_t size_base = kv_size;
 
-    uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad));
+    uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
 
     // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
     if (swa_full) {

+ 1 - 1
src/llama-kv-cache.h

@@ -339,7 +339,7 @@ public:
                          bool   swa_full,
                      uint32_t   kv_size,
                      uint32_t   n_seq_max,
-                     uint32_t   n_batch,
+                     uint32_t   n_ubatch,
                      uint32_t   n_pad);
 
     ~llama_kv_cache_unified_iswa() = default;

+ 5 - 1
src/llama-model.cpp

@@ -13230,7 +13230,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
                             params.swa_full,
                             cparams.n_ctx,
                             cparams.n_seq_max,
-                            cparams.n_batch,
+                            cparams.n_ubatch,
                             padding);
                 } else {
                     GGML_ASSERT(!hparams.is_swa_any());
@@ -13593,6 +13593,10 @@ int32_t llama_model_n_head_kv(const llama_model * model) {
     return model->hparams.n_head_kv();
 }
 
+int32_t llama_model_n_swa(const llama_model * model) {
+    return model->hparams.n_swa;
+}
+
 // deprecated
 int32_t llama_n_ctx_train(const llama_model * model) {
     return llama_model_n_ctx_train(model);

+ 8 - 7
tools/server/server.cpp

@@ -2016,11 +2016,6 @@ struct server_context {
                 params_base.n_cache_reuse = 0;
                 SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled");
             }
-
-            if (!params_base.speculative.model.path.empty()) {
-                SRV_ERR("%s\n", "err: speculative decode is not supported by this context");
-                return false;
-            }
         }
 
         return true;
@@ -3215,8 +3210,14 @@ struct server_context {
 
                             if (slot.n_past > 0 && slot.n_past < (int) slot.cache_tokens.size()) {
                                 const auto pos_min = llama_kv_self_seq_pos_min(ctx, slot.id);
-                                if (pos_min > 0) {
-                                    SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min);
+                                if (pos_min == -1) {
+                                    SLT_ERR(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min);
+                                    GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237");
+                                }
+
+                                const auto n_swa = llama_model_n_swa(model);
+                                if (pos_min > slot.n_past - n_swa) {
+                                    SLT_WRN(slot, "n_past = %d, cache_tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", slot.n_past, (int) slot.cache_tokens.size(), slot.id, pos_min, n_swa);
                                     SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA, see %s)\n",
                                             "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055");
                                     llama_kv_self_seq_rm(ctx, slot.id, 0, -1);