Jelajahi Sumber

common : restore grammar-based rejection sampling (#18137)

* common : restart grammar-based rejection sampling

* sampling : allow null samplers
Georgi Gerganov 1 bulan lalu
induk
melakukan
4301e27319

+ 51 - 37
common/sampling.cpp

@@ -104,10 +104,9 @@ struct ring_buffer {
 struct common_sampler {
     common_params_sampling params;
 
+    struct llama_sampler * grmr;
     struct llama_sampler * chain;
 
-    bool grammar;
-
     ring_buffer<llama_token> prev;
 
     std::vector<llama_token_data> cur;
@@ -167,15 +166,14 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
 
     lparams.no_perf = params.no_perf;
 
+    llama_sampler * grmr = nullptr;
     llama_sampler * chain = llama_sampler_chain_init(lparams);
 
-    bool grammar = false;
     std::vector<llama_sampler *> samplers;
 
     if (params.grammar.compare(0, 11, "%llguidance") == 0) {
 #ifdef LLAMA_USE_LLGUIDANCE
-        samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()));
-        grammar = true;
+        grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
 #else
         GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
 #endif // LLAMA_USE_LLGUIDANCE
@@ -224,15 +222,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
 
         if (!params.grammar.empty()) {
              if (params.grammar_lazy) {
-                 samplers.push_back(
-                         llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
-                             trigger_patterns_c.data(), trigger_patterns_c.size(),
-                             trigger_tokens.data(),     trigger_tokens.size()));
+                 grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
+                         trigger_patterns_c.data(), trigger_patterns_c.size(),
+                         trigger_tokens.data(), trigger_tokens.size());
              } else {
-                 samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"));
+                 grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
              }
-
-             grammar = true;
         }
     }
 
@@ -303,8 +298,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
 
     auto * result = new common_sampler {
         /* .params  = */ params,
+        /* .grmr    = */ grmr,
         /* .chain   = */ chain,
-        /* .grammar = */ grammar,
         /* .prev    = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
         /* .cur     = */ {},
         /* .cur_p   = */ {},
@@ -315,6 +310,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
 
 void common_sampler_free(struct common_sampler * gsmpl) {
     if (gsmpl) {
+        llama_sampler_free(gsmpl->grmr);
         llama_sampler_free(gsmpl->chain);
 
         delete gsmpl;
@@ -324,25 +320,12 @@ void common_sampler_free(struct common_sampler * gsmpl) {
 void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
     const auto tm = gsmpl->tm();
 
-    if (gsmpl->grammar) {
-        const int n_smpl = llama_sampler_chain_n(gsmpl->chain);
-
-        for (int i = 0; i < n_smpl; i++) {
-            auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
-
-            // the grammar sampler is always the first one
-            if (i == 0) {
-                if (accept_grammar) {
-                    llama_sampler_accept(smpl, token);
-                }
-            } else {
-                llama_sampler_accept(smpl, token);
-            }
-        }
-    } else {
-        llama_sampler_accept(gsmpl->chain, token);
+    if (gsmpl->grmr && accept_grammar) {
+        llama_sampler_accept(gsmpl->grmr, token);
     }
 
+    llama_sampler_accept(gsmpl->chain, token);
+
     gsmpl->prev.push_back(token);
 }
 
@@ -353,8 +336,8 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
 struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
     return new common_sampler {
         /* .params  = */ gsmpl->params,
+        /* .grmr    = */ llama_sampler_clone(gsmpl->grmr),
         /* .chain   = */ llama_sampler_clone(gsmpl->chain),
-        /* .grammar = */ gsmpl->grammar,
         /* .prev    = */ gsmpl->prev,
         /* .cur     = */ gsmpl->cur,
         /* .cur_p   = */ gsmpl->cur_p,
@@ -410,7 +393,7 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
     return gsmpl->chain;
 }
 
-llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
+llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
     llama_synchronize(ctx);
 
     // start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
@@ -418,11 +401,42 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
 
     llama_token id = LLAMA_TOKEN_NULL;
 
+    auto & grmr  = gsmpl->grmr;
     auto & chain = gsmpl->chain;
     auto & cur_p = gsmpl->cur_p; // initialized by set_logits
 
     gsmpl->set_logits(ctx, idx);
 
+    if (grammar_first) {
+        llama_sampler_apply(grmr, &cur_p);
+    }
+
+    llama_sampler_apply(chain, &cur_p);
+
+    id = cur_p.data[cur_p.selected].id;
+
+    if (grammar_first) {
+        return id;
+    }
+
+    // check if it the sampled token fits the grammar (grammar-based rejection sampling)
+    {
+        llama_token_data       single_token_data       = { id, 1.0f, 0.0f };
+        llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
+
+        llama_sampler_apply(grmr, &single_token_data_array);
+
+        const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
+        if (is_valid) {
+            return id;
+        }
+    }
+
+    // resampling:
+    // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
+    gsmpl->set_logits(ctx, idx);
+
+    llama_sampler_apply(grmr,  &cur_p);
     llama_sampler_apply(chain, &cur_p);
 
     GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
@@ -432,7 +446,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
     return id;
 }
 
-std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft) {
+std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
     GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
 
     std::vector<llama_token> result;
@@ -440,7 +454,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
 
     size_t i = 0;
     for (; i < draft.size(); i++) {
-        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
+        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
 
         common_sampler_accept(gsmpl, id, true);
 
@@ -452,7 +466,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
     }
 
     if (i == draft.size()) {
-        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
+        const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
 
         common_sampler_accept(gsmpl, id, true);
 
@@ -462,13 +476,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
     return result;
 }
 
-std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
+std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
     std::vector<int> idxs(draft.size() + 1);
     for (size_t i = 0; i < idxs.size(); ++i) {
         idxs[i] = i;
     }
 
-    return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
+    return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
 }
 
 uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {

+ 6 - 3
common/sampling.h

@@ -57,7 +57,10 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
 // - check if the token fits the grammar (if any)
 // - if not: resample by first applying the grammar constraints and then sampling again (slower path)
 //
-llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx);
+// if grammar_first is true, the grammar is applied before the samplers (slower)
+// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
+//
+llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
 
 // generalized version of common_sampler_sample
 //
@@ -75,10 +78,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
 //
 // returns at least 1 token, up to idxs.size()
 //
-std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft);
+std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
 
 // assume idxs == [ 0, 1, 2, ..., draft.size() ]
-std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft);
+std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
 
 uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
 

