فهرست منبع

Added api for getting/setting the kv_cache (#685)

The api provides access methods for retrieving the current memory buffer for the kv_cache and its token number.
It also contains a method for setting the kv_cache from a memory buffer.

This makes it possible to load/save history - maybe support --cache-prompt paramater as well?

Co-authored-by: Pavol Rusnak <pavol@rusnak.io>
Christian Falch 2 سال پیش
والد
کامیت
e986f94829
2فایلهای تغییر یافته به همراه44 افزوده شده و 0 حذف شده
  1. 27 0
      llama.cpp
  2. 17 0
      llama.h

+ 27 - 0
llama.cpp

@@ -1668,6 +1668,33 @@ int llama_model_quantize(
     return 0;
 }
 
+// Returns the KV cache that will contain the context for the
+// ongoing prediction with the model.
+const uint8_t * llama_get_kv_cache(struct llama_context * ctx) {
+    return ctx->model.kv_self.buf.data();
+}
+
+// Returns the size of the KV cache
+size_t llama_get_kv_cache_size(struct llama_context * ctx) {
+    return ctx->model.kv_self.buf.size();
+}
+
+int llama_get_kv_cache_token_count(struct llama_context * ctx) {
+    return ctx->model.kv_self.n;
+}
+
+// Sets the KV cache containing the current context for the model
+void llama_set_kv_cache(
+        struct llama_context * ctx,
+               const uint8_t * kv_cache,
+                      size_t   n_size,
+                         int   n_token_count) {
+    // Make sure we have the same kv cache setup
+    LLAMA_ASSERT(ctx->model.kv_self.buf.size() == n_size);
+    memcpy(ctx->model.kv_self.buf.data(), kv_cache, n_size);
+    ctx->model.kv_self.n = n_token_count;
+}
+
 int llama_eval(
         struct llama_context * ctx,
            const llama_token * tokens,

+ 17 - 0
llama.h

@@ -83,6 +83,23 @@ extern "C" {
             const char * fname_out,
                    int   itype);
 
+    // Returns the KV cache that will contain the context for the
+    // ongoing prediction with the model.
+    LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx);
+
+    // Returns the size of the KV cache
+    LLAMA_API size_t llama_get_kv_cache_size(struct llama_context * ctx);
+
+    // Returns the number of tokens in the KV cache
+    LLAMA_API int llama_get_kv_cache_token_count(struct llama_context * ctx);
+
+    // Sets the KV cache containing the current context for the model
+    LLAMA_API void llama_set_kv_cache(
+            struct llama_context * ctx,
+                   const uint8_t * kv_cache,
+                          size_t   n_size,
+                             int   n_token_count);
+
     // Run the llama inference to obtain the logits and probabilities for the next token.
     // tokens + n_tokens is the provided batch of new tokens to process
     // n_past is the number of tokens to use from previous eval calls