Просмотр исходного кода

main : option to disable context shift (#9484)

* added cli arg to disable context shift

* reverted precommit

* updated README.md for main

* white space

* allow disabling context shift in the server

* Update common/arg.cpp

no-context-shift only works for main example

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* added server example to --no-context-shift args

* removed server changes

* white space

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Vinesh Janarthanan 1 год назад
Родитель
Сommit
441b72b91f
4 измененных файлов с 30 добавлено и 15 удалено
  1. 7 1
      common/arg.cpp
  2. 1 0
      common/common.h
  3. 2 0
      examples/main/README.md
  4. 20 14
      examples/main/main.cpp

+ 7 - 1
common/arg.cpp

@@ -685,6 +685,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
             params.n_keep = value;
         }
     ));
+    add_opt(llama_arg(
+        {"--no-context-shift"},
+        format("disables context shift on inifinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
+        [](gpt_params & params) {
+            params.ctx_shift = false;
+        }
+    ).set_examples({LLAMA_EXAMPLE_MAIN}));
     add_opt(llama_arg(
         {"--chunks"}, "N",
         format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
@@ -1985,4 +1992,3 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
 
     return ctx_arg;
 }
-

+ 1 - 0
common/common.h

@@ -246,6 +246,7 @@ struct gpt_params {
     bool cont_batching     = true;  // insert new sequences for decoding on-the-fly
     bool flash_attn        = false; // flash attention
     bool no_perf           = false; // disable performance metrics
+    bool ctx_shift         = true;  // context shift on inifinite text generation
 
     bool input_prefix_bos  = false; // prefix BOS to user inputs, preceding input_prefix
     bool logits_all        = false; // return logits for all tokens in the batch

+ 2 - 0
examples/main/README.md

@@ -161,6 +161,8 @@ A value of -1 will enable infinite text generation, even though we have a finite
 
 If the pause is undesirable, a value of -2 will stop generation immediately when the context is filled.
 
+The `--no-context-shift` option allows you to stop the infinite text generation once the finite context window is full.
+
 It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode, text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `--predict` value. If you want the model to keep going without ever producing End-of-Sequence on its own, you can use the `--ignore-eos` parameter.
 
 ### Temperature

+ 20 - 14
examples/main/main.cpp

@@ -559,29 +559,35 @@ int main(int argc, char ** argv) {
                 // if we run out of context:
                 // - take the n_keep first tokens from the original prompt (via n_past)
                 // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
+
                 if (n_past + (int) embd.size() >= n_ctx) {
-                    if (params.n_predict == -2) {
-                        LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
+                    if (!params.ctx_shift){
+                        LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
                         break;
-                    }
+                    } else {
+                        if (params.n_predict == -2) {
+                            LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
+                            break;
+                        }
 
-                    const int n_left    = n_past - params.n_keep;
-                    const int n_discard = n_left/2;
+                        const int n_left    = n_past - params.n_keep;
+                        const int n_discard = n_left/2;
 
-                    LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
-                            n_past, n_left, n_ctx, params.n_keep, n_discard);
+                        LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
+                                n_past, n_left, n_ctx, params.n_keep, n_discard);
 
-                    llama_kv_cache_seq_rm (ctx, 0, params.n_keep            , params.n_keep + n_discard);
-                    llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
+                        llama_kv_cache_seq_rm (ctx, 0, params.n_keep            , params.n_keep + n_discard);
+                        llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
 
-                    n_past -= n_discard;
+                        n_past -= n_discard;
 
-                    LOG_DBG("after swap: n_past = %d\n", n_past);
+                        LOG_DBG("after swap: n_past = %d\n", n_past);
 
-                    LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
+                        LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
 
-                    LOG_DBG("clear session path\n");
-                    path_session.clear();
+                        LOG_DBG("clear session path\n");
+                        path_session.clear();
+                    }
                 }
             } else {
                 // context extension via Self-Extend