Przeglądaj źródła

llama : add option for greedy sampling with probs (#3813)

* llama : add option for greedy sampling with probs

* llama : add comment about llama_sample_token_greedy() missing probs

* sampling : temp == 0.0 -> no probs, temp < 0.0 -> probs
Georgi Gerganov 2 lat temu
rodzic
commit
ee1a0ec9cb
4 zmienionych plików z 9 dodań i 3 usunięć
  1. 1 0
      common/common.cpp
  2. 6 2
      common/sampling.cpp
  3. 1 1
      examples/speculative/speculative.cpp
  4. 1 0
      llama.h

+ 1 - 0
common/common.cpp

@@ -224,6 +224,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
                 break;
             }
             sparams.temp = std::stof(argv[i]);
+            sparams.temp = std::max(sparams.temp, 0.0f);
         } else if (arg == "--tfs") {
             if (++i >= argc) {
                 invalid_param = true;

+ 6 - 2
common/sampling.cpp

@@ -167,8 +167,12 @@ llama_token llama_sampling_sample(
         llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
     }
 
-    if (temp <= 0) {
-        // greedy sampling
+    if (temp < 0.0) {
+        // greedy sampling, with probs
+        llama_sample_softmax(ctx_main, &cur_p);
+        id = cur_p.data[0].id;
+    } else if (temp == 0.0) {
+        // greedy sampling, no probs
         id = llama_sample_token_greedy(ctx_main, &cur_p);
     } else {
         if (mirostat == 1) {

+ 1 - 1
examples/speculative/speculative.cpp

@@ -148,7 +148,7 @@ int main(int argc, char ** argv) {
     std::vector<seq_draft> drafts(n_seq_dft);
 
     params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
-    params.sparams.temp = std::max(0.01f, params.sparams.temp);
+    params.sparams.temp = -1.0f;    // force greedy sampling with probs for the draft model
 
     for (int s = 0; s < n_seq_dft; ++s) {
         drafts[s].ctx_sampling = llama_sampling_init(params.sparams);

+ 1 - 0
llama.h

@@ -658,6 +658,7 @@ extern "C" {
                            float * mu);
 
     /// @details Selects the token with the highest probability.
+    ///          Does not compute the token probabilities. Use llama_sample_softmax() instead.
     LLAMA_API llama_token llama_sample_token_greedy(
             struct llama_context * ctx,
           llama_token_data_array * candidates);