Browse Source

`grammars`: fix resampling logic regression (#7424)

Olivier Chafik 1 year ago
parent
commit
e402de364b
2 changed files with 9 additions and 8 deletions
  1. 7 6
      common/sampling.cpp
  2. 2 2
      examples/main/main.cpp

+ 7 - 6
common/sampling.cpp

@@ -179,7 +179,7 @@ static llama_token llama_sampling_sample_impl(
                   struct llama_context * ctx_main,
                   struct llama_context * ctx_main,
                   struct llama_context * ctx_cfg,
                   struct llama_context * ctx_cfg,
                   const int idx,
                   const int idx,
-                  bool is_resampling) {  // Add a parameter to indicate if we are resampling
+                  bool is_resampling) {
     const llama_sampling_params & params = ctx_sampling->params;
     const llama_sampling_params & params = ctx_sampling->params;
 
 
     const float   temp            = params.temp;
     const float   temp            = params.temp;
@@ -188,8 +188,8 @@ static llama_token llama_sampling_sample_impl(
     const float   mirostat_eta    = params.mirostat_eta;
     const float   mirostat_eta    = params.mirostat_eta;
 
 
     std::vector<float> original_logits;
     std::vector<float> original_logits;
-    auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, !is_resampling, &original_logits);
-    if (!is_resampling) {
+    auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits);
+    if (ctx_sampling->grammar != NULL && !is_resampling) {
         GGML_ASSERT(!original_logits.empty());
         GGML_ASSERT(!original_logits.empty());
     }
     }
     llama_token id = 0;
     llama_token id = 0;
@@ -252,7 +252,7 @@ static llama_token llama_sampling_sample_impl(
             // Restore logits from the copy
             // Restore logits from the copy
             std::copy(original_logits.begin(), original_logits.end(), logits);
             std::copy(original_logits.begin(), original_logits.end(), logits);
 
 
-            return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, true);  // Pass true for is_resampling
+            return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true);
         }
         }
     }
     }
 
 
@@ -285,7 +285,8 @@ static llama_token_data_array llama_sampling_prepare_impl(
     // Get a pointer to the logits
     // Get a pointer to the logits
     float * logits = llama_get_logits_ith(ctx_main, idx);
     float * logits = llama_get_logits_ith(ctx_main, idx);
 
 
-    if (apply_grammar && original_logits != NULL) {
+    if (ctx_sampling->grammar != NULL && !apply_grammar) {
+        GGML_ASSERT(original_logits != NULL);
         // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
         // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this.
         *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
         *original_logits = {logits, logits + llama_n_vocab(llama_get_model(ctx_main))};
     }
     }
@@ -342,7 +343,7 @@ llama_token llama_sampling_sample(
                   struct llama_context * ctx_cfg,
                   struct llama_context * ctx_cfg,
                   const int idx) {
                   const int idx) {
     // Call the implementation function with is_resampling set to false by default
     // Call the implementation function with is_resampling set to false by default
-    return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
+    return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false);
 }
 }
 
 
 llama_token_data_array llama_sampling_prepare(
 llama_token_data_array llama_sampling_prepare(

+ 2 - 2
examples/main/main.cpp

@@ -707,7 +707,7 @@ int main(int argc, char ** argv) {
 
 
             const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
             const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
 
 
-            llama_sampling_accept(ctx_sampling, ctx, id, true);
+            llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);
 
 
             LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
             LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
 
 
@@ -728,7 +728,7 @@ int main(int argc, char ** argv) {
 
 
                 // push the prompt in the sampling context in order to apply repetition penalties later
                 // push the prompt in the sampling context in order to apply repetition penalties later
                 // for the prompt, we don't apply grammar rules
                 // for the prompt, we don't apply grammar rules
-                llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
+                llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], /* apply_grammar= */ false);
 
 
                 ++n_consumed;
                 ++n_consumed;
                 if ((int) embd.size() >= params.n_batch) {
                 if ((int) embd.size() >= params.n_batch) {