Explorar el Código

common, server : surface min_keep as its own parameter (#5567)

* Feature - surface min_keep as its own parameter

* Updated README with min_keep param
Robey Holderith hace 1 año
padre
commit
5ee99c32f5

+ 1 - 0
common/common.cpp

@@ -1704,6 +1704,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
     }
     fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
     fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
+    fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep);
     fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
     fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
     fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);

+ 4 - 1
common/sampling.cpp

@@ -248,7 +248,10 @@ static llama_token llama_sampling_sample_impl(
             llama_sample_temp(ctx_main, &cur_p, temp);
             id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
         } else {
-            sampler_queue(ctx_main, params, cur_p, 1);
+            // temperature sampling
+            size_t min_keep = std::max(1, params.min_keep);
+
+            sampler_queue(ctx_main, params, cur_p, min_keep);
 
             id = llama_sample_token(ctx_main, &cur_p);
 

+ 1 - 0
common/sampling.h

@@ -22,6 +22,7 @@ enum class llama_sampler_type : char {
 typedef struct llama_sampling_params {
     int32_t     n_prev                = 64;       // number of previous tokens to remember
     int32_t     n_probs               = 0;        // if greater than 0, output the probabilities of top n_probs tokens.
+    int32_t     min_keep              = 0;        // 0 = disabled, otherwise samplers should return at least min_keep tokens
     int32_t     top_k                 = 40;       // <= 0 to use vocab size
     float       top_p                 = 0.95f;    // 1.0 = disabled
     float       min_p                 = 0.05f;    // 0.0 = disabled

+ 2 - 0
examples/server/README.md

@@ -199,6 +199,8 @@ node index.js
 
     `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token (default: 0)
 
+    `min_keep`: If greater than 0, force samplers to return N possible tokens at minimum (default: 0)
+
     `image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `prompt`. You can determine the place of the image in the prompt as in the following: `USER:[img-12]Describe the image in detail.\nASSISTANT:`. In this case, `[img-12]` will be replaced by the embeddings of the image with id `12` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 12}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
 
     `slot_id`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot (default: -1)

+ 4 - 0
examples/server/public/index.html

@@ -234,6 +234,7 @@
       mirostat_eta: 0.1, // learning rate
       grammar: '',
       n_probs: 0, // no completion_probabilities,
+      min_keep: 0, // min probs from each sampler,
       image_data: [],
       cache_prompt: true,
       api_key: ''
@@ -791,6 +792,9 @@
             <fieldset>
               ${IntField({ label: "Show Probabilities", max: 10, min: 0, name: "n_probs", value: params.value.n_probs })}
             </fieldset>
+            <fieldset>
+              ${IntField({ label: "Min Probabilities from each Sampler", max: 10, min: 0, name: "min_keep", value: params.value.min_keep })}
+            </fieldset>
             <fieldset>
               <label for="api_key">API Key</label>
               <input type="text" name="api_key" value="${params.value.api_key}" placeholder="Enter API key" oninput=${updateParams} />

+ 2 - 0
examples/server/server.cpp

@@ -548,6 +548,7 @@ struct llama_server_context
         slot->params.seed               = json_value(data, "seed",              default_params.seed);
         slot->sparams.grammar           = json_value(data, "grammar",           default_sparams.grammar);
         slot->sparams.n_probs           = json_value(data, "n_probs",           default_sparams.n_probs);
+        slot->sparams.min_keep          = json_value(data, "min_keep",          default_sparams.min_keep);
 
         if (slot->n_predict > 0 && slot->params.n_predict > slot->n_predict) {
             // Might be better to reject the request with a 400 ?
@@ -1093,6 +1094,7 @@ struct llama_server_context
             {"stream",            slot.params.stream},
             {"logit_bias",        slot.sparams.logit_bias},
             {"n_probs",           slot.sparams.n_probs},
+            {"min_keep",          slot.sparams.min_keep},
             {"grammar",           slot.sparams.grammar},
             {"samplers",          samplers_sequence}
         };