Quellcode durchsuchen

kv-cache: Fix state restore fragmented cache (#17982)

* kv-cache : fix state restore with fragmented cache (#17527)

Change find_slot to allow non-contiguous allocation during state restore. Fixes 'failed to find available cells in kv cache' error when restoring state to fragmented cache.

* tests : update logic

* cleanup: tightened state_read_meta sig, added is_contiguous case

* fix: state_read_meta arg reorder loose ends

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
ssweens vor 1 Monat
Ursprung
Commit
4529c660c8
4 geänderte Dateien mit 213 neuen und 29 gelöschten Zeilen
  1. 64 27
      src/llama-kv-cache.cpp
  2. 19 2
      src/llama-kv-cache.h
  3. 8 0
      tests/CMakeLists.txt
  4. 122 0
      tests/test-state-restore-fragmented.cpp

+ 64 - 27
src/llama-kv-cache.cpp

@@ -1561,9 +1561,11 @@ void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama
 
 
         const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
         const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id];
 
 
+        slot_info sinfo;
+
         bool res = true;
         bool res = true;
-        res = res && state_read_meta(io, strm, cell_count, seq_id);
-        res = res && state_read_data(io, strm, cell_count);
+        res = res && state_read_meta(io, strm, cell_count, sinfo, seq_id);
+        res = res && state_read_data(io, strm, cell_count, sinfo);
 
 
         if (!res) {
         if (!res) {
             if (seq_id == -1) {
             if (seq_id == -1) {
@@ -1702,7 +1704,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t
     }
     }
 }
 }
 
 
-bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) {
+bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, slot_info & sinfo, llama_seq_id dest_seq_id) {
     auto & cells = v_cells[strm];
     auto & cells = v_cells[strm];
     auto & head  = v_heads[strm];
     auto & head  = v_heads[strm];
 
 
@@ -1739,7 +1741,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
             ubatch.seq_id[i]   = &dest_seq_id;
             ubatch.seq_id[i]   = &dest_seq_id;
         }
         }
 
 
-        const auto sinfo = find_slot(ubatch, true);
+        sinfo = find_slot(ubatch, false);
         if (sinfo.empty()) {
         if (sinfo.empty()) {
             LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
             LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
             return false;
             return false;
@@ -1749,20 +1751,16 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
         //       see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
         //       see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350
         apply_ubatch(sinfo, ubatch);
         apply_ubatch(sinfo, ubatch);
 
 
-        const auto head_cur = sinfo.head();
-
-        // keep the head at the old position because we will read the KV data into it in state_read_data()
-        head = head_cur;
-
-        LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id);
+        LLAMA_LOG_DEBUG("%s: cell_count = %d, dest_seq_id = %d\n", __func__, cell_count, dest_seq_id);
 
 
-        // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
-        // Assume that this is one contiguous block of cells
-        GGML_ASSERT(head_cur + cell_count <= cells.size());
-        GGML_ASSERT(cells.pos_get(head_cur)                  == ubatch.pos[0]);
-        GGML_ASSERT(cells.pos_get(head_cur + cell_count - 1) == ubatch.pos[cell_count - 1]);
-        GGML_ASSERT(cells.seq_has(head_cur,                  dest_seq_id));
-        GGML_ASSERT(cells.seq_has(head_cur + cell_count - 1, dest_seq_id));
+        // DEBUG CHECK: verify that all cells were allocated and have correct seq_id and pos values
+        GGML_ASSERT(sinfo.n_stream() == 1);
+        GGML_ASSERT(sinfo.idxs[0].size() == cell_count);
+        for (uint32_t i = 0; i < cell_count; ++i) {
+            const uint32_t idx = sinfo.idxs[0][i];
+            GGML_ASSERT(cells.pos_get(idx) == ubatch.pos[i]);
+            GGML_ASSERT(cells.seq_has(idx, dest_seq_id));
+        }
     } else {
     } else {
         // whole KV cache restore
         // whole KV cache restore
 
 
@@ -1795,15 +1793,24 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
             }
             }
         }
         }
 
 
+        // Create contiguous slot_info for whole cache restore
+        sinfo.s0 = strm;
+        sinfo.s1 = strm;
+        sinfo.resize(1);
+        sinfo.strm[0] = strm;
+        sinfo.idxs[0].resize(cell_count);
+        for (uint32_t i = 0; i < cell_count; ++i) {
+            sinfo.idxs[0][i] = i;
+        }
+
         head = 0;
         head = 0;
     }
     }
 
 
     return true;
     return true;
 }
 }
 
 
-bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) {
+bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo) {
     auto & cells = v_cells[strm];
     auto & cells = v_cells[strm];
-    auto & head  = v_heads[strm];
 
 
     uint32_t v_trans;
     uint32_t v_trans;
     uint32_t n_layer;
     uint32_t n_layer;
@@ -1853,8 +1860,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
         }
         }
 
 
         if (cell_count) {
         if (cell_count) {
-            // Read and set the keys for the whole cell range
-            ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row);
+            if (sinfo.is_contiguous()) {
+                // Fast path: contiguous cells, single memcpy
+                ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row);
+            } else {
+                // Slow path: scatter to non-contiguous positions
+                const void * src = io.read(cell_count * k_size_row);
+                for (uint32_t i = 0; i < cell_count; ++i) {
+                    const size_t dst_offset = sinfo.idxs[0][i] * k_size_row;
+                    ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row);
+                }
+            }
         }
         }
     }
     }
 
 
@@ -1885,8 +1901,17 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
             }
             }
 
 
             if (cell_count) {
             if (cell_count) {
-                // Read and set the values for the whole cell range
-                ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row);
+                if (sinfo.is_contiguous()) {
+                    // Fast path: contiguous cells, single memcpy
+                    ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row);
+                } else {
+                    // Slow path: scatter to non-contiguous positions
+                    const void * src = io.read(cell_count * v_size_row);
+                    for (uint32_t i = 0; i < cell_count; ++i) {
+                        const size_t dst_offset = sinfo.idxs[0][i] * v_size_row;
+                        ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row);
+                    }
+                }
             }
             }
         }
         }
     } else {
     } else {
@@ -1925,10 +1950,22 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32
             }
             }
 
 
             if (cell_count) {
             if (cell_count) {
-                // For each row in the transposed matrix, read the values for the whole cell range
-                for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
-                    const size_t dst_offset = (head + j * cells.size()) * v_size_el;
-                    ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
+                if (sinfo.is_contiguous()) {
+                    // Fast path: contiguous cells
+                    const uint32_t h = sinfo.head();
+                    for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                        const size_t dst_offset = (h + j * cells.size()) * v_size_el;
+                        ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
+                    }
+                } else {
+                    // Slow path: scatter to non-contiguous positions
+                    for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
+                        const void * src = io.read(cell_count * v_size_el);
+                        for (uint32_t i = 0; i < cell_count; ++i) {
+                            const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el;
+                            ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el);
+                        }
+                    }
                 }
                 }
             }
             }
         }
         }

+ 19 - 2
src/llama-kv-cache.h

@@ -72,6 +72,23 @@ public:
         void clear() {
         void clear() {
             idxs.clear();
             idxs.clear();
         }
         }
+
+        // check if indices are contiguous starting from head()
+        bool is_contiguous() const {
+            if (idxs.empty() || idxs[0].empty()) {
+                return true;
+            }
+            if (idxs.size() > 1) {
+                return false;
+            }
+            const uint32_t h = idxs[0][0];
+            for (size_t i = 0; i < idxs[0].size(); ++i) {
+                if (idxs[0][i] != h + i) {
+                    return false;
+                }
+            }
+            return true;
+        }
     };
     };
 
 
     using slot_info_vec_t = std::vector<slot_info>;
     using slot_info_vec_t = std::vector<slot_info>;
@@ -264,8 +281,8 @@ private:
     void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
     void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
     void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
     void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
 
 
