ソースを参照

sampling : when top-k <= 0 -> noop (#13173)

ggml-ci
Georgi Gerganov 8 ヶ月 前
コミット
d9d398f84f
2 ファイル変更3 行追加1 行削除
  1. 1 0
      include/llama.h
  2. 2 1
      src/llama-sampling.cpp

+ 1 - 0
include/llama.h

@@ -1232,6 +1232,7 @@ extern "C" {
         "will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)");
         "will be removed in the future (see https://github.com/ggml-org/llama.cpp/pull/9896#discussion_r1800920915)");
 
 
     /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
     /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
+    /// Setting k <= 0 makes this a noop
     LLAMA_API struct llama_sampler * llama_sampler_init_top_k      (int32_t k);
     LLAMA_API struct llama_sampler * llama_sampler_init_top_k      (int32_t k);
 
 
     /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
     /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751

+ 2 - 1
src/llama-sampling.cpp

@@ -232,7 +232,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
     // }
     // }
 
 
     if (k <= 0) {
     if (k <= 0) {
-        k = cur_p->size;
+        return;
     }
     }
 
 
     k = std::min(k, (int) cur_p->size);
     k = std::min(k, (int) cur_p->size);
@@ -298,6 +298,7 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
         }
         }
         cur_p->sorted = true;
         cur_p->sorted = true;
     }
     }
+
     cur_p->size = k;
     cur_p->size = k;
 }
 }