Преглед изворни кода

kv-cache : fix seq_rm with seq_id == -1 (#15226)

* kv-cache : fix seq_rm with seq_id == -1

ggml-ci

* cont : iterate over streams

ggml-ci
Georgi Gerganov пре 5 месеци
родитељ
комит
228f724d9c
1 измењених фајлова са 30 додато и 18 уклоњено
  1. 30 18
      src/llama-kv-cache-unified.cpp

+ 30 - 18
src/llama-kv-cache-unified.cpp

@@ -223,12 +223,7 @@ void llama_kv_cache_unified::clear(bool data) {
 }
 }
 
 
 bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
 bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
-    GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size());
-
-    auto & cells = v_cells[seq_to_stream[seq_id]];
-    auto & head  = v_heads[seq_to_stream[seq_id]];
-
-    uint32_t new_head = cells.size();
+    GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()));
 
 
     if (p0 < 0) {
     if (p0 < 0) {
         p0 = 0;
         p0 = 0;
@@ -239,6 +234,11 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
     }
     }
 
 
     if (seq_id >= 0) {
     if (seq_id >= 0) {
+        auto & cells = v_cells[seq_to_stream[seq_id]];
+        auto & head  = v_heads[seq_to_stream[seq_id]];
+
+        uint32_t new_head = cells.size();
+
         for (uint32_t i = 0; i < cells.size(); ++i) {
         for (uint32_t i = 0; i < cells.size(); ++i) {
             if (!cells.pos_in(i, p0, p1)) {
             if (!cells.pos_in(i, p0, p1)) {
                 continue;
                 continue;
@@ -250,24 +250,36 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
                 }
                 }
             }
             }
         }
         }
+
+        // If we freed up a slot, set head to it so searching can start there.
+        if (new_head != cells.size() && new_head < head) {
+            head = new_head;
+        }
     } else {
     } else {
         // match any sequence
         // match any sequence
-        for (uint32_t i = 0; i < cells.size(); ++i) {
-            if (!cells.pos_in(i, p0, p1)) {
-                continue;
-            }
+        for (uint32_t s = 0; s < n_stream; ++s) {
+            auto & cells = v_cells[s];
+            auto & head  = v_heads[s];
 
 
-            cells.rm(i);
+            uint32_t new_head = cells.size();
 
 
-            if (new_head == cells.size()) {
-                new_head = i;
+            for (uint32_t i = 0; i < cells.size(); ++i) {
+                if (!cells.pos_in(i, p0, p1)) {
+                    continue;
+                }
+
+                cells.rm(i);
+
+                if (new_head == cells.size()) {
+                    new_head = i;
+                }
             }
             }
-        }
-    }
 
 
-    // If we freed up a slot, set head to it so searching can start there.
-    if (new_head != cells.size() && new_head < head) {
-        head = new_head;
+            // If we freed up a slot, set head to it so searching can start there.
+            if (new_head != cells.size() && new_head < head) {
+                head = new_head;
+            }
+        }
     }
     }
 
 
     return true;
     return true;