Browse Source

memory : correctly handle failure in apply() (#14438)

ggml-ci
Georgi Gerganov 6 tháng trước cách đây
mục cha
commit
745f11fed0

+ 1 - 1
src/llama-kv-cache-unified-iswa.cpp

@@ -246,7 +246,7 @@ bool llama_kv_cache_unified_iswa_context::next() {
 }
 
 bool llama_kv_cache_unified_iswa_context::apply() {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+    assert(!llama_memory_status_is_fail(status));
 
     bool res = true;
 

+ 1 - 1
src/llama-kv-cache-unified.cpp

@@ -1776,7 +1776,7 @@ bool llama_kv_cache_unified_context::next() {
 }
 
 bool llama_kv_cache_unified_context::apply() {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+    assert(!llama_memory_status_is_fail(status));
 
     // no ubatches -> this is a KV cache update
     if (ubatches.empty()) {

+ 1 - 1
src/llama-memory-hybrid.cpp

@@ -218,7 +218,7 @@ bool llama_memory_hybrid_context::next() {
 }
 
 bool llama_memory_hybrid_context::apply() {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+    assert(!llama_memory_status_is_fail(status));
 
     bool res = true;
 

+ 9 - 1
src/llama-memory-recurrent.cpp

@@ -1071,7 +1071,15 @@ bool llama_memory_recurrent_context::next() {
 }
 
 bool llama_memory_recurrent_context::apply() {
-    assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
+    assert(!llama_memory_status_is_fail(status));
+
+    // no ubatches -> this is an update
+    if (ubatches.empty()) {
+        // recurrent cache never performs updates
+        assert(status == LLAMA_MEMORY_STATUS_NO_UPDATE);
+
+        return true;
+    }
 
     mem->find_slot(ubatches[i_next]);
 

+ 17 - 0
src/llama-memory.cpp

@@ -40,3 +40,20 @@ llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_me
     // if either status has an update, then the combined status has an update
     return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
 }
+
+bool llama_memory_status_is_fail(llama_memory_status status) {
+    switch (status) {
+        case LLAMA_MEMORY_STATUS_SUCCESS:
+        case LLAMA_MEMORY_STATUS_NO_UPDATE:
+            {
+                return false;
+            }
+        case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
+        case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
+            {
+                return true;
+            }
+    }
+
+    return false;
+}

+ 3 - 0
src/llama-memory.h

@@ -31,6 +31,9 @@ enum llama_memory_status {
 // useful for implementing hybrid memory types (e.g. iSWA)
 llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
 
+// helper function for checking if a memory status indicates a failure
+bool llama_memory_status_is_fail(llama_memory_status status);
+
 // the interface for managing the memory context during batch processing
 // this interface is implemented per memory type. see:
 //   - llama_kv_cache_unified_context