+ 1 - 1
common/speculative.cpp

@@ -315,7 +315,7 @@ llama_tokens common_speculative_gen_draft(
     for (int i = 0; i < params.n_draft; ++i) {
         common_batch_clear(batch);
 
-        common_sampler_sample(smpl, ctx_dft, 0);
+        common_sampler_sample(smpl, ctx_dft, 0, true);
 
         const auto * cur_p = common_sampler_get_candidates(smpl, true);
 

+ 2 - 2
examples/speculative/speculative.cpp

@@ -242,7 +242,7 @@ int main(int argc, char ** argv) {
                 bool accept = false;
                 if (params.sampling.temp > 0) {
                     // stochastic verification
-                    common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
+                    common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
 
                     auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
 
@@ -491,7 +491,7 @@ int main(int argc, char ** argv) {
                     continue;
                 }
 
-                common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft);
+                common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
 
                 const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);
 

+ 16 - 0
src/llama-sampling.cpp

@@ -362,23 +362,39 @@ const char * llama_sampler_name(const struct llama_sampler * smpl) {
 }
 
 void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
+    if (!smpl) {
+        return;
+    }
+
     if (smpl->iface->accept) {
         smpl->iface->accept(smpl, token);
     }
 }
 
 void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
+    if (!smpl) {
+        return;
+    }
+
     GGML_ASSERT(smpl->iface->apply);
     smpl->iface->apply(smpl, cur_p);
 }
 
 void llama_sampler_reset(struct llama_sampler * smpl) {
+    if (!smpl) {
+        return;
+    }
+
     if (smpl->iface->reset) {
         smpl->iface->reset(smpl);
     }
 }
 
 struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
+    if (!smpl) {
+        return nullptr;
+    }
+
     if (smpl->iface->clone) {
         return smpl->iface->clone(smpl);
     }