|
|
@@ -1552,14 +1552,14 @@ static void llama_kv_cache_seq_shift(
|
|
|
|
|
|
for (uint32_t i = 0; i < cache.size; ++i) {
|
|
|
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
|
|
- cache.cells[i].pos += delta;
|
|
|
+ cache.has_shift = true;
|
|
|
+ cache.cells[i].pos += delta;
|
|
|
+ cache.cells[i].delta += delta;
|
|
|
+
|
|
|
if (cache.cells[i].pos < 0) {
|
|
|
cache.cells[i].pos = -1;
|
|
|
cache.cells[i].seq_id.clear();
|
|
|
if (new_head == cache.size) new_head = i;
|
|
|
- } else {
|
|
|
- cache.has_shift = true;
|
|
|
- cache.cells[i].delta = delta;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -6073,11 +6073,20 @@ static int llama_decode_internal(
|
|
|
#endif
|
|
|
|
|
|
// update the kv ring buffer
|
|
|
- lctx.kv_self.has_shift = false;
|
|
|
- lctx.kv_self.head += n_tokens;
|
|
|
- // Ensure kv cache head points to a valid index.
|
|
|
- if (lctx.kv_self.head >= lctx.kv_self.size) {
|
|
|
- lctx.kv_self.head = 0;
|
|
|
+ {
|
|
|
+ if (kv_self.has_shift) {
|
|
|
+ kv_self.has_shift = false;
|
|
|
+ for (uint32_t i = 0; i < kv_self.size; ++i) {
|
|
|
+ kv_self.cells[i].delta = 0;
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ kv_self.head += n_tokens;
|
|
|
+
|
|
|
+ // Ensure kv cache head points to a valid index.
|
|
|
+ if (kv_self.head >= kv_self.size) {
|
|
|
+ kv_self.head = 0;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
#ifdef GGML_PERF
|