Browse Source

llama : deprecate explicit kv_self defrag/update calls (#13921)

ggml-ci
Georgi Gerganov 7 months ago
parent
commit
803f8baf4f
3 changed files with 8 additions and 14 deletions
  1. 2 7
      examples/passkey/passkey.cpp
  2. 4 7
      include/llama.h
  3. 2 0
      src/llama-context.cpp

+ 2 - 7
examples/passkey/passkey.cpp

@@ -133,9 +133,8 @@ int main(int argc, char ** argv) {
             const int ib = i/n_batch - 1;
             const int bd = n_batch_grp*(n_grp - 1);
 
-            llama_kv_self_seq_add (ctx, 0, n_past - n_batch,         n_past,         ib*bd);
-            llama_kv_self_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
-            llama_kv_self_update  (ctx);
+            llama_kv_self_seq_add(ctx, 0, n_past - n_batch,         n_past,         ib*bd);
+            llama_kv_self_seq_div(ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
 
             n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
         }
@@ -169,8 +168,6 @@ int main(int argc, char ** argv) {
 
         llama_kv_self_seq_rm (ctx, 0, n_keep            , n_keep + n_discard);
         llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
-      //llama_kv_self_defrag (ctx);
-        llama_kv_self_update (ctx);
 
         n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
 
@@ -200,8 +197,6 @@ int main(int argc, char ** argv) {
 
             llama_kv_self_seq_rm (ctx, 0, n_keep            , n_keep + n_discard);
             llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
-          //llama_kv_self_defrag (ctx);
-            llama_kv_self_update (ctx);
 
             n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
         }

+ 4 - 7
include/llama.h

@@ -655,7 +655,6 @@ extern "C" {
     // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
     // If the KV cache is RoPEd, the KV data is updated accordingly:
     //   - lazily on next llama_decode()
-    //   - explicitly with llama_kv_self_update()
     // p0 < 0 : [0,  p1]
     // p1 < 0 : [p0, inf)
     LLAMA_API void llama_kv_self_seq_add(
@@ -668,7 +667,6 @@ extern "C" {
     // Integer division of the positions by factor of `d > 1`
     // If the KV cache is RoPEd, the KV data is updated accordingly:
     //   - lazily on next llama_decode()
-    //   - explicitly with llama_kv_self_update()
     // p0 < 0 : [0,  p1]
     // p1 < 0 : [p0, inf)
     LLAMA_API void llama_kv_self_seq_div(
@@ -696,16 +694,15 @@ extern "C" {
     // Defragment the KV cache
     // This will be applied:
     //   - lazily on next llama_decode()
-    //   - explicitly with llama_kv_self_update()
-    // TODO: deprecate and always update the cache lazily [TAG: API_KV_NO_DEFRAG]
-    LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
+    LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx),
+            "simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");
 
     // Check if the context supports KV cache shifting
     LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);
 
     // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
-    // TODO: deprecate and always update the cache lazily [TAG: API_KV_NO_DEFRAG]
-    LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
+    LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx),
+            "simply remove this call, updates are applied lazily on the next llama_decode()");
 
     //
     // State / sessions

+ 2 - 0
src/llama-context.cpp

@@ -2281,6 +2281,7 @@ llama_kv_cache * llama_get_kv_self(llama_context * ctx) {
     return ctx->get_kv_self();
 }
 
+// deprecated
 void llama_kv_self_update(llama_context * ctx) {
     ctx->kv_self_update();
 }
@@ -2535,6 +2536,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
     return kv->seq_pos_max(seq_id);
 }
 
+// deprecated
 void llama_kv_self_defrag(llama_context * ctx) {
     auto * kv = ctx->get_kv_self();
     if (!kv) {