Pārlūkot izejas kodu

context : fix state io for memory-less contexts (#13470)

ggml-ci
Georgi Gerganov 8 mēneši atpakaļ
vecāks
revīzija
064cc596ac
1 mainītis faili ar 14 papildinājumiem un 7 dzēšanām
  1. 14 7
      src/llama-context.cpp

+ 14 - 7
src/llama-context.cpp

@@ -1788,10 +1788,13 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
         }
         }
     }
     }
 
 
-    LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
-    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
+    if (memory) {
+        LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
 
 
-    kv_self->state_read(io);
+        llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
+
+        kv_self->state_read(io);
+    }
 
 
     return io.n_bytes();
     return io.n_bytes();
 }
 }
@@ -1799,9 +1802,11 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
 size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
 size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
     GGML_UNUSED(seq_id);
     GGML_UNUSED(seq_id);
 
 
-    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
+    if (memory) {
+        llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
 
 
-    kv_self->state_write(io, seq_id);
+        kv_self->state_write(io, seq_id);
+    }
 
 
     return io.n_bytes();
     return io.n_bytes();
 }
 }
@@ -1809,9 +1814,11 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
 size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
 size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
     GGML_UNUSED(seq_id);
     GGML_UNUSED(seq_id);
 
 
-    llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
+    if (memory) {
+        llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
 
 
-    kv_self->state_read(io, seq_id);
+        kv_self->state_read(io, seq_id);
+    }
 
 
     return io.n_bytes();
     return io.n_bytes();
 }
 }