Jelajahi Sumber

sampling : refactor + optimize penalties sampler (#10803)

* sampling : refactor + optimize penalties sampler

ggml-ci

* common : apply ignore_eos as logit bias

ggml-ci

* batched : remove penalties sampler

* params : allow penalty_last_n == -1 to be equal to context size

ggml-ci

* common : by default, move the penalties at the end of the sampling chain

ggml-ci

* common : ignore all EOG tokens

Co-authored-by: Diego Devesa <slarengh@gmail.com>

* common : move back the penalties at the front of the sampling chain

ggml-ci

* readme : restore hint about --ignore-eos flag [no ci]

* llama : minor

ggml-ci

* webui : update

---------

Co-authored-by: Diego Devesa <slarengh@gmail.com>
Georgi Gerganov 1 tahun lalu
induk
melakukan
644fd71b44

+ 6 - 7
common/arg.cpp

@@ -855,13 +855,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.sampling.ignore_eos = true;
         }
     ).set_sparam());
-    add_opt(common_arg(
-        {"--penalize-nl"},
-        string_format("penalize newline tokens (default: %s)", params.sampling.penalize_nl ? "true" : "false"),
-        [](common_params & params) {
-            params.sampling.penalize_nl = true;
-        }
-    ).set_sparam());
     add_opt(common_arg(
         {"--temp"}, "N",
         string_format("temperature (default: %.1f)", (double)params.sampling.temp),
@@ -916,6 +909,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         {"--repeat-last-n"}, "N",
         string_format("last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", params.sampling.penalty_last_n),
         [](common_params & params, int value) {
+            if (value < -1) {
+                throw std::runtime_error(string_format("error: invalid repeat-last-n = %d\n", value));
+            }
             params.sampling.penalty_last_n = value;
             params.sampling.n_prev = std::max(params.sampling.n_prev, params.sampling.penalty_last_n);
         }
@@ -970,6 +966,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         {"--dry-penalty-last-n"}, "N",
         string_format("set DRY penalty for the last n tokens (default: %d, 0 = disable, -1 = context size)", params.sampling.dry_penalty_last_n),
         [](common_params & params, int value) {
+            if (value < -1) {
+                throw std::runtime_error(string_format("error: invalid dry-penalty-last-n = %d\n", value));
+            }
             params.sampling.dry_penalty_last_n = value;
         }
     ).set_sparam());

+ 19 - 0
common/common.cpp

@@ -940,6 +940,25 @@ struct common_init_result common_init_from_params(common_params & params) {
         params.sampling.ignore_eos = false;
     }
 
