Jelajahi Sumber

completion : fix prompt cache for recurrent models (#19045)

Georgi Gerganov 5 hari lalu
induk
melakukan
080b161995
2 mengubah file dengan 47 tambahan dan 40 penghapusan
  1. 1 0
      src/llama-context.cpp
  2. 46 40
      tools/completion/completion.cpp

+ 1 - 0
src/llama-context.cpp

@@ -2559,6 +2559,7 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
         }
     }
 
+    // [TAG_CONTEXT_STATE_LOGITS]
     // write logits
     {
         LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__);

+ 46 - 40
tools/completion/completion.cpp

@@ -342,44 +342,51 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    // debug message about similarity of saved session, if applicable
-    size_t n_matching_session_tokens = 0;
-    if (!session_tokens.empty()) {
-        for (llama_token id : session_tokens) {
-            if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) {
-                break;
-            }
-            n_matching_session_tokens++;
-        }
-        if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) {
-            LOG_INF("%s: using full prompt from session file\n", __func__);
-        } else if (n_matching_session_tokens >= embd_inp.size()) {
-            LOG_INF("%s: session file has exact match for prompt!\n", __func__);
-        } else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
-            LOG_WRN("%s: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
-                    __func__, n_matching_session_tokens, embd_inp.size());
-        } else {
-            LOG_INF("%s: session file matches %zu / %zu tokens of prompt\n",
-                    __func__, n_matching_session_tokens, embd_inp.size());
-        }
+    bool session_do_save = false;
 
-        // remove any "future" tokens that we might have inherited from the previous session
-        if (!llama_memory_seq_rm(mem, -1, n_matching_session_tokens, -1)) {
-            LOG_INF("%s: unable to resuse common prefix\n", __func__);
-            n_matching_session_tokens = 0;
-            llama_memory_seq_rm(mem, -1, -1, -1);
-        }
-    }
+    {
+        size_t n_match = 0;
+
+        if (!session_tokens.empty()) {
+            for (llama_token id : session_tokens) {
+                if (n_match >= embd_inp.size() || id != embd_inp[n_match]) {
+                    break;
+                }
+                n_match++;
+            }
+            if (params.prompt.empty() && n_match == embd_inp.size()) {
+                LOG_INF("%s: using full prompt from session file\n", __func__);
+            } else if (n_match >= embd_inp.size()) {
+                LOG_INF("%s: session file has exact match for prompt!\n", __func__);
+            } else if (n_match < (embd_inp.size() / 2)) {
+                LOG_WRN("%s: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",
+                        __func__, n_match, embd_inp.size());
+            } else {
+                LOG_INF("%s: session file matches %zu / %zu tokens of prompt\n",
+                        __func__, n_match, embd_inp.size());
+            }
 
-    LOG_DBG("recalculate the cached logits (check): embd_inp.size() %zu, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu\n",
-         embd_inp.size(), n_matching_session_tokens, embd_inp.size(), session_tokens.size());
+            if (session_tokens.size() == n_match) {
+                // [TAG_CONTEXT_STATE_LOGITS]
+                // in this case, we are going to reuse the logits from the session
+                // if we ever decide to remove the logits from the session, we need to handle this somehow
+                // ref: https://github.com/ggml-org/llama.cpp/pull/18862#issuecomment-3756330941
+            }
 
-    // if we will use the cache for the full prompt without reaching the end of the cache, force
-    // reevaluation of the last token to recalculate the cached logits
-    if (!embd_inp.empty() && n_matching_session_tokens == embd_inp.size() && session_tokens.size() > embd_inp.size()) {
-        LOG_DBG("recalculate the cached logits (do): session_tokens.resize( %zu )\n", embd_inp.size() - 1);
+            // remove any "future" tokens that we might have inherited from the previous session
+            if (session_tokens.size() > n_match) {
+                if (!llama_memory_seq_rm(mem, -1, n_match, -1)) {
+                    LOG_WRN("%s: unable to resuse common prefix (for example, when the memory is recurrent)\n", __func__);
+                    llama_memory_clear(mem, true);
+                    session_tokens.clear();
+                    n_match = 0;
+                } else {
+                    session_tokens.resize(n_match);
+                }
+            }
+        }
 
-        session_tokens.resize(embd_inp.size() - 1);
+        session_do_save = !path_session.empty() && n_match < embd_inp.size() && !params.prompt_cache_ro;
     }
 
     // number of tokens to keep when resetting context
@@ -521,10 +528,9 @@ int main(int argc, char ** argv) {
         is_interacting = params.interactive_first;
     }
 
-    bool is_antiprompt        = false;
-    bool input_echo           = true;
-    bool display              = true;
-    bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < embd_inp.size();
+    bool is_antiprompt = false;
+    bool input_echo    = true;
+    bool display       = true;
 
     int n_past             = 0;
     int n_remain           = params.n_predict;
@@ -700,8 +706,8 @@ int main(int argc, char ** argv) {
 
         if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
             // optionally save the session on first sample (for faster prompt loading next time)
-            if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
-                need_to_save_session = false;
+            if (session_do_save) {
+                session_do_save = false;
                 llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
 
                 LOG_DBG("saved session to %s\n", path_session.c_str());