Преглед изворни кода

sampling: fix top_k <= 0 (#5388)

* sampling: fix top_k <= 0

* Update llama.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Johannes Gäßler пре 1 година
родитељ
комит
26d4efd11e
3 измењених фајлова са 7 додато и 1 уклоњено
  1. 1 1
      common/sampling.cpp
  2. 4 0
      llama.cpp
  3. 2 0
      tests/test-sampling.cpp

+ 1 - 1
common/sampling.cpp

@@ -132,7 +132,7 @@ static void sampler_queue(
     const float         temp              = params.temp;
     const float         dynatemp_range    = params.dynatemp_range;
     const float         dynatemp_exponent = params.dynatemp_exponent;
-    const int32_t       top_k             = params.top_k <= 0 ? n_vocab : params.top_k;
+    const int32_t       top_k             = params.top_k;
     const float         top_p             = params.top_p;
     const float         min_p             = params.min_p;
     const float         tfs_z             = params.tfs_z;

+ 4 - 0
llama.cpp

@@ -8585,6 +8585,10 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can
     // }
 
     const int64_t t_start_sample_us = ggml_time_us();
+    
+    if (k <= 0) {
+        k = candidates->size;
+    }
 
     k = std::max(k, (int) min_keep);
     k = std::min(k, (int) candidates->size);

+ 2 - 0
tests/test-sampling.cpp

@@ -235,6 +235,8 @@ int main(void) {
 
     test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1);
     test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3);
+    test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
+    test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
 
     test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
     test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);