Kaynağa Gözat

server: stop generation at `n_ctx_train` if `n_predict` is not set (#6638)

* server: cap n_predict if not set to n_ctx_train

* server: fix infinite loop

* server: infinite loop, move in process_token
server: infinite loop: set stop limit to true

* minor: spaces

* minor: spaces

* server: include prompt tokens in the EOS limit
Pierrick Hymbert 1 yıl önce
ebeveyn
işleme
7f5ff558ee
1 değiştirilmiş dosya ile 22 ekleme ve 1 silme
  1. 22 1
      examples/server/server.cpp

+ 22 - 1
examples/server/server.cpp

@@ -1207,6 +1207,27 @@ struct server_context {
             LOG_VERBOSE("eos token found", {});
             LOG_VERBOSE("eos token found", {});
         }
         }
 
 
+        auto n_ctx_train = llama_n_ctx_train(model);
+        if (slot.params.n_predict < 1 && slot.ga_n == 1
+                    && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
+            LOG_WARNING("n_predict is not set and self-context extend is disabled."
+                        " Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", {
+                    { "id_slot",              slot.id },
+                    { "params.n_predict",     slot.params.n_predict },
+                    { "slot.n_prompt_tokens", slot.n_prompt_tokens },
+                    { "slot.n_decoded",       slot.n_decoded },
+                    { "slot.n_predict",       slot.n_predict },
+                    { "n_slots",              params.n_parallel },
+                    { "slot.n_ctx",           slot.n_ctx },
+                    { "n_ctx",                n_ctx },
+                    { "n_ctx_train",          n_ctx_train },
+                    { "ga_n",                 slot.ga_n },
+                });
+            slot.truncated      = true;
+            slot.stopped_limit  = true;
+            slot.has_next_token = false; // stop prediction
+        }
+
         LOG_VERBOSE("next token", {
         LOG_VERBOSE("next token", {
             {"id_slot",        slot.id},
             {"id_slot",        slot.id},
             {"id_task",        slot.id_task},
             {"id_task",        slot.id_task},
@@ -2141,7 +2162,7 @@ struct server_context {
         });
         });
 
 
         // process the created batch of tokens
         // process the created batch of tokens
-        for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
+        for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
             const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
             const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
 
 
             for (auto & slot : slots) {
             for (auto & slot : slots) {