Просмотр исходного кода

llama : fix kv shift bug (#3835)

ggml-ci
Georgi Gerganov 2 лет назад
Родитель
Сommit
71a09da301
1 измененных файлов с 18 добавлено и 9 удалено
  1. 18 9
      llama.cpp

+ 18 - 9
llama.cpp

@@ -1552,14 +1552,14 @@ static void llama_kv_cache_seq_shift(
 
     for (uint32_t i = 0; i < cache.size; ++i) {
         if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
-            cache.cells[i].pos += delta;
+            cache.has_shift = true;
+            cache.cells[i].pos   += delta;
+            cache.cells[i].delta += delta;
+
             if (cache.cells[i].pos < 0) {
                 cache.cells[i].pos = -1;
                 cache.cells[i].seq_id.clear();
                 if (new_head == cache.size) new_head = i;
-            } else {
-                cache.has_shift = true;
-                cache.cells[i].delta = delta;
             }
         }
     }
@@ -6073,11 +6073,20 @@ static int llama_decode_internal(
 #endif
 
     // update the kv ring buffer
-    lctx.kv_self.has_shift  = false;
-    lctx.kv_self.head      += n_tokens;
-    // Ensure kv cache head points to a valid index.
-    if (lctx.kv_self.head >= lctx.kv_self.size) {
-        lctx.kv_self.head = 0;
+    {
+        if (kv_self.has_shift) {
+            kv_self.has_shift = false;
+            for (uint32_t i = 0; i < kv_self.size; ++i) {
+                kv_self.cells[i].delta = 0;
+            }
+        }
+
+        kv_self.head += n_tokens;
+
+        // Ensure kv cache head points to a valid index.
+        if (kv_self.head >= kv_self.size) {
+            kv_self.head = 0;
+        }
     }
 
 #ifdef GGML_PERF