Просмотр исходного кода

server : add n_indent parameter for line indentation requirement (#9929)

ggml-ci
Georgi Gerganov 1 год назад
Родитель
Сommit
8901755ba3
2 измененных файлов с 49 добавлено и 7 удалено
  1. 2 0
      examples/server/README.md
  2. 47 7
      examples/server/server.cpp

+ 2 - 0
examples/server/README.md

@@ -333,6 +333,8 @@ node index.js
 
     `n_predict`: Set the maximum number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. Default: `-1`, where `-1` is infinity.
 
+    `n_indent`: Specify the minimum line indentation for the generated text in number of whitespace characters. Useful for code completion tasks. Default: `0`
+
     `n_keep`: Specify the number of tokens from the prompt to retain when the context size is exceeded and tokens need to be discarded. The number excludes the BOS token.
     By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to retain all tokens from the prompt.
 

+ 47 - 7
examples/server/server.cpp

@@ -131,6 +131,7 @@ struct slot_params {
     int32_t n_keep    =  0; // number of tokens to keep from initial prompt
     int32_t n_discard =  0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
     int32_t n_predict = -1; // new tokens to predict
+    int32_t n_indent  =  0; // mininum line indentation for the generated text in number of whitespace characters
 
     int64_t t_max_prompt_ms  = -1; // TODO: implement
     int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
@@ -173,6 +174,8 @@ struct server_slot {
     std::vector<llama_token> prompt_tokens;
     std::vector<llama_token> extra_tokens;
 
+    size_t last_nl_pos = 0;
+
     std::string generated_text;
     std::vector<llama_token> cache_tokens;
     std::vector<completion_token_output> generated_token_probs;
@@ -215,6 +218,7 @@ struct server_slot {
         SLT_DBG(*this, "%s", "\n");
 
         n_prompt_tokens    = 0;
+        last_nl_pos        = 0;
         generated_text     = "";
         has_new_line       = false;
         truncated          = false;
@@ -860,6 +864,7 @@ struct server_context {
         slot.params.stream             = json_value(data, "stream",            false);
         slot.params.cache_prompt       = json_value(data, "cache_prompt",      false);
         slot.params.n_predict          = json_value(data, "n_predict",         json_value(data, "max_tokens", default_params.n_predict));
+        slot.params.n_indent           = json_value(data, "n_indent",          default_params.n_indent);
         slot.sparams.top_k             = json_value(data, "top_k",             default_sparams.top_k);
         slot.sparams.top_p             = json_value(data, "top_p",             default_sparams.top_p);
         slot.sparams.min_p             = json_value(data, "min_p",             default_sparams.min_p);
@@ -878,7 +883,7 @@ struct server_context {
         slot.sparams.mirostat_tau      = json_value(data, "mirostat_tau",      default_sparams.mirostat_tau);
         slot.sparams.mirostat_eta      = json_value(data, "mirostat_eta",      default_sparams.mirostat_eta);
         slot.sparams.penalize_nl       = json_value(data, "penalize_nl",       default_sparams.penalize_nl);
-        slot.params.n_keep             = json_value(data, "n_keep",            slot.params.n_keep);
+        slot.params.n_keep             = json_value(data, "n_keep",            default_params.n_keep);
         slot.params.n_discard          = json_value(data, "n_discard",         default_params.n_discard);
         slot.sparams.seed              = json_value(data, "seed",              default_sparams.seed);
         slot.sparams.n_probs           = json_value(data, "n_probs",           default_sparams.n_probs);
@@ -1129,13 +1134,48 @@ struct server_context {
             SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
         }
 
-        // if we have already seen a new line, we stop after a certain time limit
-        if (slot.has_new_line && slot.params.t_max_predict_ms > 0 &&
-            (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
-            slot.stopped_limit  = true;
-            slot.has_next_token = false;
+        if (slot.has_new_line) {
+            // if we have already seen a new line, we stop after a certain time limit
+            if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
+                slot.stopped_limit  = true;
+                slot.has_next_token = false;
+
+                SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
+            }
+
+            // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
+            if (slot.params.n_indent > 0) {
+                // check the current indentation
+                // TODO: improve by not doing it more than once for each new line
+                if (slot.last_nl_pos > 0) {
+                    size_t pos = slot.last_nl_pos;
+
+                    int n_indent = 0;
+                    while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
+                        n_indent++;
+                        pos++;
+                    }
+
+                    if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
+                        slot.stopped_limit  = true;
+                        slot.has_next_token = false;
+
+                        // cut the last line
+                        slot.generated_text.erase(pos, std::string::npos);
 
-            SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
+                        SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
+                    }
+                }
+
+                // find the next new line
+                {
+                    const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos);
+
+                    if (pos != std::string::npos) {
+                        slot.last_nl_pos = pos + 1;
+                    }
+                }
+            }
         }
 
         // check if there is a new line in the generated text