+    if (params.sampling.ignore_eos) {
+        for (llama_token i = 0; i < llama_n_vocab(model); i++) {
+            if (llama_token_is_eog(model, i)) {
+                LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
+                params.sampling.logit_bias.push_back({i, -INFINITY});
+            }
+        }
+    }
+
+    if (params.sampling.penalty_last_n == -1) {
+        LOG_INF("%s: setting penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
+        params.sampling.penalty_last_n = llama_n_ctx(lctx);
+    }
+
+    if (params.sampling.dry_penalty_last_n == -1) {
+        LOG_INF("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
+        params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
+    }
+
     if (params.warmup) {
         LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
 

+ 9 - 6
common/common.h

@@ -95,6 +95,7 @@ enum common_sampler_type {
     COMMON_SAMPLER_TYPE_TEMPERATURE = 7,
     COMMON_SAMPLER_TYPE_XTC         = 8,
     COMMON_SAMPLER_TYPE_INFILL      = 9,
+    COMMON_SAMPLER_TYPE_PENALTIES   = 10,
 };
 
 // dimensionality reduction methods, used by cvector-generator
@@ -130,7 +131,6 @@ struct common_params_sampling {
     int32_t mirostat           = 0;     // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
     float   mirostat_tau       = 5.00f; // target entropy
     float   mirostat_eta       = 0.10f; // learning rate
-    bool    penalize_nl        = false; // consider newlines as a repeatable token
     bool    ignore_eos         = false;
     bool    no_perf            = false; // disable performance metrics
     bool    timing_per_token   = false;
@@ -139,6 +139,7 @@ struct common_params_sampling {
 
 
     std::vector<enum common_sampler_type> samplers = {
+        COMMON_SAMPLER_TYPE_PENALTIES,
         COMMON_SAMPLER_TYPE_DRY,
         COMMON_SAMPLER_TYPE_TOP_K,
         COMMON_SAMPLER_TYPE_TYPICAL_P,
@@ -193,11 +194,13 @@ struct common_params {
     float   defrag_thold          =  0.1f; // KV cache defragmentation threshold
 
     // offload params
-    std::vector<ggml_backend_dev_t> devices;         // devices to use for offloading
-    int32_t n_gpu_layers                    =    -1; // number of layers to store in VRAM (-1 - use default)
-    int32_t main_gpu                        =     0; // the GPU that is used for scratch and small tensors
-    float   tensor_split[128]               =   {0}; // how split tensors should be distributed across GPUs
-    enum llama_split_mode        split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
+    std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
+
+    int32_t n_gpu_layers      = -1;  // number of layers to store in VRAM (-1 - use default)
+    int32_t main_gpu          = 0;   // the GPU that is used for scratch and small tensors
+    float   tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
+
+    enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
 
     struct cpu_params cpuparams;
     struct cpu_params cpuparams_batch;

+ 11 - 16
common/sampling.cpp

@@ -161,32 +161,20 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
                 params.logit_bias.size(),
                 params.logit_bias.data()));
 
-    llama_sampler_chain_add(result->chain,
-            llama_sampler_init_penalties(
-                llama_n_vocab  (model),
-                llama_token_eos(model),
-                llama_token_nl (model),
-                params.penalty_last_n,
-                params.penalty_repeat,
-                params.penalty_freq,
-                params.penalty_present,
-                params.penalize_nl,
-                params.ignore_eos));
-
     if (params.mirostat == 0) {
         for (const auto & cnstr : params.samplers) {
             switch (cnstr) {
-                    case COMMON_SAMPLER_TYPE_DRY:
+                case COMMON_SAMPLER_TYPE_DRY:
                     {
-                        std::vector<const char*> c_breakers;
+                        std::vector<const char *> c_breakers;
                         c_breakers.reserve(params.dry_sequence_breakers.size());
-                        for (const auto& str : params.dry_sequence_breakers) {
+                        for (const auto & str : params.dry_sequence_breakers) {
                             c_breakers.push_back(str.c_str());
                         }
 
                         llama_sampler_chain_add(result->chain, llama_sampler_init_dry      (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
                     }
-                        break;
+                    break;
                 case COMMON_SAMPLER_TYPE_TOP_K:
                     llama_sampler_chain_add(result->chain, llama_sampler_init_top_k    (params.top_k));
                     break;
@@ -208,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
                 case COMMON_SAMPLER_TYPE_INFILL:
                     llama_sampler_chain_add(result->chain, llama_sampler_init_infill   (model));
                     break;
+                case COMMON_SAMPLER_TYPE_PENALTIES:
+                    llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
+                    break;
                 default:
                     GGML_ASSERT(false && "unknown sampler type");
             }
@@ -415,6 +406,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
         case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
         case COMMON_SAMPLER_TYPE_XTC:         return 'x';
         case COMMON_SAMPLER_TYPE_INFILL:      return 'i';
+        case COMMON_SAMPLER_TYPE_PENALTIES:   return 'e';
         default : return '?';
     }
 }
@@ -429,6 +421,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
         case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
         case COMMON_SAMPLER_TYPE_XTC:         return "xtc";
         case COMMON_SAMPLER_TYPE_INFILL:      return "infill";
+        case COMMON_SAMPLER_TYPE_PENALTIES:   return "penalties";
         default : return "";
     }
 }
@@ -443,6 +436,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
         { "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
         { "xtc",         COMMON_SAMPLER_TYPE_XTC },
         { "infill",      COMMON_SAMPLER_TYPE_INFILL },
+        { "penalties",   COMMON_SAMPLER_TYPE_PENALTIES },
     };
 
     // since samplers names are written multiple ways
@@ -489,6 +483,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC),         COMMON_SAMPLER_TYPE_XTC },
         { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL),      COMMON_SAMPLER_TYPE_INFILL },
+        { common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_PENALTIES),   COMMON_SAMPLER_TYPE_PENALTIES },
     };
 
     std::vector<common_sampler_type> samplers;

+ 1 - 0
examples/batched/batched.cpp

@@ -65,6 +65,7 @@ int main(int argc, char ** argv) {
     llama_context * ctx = llama_new_context_with_model(model, ctx_params);
 
     auto sparams = llama_sampler_chain_default_params();
+    sparams.no_perf = false;
 
     llama_sampler * smpl = llama_sampler_chain_init(sparams);
 

+ 0 - 5
examples/main/README.md

@@ -177,16 +177,11 @@ Example usage: `--temp 0`
 
 -   `--repeat-penalty N`: Control the repetition of token sequences in the generated text default: 1.0, 1.0 = disabled).
 -   `--repeat-last-n N`: Last n tokens to consider for penalizing repetition (default: 64, 0 = disabled, -1 = ctx-size).
--   `--no-penalize-nl`: Disable penalization for newline tokens when applying the repeat penalty.
 
 The `repeat-penalty` option helps prevent the model from generating repetitive or monotonous text. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. The default value is 1.
 
 The `repeat-last-n` option controls the number of tokens in the history to consider for penalizing repetition. A larger value will look further back in the generated text to prevent repetitions, while a smaller value will only consider recent tokens. A value of 0 disables the penalty, and a value of -1 sets the number of tokens considered equal to the context size (`ctx-size`).
 
-Use the `--no-penalize-nl` option to disable newline penalization when applying the repeat penalty. This option is particularly useful for generating chat conversations, dialogues, code, poetry, or any text where newline tokens play a significant role in structure and formatting. Disabling newline penalization helps maintain the natural flow and intended formatting in these specific use cases.
-
-Example usage: `--repeat-penalty 1.15 --repeat-last-n 128 --no-penalize-nl`
-
 ### DRY Repetition Penalty
 
 DRY (Don't Repeat Yourself) sampling is an effective technique for reducing repetition in generated text even across long contexts by penalizing tokens based on their recent usage patterns (original [PR link](https://github.com/oobabooga/text-generation-webui/pull/5677)).

+ 0 - 5
examples/server/README.md

@@ -104,7 +104,6 @@ The project is under active development, and we are [looking for feedback and co
 | `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) |
 | `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: dkypmxt) |
 | `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
-| `--penalize-nl` | penalize newline tokens (default: false) |
 | `--temp N` | temperature (default: 0.8) |
 | `--top-k N` | top-k sampling (default: 40, 0 = disabled) |
 | `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) |
@@ -393,8 +392,6 @@ These words will not be included in the completion, so make sure to add them to
 
 `repeat_last_n`: Last n tokens to consider for penalizing repetition. Default: `64`, where `0` is disabled and `-1` is ctx-size.
 
-`penalize_nl`: Penalize newline tokens when applying the repeat penalty. Default: `true`
-
 `presence_penalty`: Repeat alpha presence penalty. Default: `0.0`, which is disabled.
 
 `frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
@@ -655,7 +652,6 @@ This endpoint is public (no API key check). By default, it is read-only. To make
       "mirostat": 0,
       "mirostat_tau": 5.0,
       "mirostat_eta": 0.10000000149011612,
-      "penalize_nl": false,
       "stop": [],
       "max_tokens": -1,
       "n_keep": 0,
@@ -845,7 +841,6 @@ Example:
       "mirostat": 0,
       "mirostat_tau": 5.0,
       "mirostat_eta": 0.10000000149011612,
-      "penalize_nl": false,
       "stop": [],
       "max_tokens": -1,
       "n_keep": 0,

TEMPAT SAMPAH
examples/server/public/index.html.gz


+ 0 - 1
examples/server/public_legacy/index-new.html

@@ -39,7 +39,6 @@
       temperature: 0.8, // adapt all following parameters to optimized min-p requierements. If for non-english, set to 0.6 or lower
       repeat_last_n: 0, // 0 = disable penalty, -1 = context size
       repeat_penalty: 1.0, // 1.0 = disabled
-      penalize_nl: false, // true only useful for infinite completion
       dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
       dry_base: 1.75,     // 0.0 = disabled
       dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well

+ 0 - 2
examples/server/public_legacy/index.html

@@ -303,7 +303,6 @@
       temperature: 0.7,
       repeat_last_n: 256, // 0 = disable penalty, -1 = context size
       repeat_penalty: 1.18, // 1.0 = disabled
-      penalize_nl: false,
       dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well
       dry_base: 1.75,     // 0.0 = disabled
       dry_allowed_length: 2, // tokens extending repetitions beyond this receive penalty, 2 works well
@@ -1006,7 +1005,6 @@
             ${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
             ${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
             ${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
-            ${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
             ${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
             ${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
             ${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}

+ 23 - 5
examples/server/server.cpp

@@ -135,7 +135,6 @@ struct slot_params {
             {"mirostat",                  sampling.mirostat},
             {"mirostat_tau",              sampling.mirostat_tau},
             {"mirostat_eta",              sampling.mirostat_eta},
-            {"penalize_nl",               sampling.penalize_nl},
             {"stop",                      antiprompt},
             {"max_tokens",                n_predict}, // User configured n_predict
             {"n_keep",                    n_keep},
@@ -184,6 +183,7 @@ struct server_task {
 
     static slot_params params_from_json_cmpl(
             const llama_model * model,
+            const llama_context * ctx,
             const common_params & params_base,
             const json & data) {
         slot_params params;
@@ -226,7 +226,6 @@ struct server_task {
         params.sampling.mirostat           = json_value(data, "mirostat",           defaults.sampling.mirostat);
         params.sampling.mirostat_tau       = json_value(data, "mirostat_tau",       defaults.sampling.mirostat_tau);
         params.sampling.mirostat_eta       = json_value(data, "mirostat_eta",       defaults.sampling.mirostat_eta);
-        params.sampling.penalize_nl        = json_value(data, "penalize_nl",        defaults.sampling.penalize_nl);
         params.sampling.seed               = json_value(data, "seed",               defaults.sampling.seed);
         params.sampling.n_probs            = json_value(data, "n_probs",            defaults.sampling.n_probs);
         params.sampling.min_keep           = json_value(data, "min_keep",           defaults.sampling.min_keep);
@@ -239,8 +238,27 @@ struct server_task {
         params.speculative.n_min = std::max(params.speculative.n_min, 2);
         params.speculative.n_max = std::max(params.speculative.n_max, 0);
 
+        // TODO: add more sanity checks for the input parameters
+
+        if (params.sampling.penalty_last_n < -1) {
+            throw std::runtime_error("Error: repeat_last_n must be >= -1");
+        }
+
+        if (params.sampling.dry_penalty_last_n < -1) {
+            throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
+        }
+
+        if (params.sampling.penalty_last_n == -1) {
+            // note: should be the slot's context and not the full context, but it's ok
+            params.sampling.penalty_last_n = llama_n_ctx(ctx);
+        }
+
+        if (params.sampling.dry_penalty_last_n == -1) {
+            params.sampling.dry_penalty_last_n = llama_n_ctx(ctx);
+        }
+
         if (params.sampling.dry_base < 1.0f) {
-           params.sampling.dry_base = defaults.sampling.dry_base;
+            params.sampling.dry_base = defaults.sampling.dry_base;
         }
 
         // sequence breakers for DRY
@@ -1469,7 +1487,7 @@ struct server_context {
         n_ctx = llama_n_ctx(ctx);
 
         add_bos_token = llama_add_bos_token(model);
-        has_eos_token = !llama_add_eos_token(model);
+        has_eos_token = llama_token_eos(model) != LLAMA_TOKEN_NULL;
 
         if (!params_base.speculative.model.empty()) {
             SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
@@ -3381,7 +3399,7 @@ int main(int argc, char ** argv) {
                 task.index = i;
 
                 task.prompt_tokens    = std::move(tokenized_prompts[i]);
-                task.params           = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.params_base, data);
+                task.params           = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
                 task.id_selected_slot = json_value(data, "id_slot", -1);
 
                 // OAI-compat

+ 0 - 2
examples/server/themes/buttons-top/index.html

@@ -222,7 +222,6 @@
       temperature: 0.7,
       repeat_last_n: 256, // 0 = disable penalty, -1 = context size
       repeat_penalty: 1.18, // 1.0 = disabled
-      penalize_nl: false,
       top_k: 40, // <= 0 to use vocab size
       top_p: 0.95, // 1.0 = disabled
       min_p: 0.05, // 0 = disabled
@@ -779,7 +778,6 @@
             ${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
             ${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
             ${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
-            ${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
             ${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
             ${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
             ${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}

+ 0 - 2
examples/server/themes/wild/index.html

@@ -225,7 +225,6 @@
       temperature: 0.7,
       repeat_last_n: 256, // 0 = disable penalty, -1 = context size
       repeat_penalty: 1.18, // 1.0 = disabled
-      penalize_nl: false,
       top_k: 40, // <= 0 to use vocab size
       top_p: 0.95, // 1.0 = disabled
       min_p: 0.05, // 0 = disabled
@@ -782,7 +781,6 @@
             ${FloatField({ label: "Temperature", max: 2.0, min: 0.0, name: "temperature", step: 0.01, value: params.value.temperature })}
             ${FloatField({ label: "Penalize repeat sequence", max: 2.0, min: 0.0, name: "repeat_penalty", step: 0.01, value: params.value.repeat_penalty })}
             ${IntField({ label: "Consider N tokens for penalize", max: 2048, min: 0, name: "repeat_last_n", value: params.value.repeat_last_n })}
-            ${BoolField({ label: "Penalize repetition of newlines", name: "penalize_nl", value: params.value.penalize_nl })}
             ${IntField({ label: "Top-K sampling", max: 100, min: -1, name: "top_k", value: params.value.top_k })}
             ${FloatField({ label: "Top-P sampling", max: 1.0, min: 0.0, name: "top_p", step: 0.01, value: params.value.top_p })}
             ${FloatField({ label: "Min-P sampling", max: 1.0, min: 0.0, name: "min_p", step: 0.01, value: params.value.min_p })}

+ 1 - 1
examples/server/webui/src/main.js

@@ -33,7 +33,7 @@ const CONFIG_DEFAULT = {
   systemMessage: 'You are a helpful assistant.',
   showTokensPerSecond: false,
   // make sure these default values are in sync with `common.h`
-  samplers: 'dkypmxt',
+  samplers: 'edkypmxt',
   temperature: 0.8,
   dynatemp_range: 0.0,
   dynatemp_exponent: 1.0,

+ 5 - 9
include/llama.h

@@ -1139,16 +1139,12 @@ extern "C" {
                           const char * grammar_str,
                           const char * grammar_root);
 
+    /// NOTE: Avoid using on the full vocabulary as searching for repeated tokens can become slow. For example, apply top-k or top-p sampling first.
     LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
-                             int32_t   n_vocab,         // llama_n_vocab()
-                         llama_token   special_eos_id,  // llama_token_eos()
-                         llama_token   linefeed_id,     // llama_token_nl()
-                             int32_t   penalty_last_n,  // last n tokens to penalize (0 = disable penalty, -1 = context size)
-                               float   penalty_repeat,  // 1.0 = disabled
-                               float   penalty_freq,    // 0.0 = disabled
-                               float   penalty_present, // 0.0 = disabled
-                                bool   penalize_nl,     // consider newlines as a repeatable token
-                                bool   ignore_eos);     // ignore the end-of-sequence token
+                             int32_t   penalty_last_n,   // last n tokens to penalize (0 = disable penalty, -1 = context size)
+                               float   penalty_repeat,   // 1.0 = disabled
+                               float   penalty_freq,     // 0.0 = disabled
+                               float   penalty_present); // 0.0 = disabled
 
     ///  @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
     LLAMA_API struct llama_sampler *    llama_sampler_init_dry(

+ 35 - 90
src/llama-sampling.cpp

@@ -1396,19 +1396,15 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab
 // penalties
 
 struct llama_sampler_penalties {
-    const int32_t     n_vocab;
-    const llama_token special_eos_id;
-    const llama_token linefeed_id;
-
     const int32_t penalty_last_n;
     const float   penalty_repeat;
     const float   penalty_freq;
     const float   penalty_present;
 
-    const bool    penalize_nl;
-    const bool    ignore_eos;
-
     ring_buffer<llama_token> prev;
+
+    // a frequency map to count token occurrences
+    std::unordered_map<llama_token, int> token_count;
 };
 
 static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
@@ -1421,76 +1417,50 @@ static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_to
         return;
     }
 
-    ctx->prev.push_back(token);
-}
-
-static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
-    auto * ctx = (llama_sampler_penalties *) smpl->ctx;
+    ctx->token_count[token]++;
 
-    if (ctx->ignore_eos) {
-        assert(ctx->special_eos_id >= 0);
+    // if the ring buffer is full, remove the oldest token
+    if (ctx->prev.size() >= (size_t) ctx->penalty_last_n) {
+        const auto old = ctx->prev.front();
 
-        // optimistically check if the candidates are not yet sorted/shuffled/truncated
-        if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
-            cur_p->data[ctx->special_eos_id].logit = -INFINITY;
-        } else {
-            // else, search for the special EOS token
-            for (size_t i = 0; i < cur_p->size; ++i) {
-                if (cur_p->data[i].id == ctx->special_eos_id) {
-                    cur_p->data[i].logit = -INFINITY;
-                    break;
-                }
-            }
+        ctx->token_count[old]--;
+        if (ctx->token_count[old] == 0) {
+            ctx->token_count.erase(old);
         }
     }
 
-    if ((ctx->penalty_last_n == 0) ||
-        (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
-        return;
-    }
-
-    bool nl_found = false;
-    size_t nl_idx = 0;
-    float nl_logit = -INFINITY;
-    if (!ctx->penalize_nl) {
-        assert(ctx->linefeed_id >= 0);
+    ctx->prev.push_back(token);
 
-        // optimistically check if the candidates are not yet sorted/shuffled/truncated
-        if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
-            nl_found = true;
-            nl_idx = ctx->linefeed_id;
-            nl_logit = cur_p->data[ctx->linefeed_id].logit;
-        } else {
-            // else, search for the linefeed token
-            for (size_t i = 0; i < cur_p->size; ++i) {
-                if (cur_p->data[i].id == ctx->linefeed_id) {
-                    nl_found = true;
-                    nl_idx = i;
-                    nl_logit = cur_p->data[i].logit;
-                    break;
-                }
-            }
-        }
+#if 0
+    // sanity check
+    std::unordered_map<llama_token, int> tmp;
+    for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
+        tmp[ctx->prev.rat(i)]++;
     }
 
-    // Create a frequency map to count occurrences of each token in last_tokens
-    // TODO: optimize this by maintaining the token count in the sampler context
-    using llama_token_cnt = std::unordered_map<llama_token, int>;
-    llama_token_cnt token_count;
+    assert(ctx->token_count == tmp);
+#endif
+}
+
+static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
+    auto * ctx = (llama_sampler_penalties *) smpl->ctx;
 
-    for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
-        token_count[ctx->prev.rat(i)]++;
+    if ((ctx->penalty_last_n == 0) ||
+        (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
+        return;
     }
 
     // Apply frequency and presence penalties to the cur_p
     for (size_t i = 0; i < cur_p->size; ++i) {
-        const auto token_iter = token_count.find(cur_p->data[i].id);
-        if (token_iter == token_count.end()) {
+        const auto token_iter = ctx->token_count.find(cur_p->data[i].id);
+        if (token_iter == ctx->token_count.end()) {
             continue;
         }
 
         const int count = token_iter->second;
 
+        assert(count > 0 && count <= ctx->penalty_last_n);
+
         // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
         // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
         if (cur_p->data[i].logit <= 0) {
@@ -1503,30 +1473,21 @@ static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_tok
     }
 
     cur_p->sorted = false;
-
-    if (!ctx->penalize_nl && nl_found) {
-        // restore the logit of the newline token if it was penalized
-        cur_p->data[nl_idx].logit = nl_logit;
-    }
 }
 
 static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
     auto * ctx = (llama_sampler_penalties *) smpl->ctx;
     ctx->prev.clear();
+    ctx->token_count.clear();
 }
 
 static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
     const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
     auto * result = llama_sampler_init_penalties(
-            ctx->n_vocab,
-            ctx->special_eos_id,
-            ctx->linefeed_id,
             ctx->penalty_last_n,
             ctx->penalty_repeat,
             ctx->penalty_freq,
-            ctx->penalty_present,
-            ctx->penalize_nl,
-            ctx->ignore_eos);
+            ctx->penalty_present);
 
     // copy the state
     {
@@ -1552,38 +1513,21 @@ static struct llama_sampler_i llama_sampler_penalties_i = {
 };
 
 struct llama_sampler * llama_sampler_init_penalties(
-        int32_t n_vocab,
-        llama_token special_eos_id,
-        llama_token linefeed_id,
         int32_t penalty_last_n,
         float penalty_repeat,
         float penalty_freq,
-        float penalty_present,
-        bool penalize_nl,
-        bool ignore_eos) {
-    if (linefeed_id == LLAMA_TOKEN_NULL) {
-        penalize_nl = true;
-    }
-
-    if (special_eos_id == LLAMA_TOKEN_NULL) {
-        ignore_eos = false;
-    }
-
+        float penalty_present) {
     penalty_last_n = std::max(penalty_last_n, 0);
 
     return new llama_sampler {
         /* .iface = */ &llama_sampler_penalties_i,
         /* .ctx   = */ new llama_sampler_penalties {
-            /* .n_vocab         = */ n_vocab,
-            /* .special_eos_id  = */ special_eos_id,
-            /* .linefeed_id     = */ linefeed_id,
             /* .penalty_last_n  = */ penalty_last_n,
             /* .penalty_repeat  = */ penalty_repeat,
             /* .penalty_freq    = */ penalty_freq,
             /* .penalty_present = */ penalty_present,
-            /* .penalize_nl     = */ penalize_nl,
-            /* .ignore_eos      = */ ignore_eos,
             /* .prev            = */ ring_buffer<llama_token>(penalty_last_n),
+            /* .token_count     = */ {},
         },
     };
 }
@@ -1611,7 +1555,8 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std
         if (word.find(str) != std::string::npos) {
             token_sequences.emplace(token_id, std::vector<llama_token>());
         } else {
-            size_t word_len = word.size(), str_len = str.size();
+            size_t word_len = word.size();
+            size_t str_len = str.size();
             size_t pos = -1;
             while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
                 bool match = true;

+ 1 - 1
tests/test-sampling.cpp

@@ -145,7 +145,7 @@ static void test_penalties(
     sampler_tester tester(probs, probs_expected);
 
     const size_t n_vocab = probs.size();
-    auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
+    auto * sampler = llama_sampler_init_penalties(last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
 
     for (size_t i = 0; i < last_tokens.size(); i++) {
         llama_sampler_accept(sampler, last_tokens[i]);