فهرست منبع

server : pre-calculate EOG logit biases (#14721)

ggml-ci
Georgi Gerganov 6 ماه پیش
والد
کامیت
6ffd4e9c44
3فایلهای تغییر یافته به همراه17 افزوده شده و 15 حذف شده
  1. 12 6
      common/common.cpp
  2. 2 1
      common/common.h
  3. 3 8
      tools/server/server.cpp

+ 12 - 6
common/common.cpp

@@ -1005,15 +1005,21 @@ struct common_init_result common_init_from_params(common_params & params) {
         params.sampling.ignore_eos = false;
     }
 
-    if (params.sampling.ignore_eos) {
-        for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
-            if (llama_vocab_is_eog(vocab, i)) {
-                LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
-                params.sampling.logit_bias.push_back({i, -INFINITY});
-            }
+    // initialize once
+    for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
+        if (llama_vocab_is_eog(vocab, i)) {
+            LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
+            params.sampling.logit_bias_eog.push_back({i, -INFINITY});
         }
     }
 
+    if (params.sampling.ignore_eos) {
+        // add EOG biases to the active set of logit biases
+        params.sampling.logit_bias.insert(
+                params.sampling.logit_bias.end(),
+                params.sampling.logit_bias_eog.begin(), params.sampling.logit_bias_eog.end());
+    }
+
     if (params.sampling.penalty_last_n == -1) {
         LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
         params.sampling.penalty_last_n = llama_n_ctx(lctx);

+ 2 - 1
common/common.h

@@ -177,7 +177,8 @@ struct common_params_sampling {
     std::vector<common_grammar_trigger> grammar_triggers; // optional triggers (for lazy grammars)
     std::set<llama_token>               preserved_tokens;
 
-    std::vector<llama_logit_bias> logit_bias; // logit biases to apply
+    std::vector<llama_logit_bias> logit_bias;     // logit biases to apply
+    std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
 
     // print the parameters into a string
     std::string print() const;

+ 3 - 8
tools/server/server.cpp

@@ -473,12 +473,9 @@ struct server_task {
 
             params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos);
             if (params.sampling.ignore_eos) {
-                for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
-                    if (llama_vocab_is_eog(vocab, i)) {
-                        //SRV_DBG("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(ctx, i).c_str(), -INFINITY);
-                        params.sampling.logit_bias.push_back({i, -INFINITY});
-                    }
-                }
+                params.sampling.logit_bias.insert(
+                        params.sampling.logit_bias.end(),
+                        defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end());
             }
         }
 
@@ -1906,7 +1903,6 @@ struct server_context {
 
     bool clean_kv_cache = true;
     bool add_bos_token  = true;
-    bool has_eos_token  = false;
 
     int32_t n_ctx; // total context for all clients / slots
 
@@ -1965,7 +1961,6 @@ struct server_context {
         n_ctx = llama_n_ctx(ctx);
 
         add_bos_token = llama_vocab_get_add_bos(vocab);
-        has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
 
         if (!params_base.speculative.model.path.empty() || !params_base.speculative.model.hf_repo.empty()) {
             SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str());