Просмотр исходного кода

llama : refactor k-shift implementation + KV defragmentation (#5691)

* llama : refactor k-shift implementation

ggml-ci

* llama : rename llama_kv_cache_seq_shift to llama_kv_cache_seq_add

* llama : cont k-shift refactoring + normalize type names

ggml-ci

* minor : fix MPI builds

* llama : reuse n_rot from the build context

ggml-ci

* llama : revert enum name changes from this PR

ggml-ci

* llama : update llama_rope_type

* llama : add comment about rope values

* llama : fix build

* passkey : apply kv cache updates explicitly

ggml-ci

* llama : change name to llama_kv_cache_update()

* llama : add llama_kv_cache_seq_pos_max()

* passkey : fix llama_kv_cache_seq_pos_max() usage

* llama : some llama_kv_cell simplifications

* llama : add llama_kv_cache_compress (EXPERIMENTAL)

* llama : add alternative KV cache merging (EXPERIMENTAL)

* llama : add llama_kv_cache_defrag

* llama : comments

* llama : remove llama_kv_cache_compress

will add in a separate PR

ggml-ci

* llama : defragment via non-overlapping moves

* llama : ggml_graph based defrag implementation

ggml-ci

* llama : switch the loop order in build_defrag

* llama : add comments
Georgi Gerganov 1 год назад
Родитель
Сommit
bf08e00643
6 измененных файлов с 272 добавлено и 243 удалено
  1. 2 2
      examples/infill/infill.cpp
  2. 5 5
      examples/main/main.cpp
  3. 15 10
      examples/passkey/passkey.cpp
  4. 4 4
      examples/server/server.cpp
  5. 215 219
      llama.cpp
  6. 31 3
      llama.h

+ 2 - 2
examples/infill/infill.cpp

@@ -447,8 +447,8 @@ int main(int argc, char ** argv) {
                 LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
                     n_past, n_left, n_ctx, params.n_keep, n_discard);
 
-                llama_kv_cache_seq_rm   (ctx, 0, params.n_keep + 1            , params.n_keep + n_discard + 1);
-                llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
+                llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1            , params.n_keep + n_discard + 1);
+                llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
 
                 n_past -= n_discard;
 

+ 5 - 5
examples/main/main.cpp

@@ -548,8 +548,8 @@ int main(int argc, char ** argv) {
                     LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
                             n_past, n_left, n_ctx, params.n_keep, n_discard);
 
-                    llama_kv_cache_seq_rm   (ctx, 0, params.n_keep            , params.n_keep + n_discard);
-                    llama_kv_cache_seq_shift(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
+                    llama_kv_cache_seq_rm (ctx, 0, params.n_keep            , params.n_keep + n_discard);
+                    llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
 
                     n_past -= n_discard;
 
@@ -576,9 +576,9 @@ int main(int argc, char ** argv) {
                     LOG("div:   [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
                     LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
 
-                    llama_kv_cache_seq_shift(ctx, 0, ga_i,                n_past,              ib*bd);
-                    llama_kv_cache_seq_div  (ctx, 0, ga_i + ib*bd,        ga_i + ib*bd + ga_w, ga_n);
-                    llama_kv_cache_seq_shift(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd,      dd);
+                    llama_kv_cache_seq_add(ctx, 0, ga_i,                n_past,              ib*bd);
+                    llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd,        ga_i + ib*bd + ga_w, ga_n);
+                    llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd,      dd);
 
                     n_past -= bd;
 

+ 15 - 10
examples/passkey/passkey.cpp

@@ -126,7 +126,7 @@ int main(int argc, char ** argv) {
     const int n_batch     = ctx_params.n_batch;
     const int n_batch_grp = ctx_params.n_batch/n_grp;
 
-    LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch);
+    LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d, n_junk = %d, i_pos = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch, n_junk, i_pos);
 
     // print the prompt token-by-token
 
@@ -146,10 +146,11 @@ int main(int argc, char ** argv) {
             const int ib = i/n_batch - 1;
             const int bd = n_batch_grp*(n_grp - 1);
 
-            llama_kv_cache_seq_shift(ctx, 0, n_past - n_batch,         n_past,         ib*bd);
-            llama_kv_cache_seq_div  (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
+            llama_kv_cache_seq_add (ctx, 0, n_past - n_batch,         n_past,         ib*bd);
+            llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
+            llama_kv_cache_update  (ctx);
 
-            n_past -= bd;
+            n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
         }
 
         llama_batch_clear(batch);
@@ -179,10 +180,12 @@ int main(int argc, char ** argv) {
 
         LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard);
 
-        llama_kv_cache_seq_rm   (ctx, 0, n_keep            , n_keep + n_discard);
-        llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
+        llama_kv_cache_seq_rm (ctx, 0, n_keep            , n_keep + n_discard);
+        llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
+        llama_kv_cache_defrag (ctx);
+        llama_kv_cache_update (ctx);
 
-        n_past -= n_discard;
+        n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
 
         llama_batch_clear(batch);
 
@@ -208,10 +211,12 @@ int main(int argc, char ** argv) {
         if (n_discard > 0) {
             LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
 
-            llama_kv_cache_seq_rm   (ctx, 0, n_keep            , n_keep + n_discard);
-            llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
+            llama_kv_cache_seq_rm (ctx, 0, n_keep            , n_keep + n_discard);
+            llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx,  -n_discard);
+            llama_kv_cache_defrag (ctx);
+            llama_kv_cache_update (ctx);
 
-            n_past -= n_discard;
+            n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
         }
     }
 

