Kaynağa Gözat

llama : refactor samplers internal implementation (#9370)

slaren 1 yıl önce
ebeveyn
işleme
19f4a7b296
4 değiştirilmiş dosya ile 509 ekleme ve 438 silme
  1. 4 0
      src/llama-impl.h
  2. 499 424
      src/llama-sampling.cpp
  3. 0 10
      src/llama-sampling.h
  4. 6 4
      tests/test-sampling.cpp

+ 4 - 0
src/llama-impl.h

@@ -101,6 +101,10 @@ struct ring_buffer {
     }
 
     void push_back(const T & value) {
+        if (capacity == 0) {
+            throw std::runtime_error("ring buffer: capacity is zero");
+        }
+
         if (sz == capacity) {
             // advance the start when buffer is full
             first = (first + 1) % capacity;

Dosya farkı çok büyük olduğundan ihmal edildi
+ 499 - 424
src/llama-sampling.cpp


+ 0 - 10
src/llama-sampling.h

@@ -23,16 +23,6 @@ struct llama_sampler_chain {
     mutable int32_t n_sample;
 };
 
-using llama_token_cnt = std::unordered_map<llama_token, int>;
-
-// TODO: tmp exposed until test-sampling is fixed
-void llama_sampler_penalties_impl(
-       llama_token_data_array * cur_p,
-        const llama_token_cnt & token_count,
-                        float   penalty_repeat,
-                        float   penalty_freq,
-                        float   penalty_present);
-
 struct llama_sampler * llama_sampler_init_grammar_impl(
         const struct llama_vocab & vocab,
                       const char * grammar_str,

+ 6 - 4
tests/test-sampling.cpp

@@ -148,15 +148,17 @@ static void test_penalties(
         cur.emplace_back(llama_token_data{token_id, logit, 0.0f});
     }
 
-    llama_token_cnt token_count;
+    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
+
+    auto * sampler = llama_sampler_init_penalties(n_vocab, LLAMA_TOKEN_NULL, LLAMA_TOKEN_NULL, last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence, false, false);
+
     for (size_t i = 0; i < last_tokens.size(); i++) {
-        token_count[last_tokens[i]]++;
+        llama_sampler_accept(sampler, last_tokens[i]);
     }
 
-    llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false };
     APPLY(llama_sampler_init_softmax(), &cur_p);
     DUMP(&cur_p);
-    llama_sampler_penalties_impl(&cur_p, token_count, repeat_penalty, alpha_frequency, alpha_presence); // TODO: avoid
+    APPLY(sampler, &cur_p);
     APPLY(llama_sampler_init_softmax(), &cur_p);
     DUMP(&cur_p);
 

Bu fark içinde çok fazla dosya değişikliği olduğu için bazı dosyalar gösterilmiyor