-    bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
-    bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
+    bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count,       slot_info & sinfo, llama_seq_id dest_seq_id = -1);
+    bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, const slot_info & sinfo);
 };
 };
 
 
 class llama_kv_cache_context : public llama_memory_context_i {
 class llama_kv_cache_context : public llama_memory_context_i {

+ 8 - 0
tests/CMakeLists.txt

@@ -222,6 +222,14 @@ llama_build_and_test(test-backend-ops.cpp)
 llama_build_and_test(test-model-load-cancel.cpp  LABEL "model")
 llama_build_and_test(test-model-load-cancel.cpp  LABEL "model")
 llama_build_and_test(test-autorelease.cpp        LABEL "model")
 llama_build_and_test(test-autorelease.cpp        LABEL "model")
 
 
+# Test for state restore with fragmented KV cache
+# Requires a model, uses same args pattern as test-thread-safety
+if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "s390x")
+    llama_build_and_test(test-state-restore-fragmented.cpp LABEL "model" ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf)
+else()
+    llama_build_and_test(test-state-restore-fragmented.cpp LABEL "model" ARGS -hf ggml-org/models -hff tinyllamas/stories15M-be.Q4_0.gguf)
+endif()
+
 if (NOT GGML_BACKEND_DL)
 if (NOT GGML_BACKEND_DL)
     # these tests use the backends directly and cannot be built with dynamic loading
     # these tests use the backends directly and cannot be built with dynamic loading
     llama_build_and_test(test-barrier.cpp)
     llama_build_and_test(test-barrier.cpp)

+ 122 - 0
tests/test-state-restore-fragmented.cpp

@@ -0,0 +1,122 @@
+// Test for state restore with fragmented KV cache
+// This tests the fix for: https://github.com/ggml-org/llama.cpp/issues/17527
+// The issue was that state restore required contiguous KV cache slots,
+// which fails when the cache is fragmented.
+//
+// The fix changes find_slot(ubatch, true) to find_slot(ubatch, false)
+// in state_read_meta(), allowing non-contiguous slot allocation.
+
+#include "arg.h"
+#include "common.h"
+#include "llama.h"
+
+#include <vector>
+#include <cstdio>
+#include <cstring>
+
+int main(int argc, char ** argv) {
+    common_params params;
+
+    params.sampling.seed = 1234;
+    params.kv_unified = true;
+    params.n_parallel = 3;
+    params.n_ctx = 256;
+
+    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
+        return 1;
+    }
+
+    common_init();
+
+    // init
+    common_init_result_ptr llama_init = common_init_from_params(params);
+
+    llama_model * model = llama_init->model();
+    llama_context * ctx = llama_init->context();
+
+    if (model == nullptr || ctx == nullptr) {
+        fprintf(stderr, "%s : failed to init\n", __func__);
+        return 1;
+    }
+
+    GGML_UNUSED(model);
+
+    // tokenize prompt
+    std::vector<llama_token> tokens(70, 1);
+
+    // interleave the 3 sequences:
+    // 01201230123...
+    llama_batch batch = llama_batch_init(params.n_parallel*tokens.size(), 0, 1);
+    for (size_t i = 0; i < tokens.size(); i++) {
+        for (int s = 0; s < params.n_parallel; ++s) {
+            common_batch_add(batch, tokens[i], i, {s}, false);
+        }
+    }
+    batch.logits[batch.n_tokens - 1] = true;
+
+    if (llama_decode(ctx, batch)) {
+        fprintf(stderr, "%s : failed to decode seq 0\n", __func__);
+        return 1;
+    }
+
+    fprintf(stderr, "%s : processed prompt on seq 0, 1, 2 (%zu tokens each)\n", __func__, tokens.size());
+
+    // Save state of seq 1
+    std::vector<uint8_t> seq_state(llama_state_seq_get_size(ctx, 1));
+    const size_t ncopy = llama_state_seq_get_data(ctx, seq_state.data(), seq_state.size(), 1);
+    if (ncopy != seq_state.size()) {
+        fprintf(stderr, "%s : failed to save seq 1 state\n", __func__);
+        return 1;
+    }
+    fprintf(stderr, "%s : saved seq 1 state, %zu bytes\n", __func__, ncopy);
+
+    // clear seq 1 to create a "hole" in the KV cache (fragmentation)
+    // 0.20.20.20.2....
+    llama_memory_t mem = llama_get_memory(ctx);
+    llama_memory_seq_rm(mem, 1, -1, -1);
+    fprintf(stderr, "%s : cleared seq 1 to create fragmentation\n", __func__);
+
+    // Now the cache has holes where seq 1 was
+    // This creates fragmentation - there's no contiguous block large enough
+    // for the seq 1 state if we only look for contiguous slots
+
+    // Restore seq 1 state into seq 1 (should work with non-contiguous allocation)
+    // We use seq 1 since it's a valid sequence ID (0 to n_parallel-1)
+    // Before the fix, this would fail with "failed to find available cells in kv cache"
+    const size_t nset = llama_state_seq_set_data(ctx, seq_state.data(), seq_state.size(), 1);
+    if (nset != seq_state.size()) {
+        fprintf(stderr, "%s : FAILED to restore seq state into fragmented cache (got %zu, expected %zu)\n",
+                __func__, nset, seq_state.size());
+        fprintf(stderr, "%s : This is the bug - state restore fails with fragmented KV cache\n", __func__);
+        llama_batch_free(batch);
+        return 1;
+    }
+    fprintf(stderr, "%s : restored state into seq 1, %zu bytes\n", __func__, nset);
+
+    // Verify we can decode with the restored state
+    // Generate one token to verify the restored state is usable
+    auto sparams = llama_sampler_chain_default_params();
+    llama_sampler * smpl = llama_sampler_chain_init(sparams);
+    llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sampling.seed));
+
+    auto next_token = llama_sampler_sample(smpl, ctx, -1);
+    auto next_token_str = common_token_to_piece(ctx, next_token);
+
+    common_batch_clear(batch);
+    common_batch_add(batch, next_token, (int)tokens.size(), {1}, true);
+
+    if (llama_decode(ctx, batch)) {
+        fprintf(stderr, "%s : failed to decode with restored state\n", __func__);
+        llama_sampler_free(smpl);
+        llama_batch_free(batch);
+        return 1;
+    }
+
+    fprintf(stderr, "%s : successfully decoded with restored state, generated: '%s'\n", __func__, next_token_str.c_str());
+    fprintf(stderr, "%s : SUCCESS - state restore works with fragmented KV cache\n", __func__);
+
+    llama_sampler_free(smpl);
+    llama_batch_free(batch);
+
+    return 0;
+}