+ 4 - 4
examples/server/server.cpp

@@ -1636,8 +1636,8 @@ struct llama_server_context
                         {"n_system_tokens", system_tokens.size()},
                         {"n_cache_tokens",  slot.cache_tokens.size()}
                     });
-                    llama_kv_cache_seq_rm   (ctx, slot.id, n_keep            , n_keep + n_discard);
-                    llama_kv_cache_seq_shift(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
+                    llama_kv_cache_seq_rm (ctx, slot.id, n_keep            , n_keep + n_discard);
+                    llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
 
                     for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++)
                     {
@@ -1941,9 +1941,9 @@ struct llama_server_context
                         LOG_TEE("div:   [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
                         LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
 
-                        llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
+                        llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
                         llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w,slot.ga_n);
-                        llama_kv_cache_seq_shift(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
+                        llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w,slot.n_past_se + ib * bd, dd);
 
                         slot.n_past_se -= bd;
 

Разница между файлами не показана из-за своего большого размера
+ 215 - 219
llama.cpp


+ 31 - 3
llama.h

@@ -64,6 +64,15 @@ extern "C" {
         LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece
     };
 
+    // note: these values should be synchronized with ggml_rope
+    // TODO: maybe move this enum to ggml.h (ggml_rope_type)
+    enum llama_rope_type {
+        LLAMA_ROPE_TYPE_NONE = -1,
+        LLAMA_ROPE_TYPE_NORM =  0,
+        LLAMA_ROPE_TYPE_NEOX =  2,
+        LLAMA_ROPE_TYPE_GLM  =  4,
+    };
+
     enum llama_token_type {
         LLAMA_TOKEN_TYPE_UNDEFINED    = 0,
         LLAMA_TOKEN_TYPE_NORMAL       = 1,
@@ -360,6 +369,7 @@ extern "C" {
     LLAMA_API uint32_t llama_n_batch    (const struct llama_context * ctx);
 
     LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
+    LLAMA_API enum llama_rope_type  llama_rope_type (const struct llama_model * model);
 
     LLAMA_API int32_t llama_n_vocab    (const struct llama_model * model);
     LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
@@ -514,10 +524,12 @@ extern "C" {
                     llama_seq_id   seq_id);
 
     // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
-    // If the KV cache is RoPEd, the KV data is updated accordingly
+    // If the KV cache is RoPEd, the KV data is updated accordingly:
+    //   - lazily on next llama_decode()
+    //   - explicitly with llama_kv_cache_update()
     // p0 < 0 : [0,  p1]
     // p1 < 0 : [p0, inf)
-    LLAMA_API void llama_kv_cache_seq_shift(
+    LLAMA_API void llama_kv_cache_seq_add(
             struct llama_context * ctx,
                     llama_seq_id   seq_id,
                        llama_pos   p0,
@@ -525,7 +537,9 @@ extern "C" {
                        llama_pos   delta);
 
     // Integer division of the positions by factor of `d > 1`
-    // If the KV cache is RoPEd, the KV data is updated accordingly
+    // If the KV cache is RoPEd, the KV data is updated accordingly:
+    //   - lazily on next llama_decode()
+    //   - explicitly with llama_kv_cache_update()
     // p0 < 0 : [0,  p1]
     // p1 < 0 : [p0, inf)
     LLAMA_API void llama_kv_cache_seq_div(
@@ -535,6 +549,20 @@ extern "C" {
                        llama_pos   p1,
                              int   d);
 
+    // Returns the largest position present in the KV cache for the specified sequence
+    LLAMA_API llama_pos llama_kv_cache_seq_pos_max(
+            struct llama_context * ctx,
+                    llama_seq_id   seq_id);
+
+    // Defragment the KV cache
+    // This will be applied:
+    //   - lazily on next llama_decode()
+    //   - explicitly with llama_kv_cache_update()
+    LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx);
+
+    // Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
+    LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
+
     //
     // State / sessions
     //

Некоторые файлы не были показаны из-за большого количества измененных файлов