|
@@ -104,10 +104,9 @@ struct ring_buffer {
|
|
|
struct common_sampler {
|
|
struct common_sampler {
|
|
|
common_params_sampling params;
|
|
common_params_sampling params;
|
|
|
|
|
|
|
|
|
|
+ struct llama_sampler * grmr;
|
|
|
struct llama_sampler * chain;
|
|
struct llama_sampler * chain;
|
|
|
|
|
|
|
|
- bool grammar;
|
|
|
|
|
-
|
|
|
|
|
ring_buffer<llama_token> prev;
|
|
ring_buffer<llama_token> prev;
|
|
|
|
|
|
|
|
std::vector<llama_token_data> cur;
|
|
std::vector<llama_token_data> cur;
|
|
@@ -167,15 +166,14 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|
|
|
|
|
|
|
lparams.no_perf = params.no_perf;
|
|
lparams.no_perf = params.no_perf;
|
|
|
|
|
|
|
|
|
|
+ llama_sampler * grmr = nullptr;
|
|
|
llama_sampler * chain = llama_sampler_chain_init(lparams);
|
|
llama_sampler * chain = llama_sampler_chain_init(lparams);
|
|
|
|
|
|
|
|
- bool grammar = false;
|
|
|
|
|
std::vector<llama_sampler *> samplers;
|
|
std::vector<llama_sampler *> samplers;
|
|
|
|
|
|
|
|
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
|
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
|
|
#ifdef LLAMA_USE_LLGUIDANCE
|
|
#ifdef LLAMA_USE_LLGUIDANCE
|
|
|
- samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()));
|
|
|
|
|
- grammar = true;
|
|
|
|
|
|
|
+ grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
|
|
#else
|
|
#else
|
|
|
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
|
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
|
|
#endif // LLAMA_USE_LLGUIDANCE
|
|
#endif // LLAMA_USE_LLGUIDANCE
|
|
@@ -224,15 +222,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|
|
|
|
|
|
|
if (!params.grammar.empty()) {
|
|
if (!params.grammar.empty()) {
|
|
|
if (params.grammar_lazy) {
|
|
if (params.grammar_lazy) {
|
|
|
- samplers.push_back(
|
|
|
|
|
- llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
|
|
|
|
- trigger_patterns_c.data(), trigger_patterns_c.size(),
|
|
|
|
|
- trigger_tokens.data(), trigger_tokens.size()));
|
|
|
|
|
|
|
+ grmr = llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
|
|
|
|
+ trigger_patterns_c.data(), trigger_patterns_c.size(),
|
|
|
|
|
+ trigger_tokens.data(), trigger_tokens.size());
|
|
|
} else {
|
|
} else {
|
|
|
- samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"));
|
|
|
|
|
|
|
+ grmr = llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
|
|
}
|
|
}
|
|
|
-
|
|
|
|
|
- grammar = true;
|
|
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -303,8 +298,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|
|
|
|
|
|
|
auto * result = new common_sampler {
|
|
auto * result = new common_sampler {
|
|
|
/* .params = */ params,
|
|
/* .params = */ params,
|
|
|
|
|
+ /* .grmr = */ grmr,
|
|
|
/* .chain = */ chain,
|
|
/* .chain = */ chain,
|
|
|
- /* .grammar = */ grammar,
|
|
|
|
|
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
|
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
|
|
/* .cur = */ {},
|
|
/* .cur = */ {},
|
|
|
/* .cur_p = */ {},
|
|
/* .cur_p = */ {},
|
|
@@ -315,6 +310,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|
|
|
|
|
|
|
void common_sampler_free(struct common_sampler * gsmpl) {
|
|
void common_sampler_free(struct common_sampler * gsmpl) {
|
|
|
if (gsmpl) {
|
|
if (gsmpl) {
|
|
|
|
|
+ llama_sampler_free(gsmpl->grmr);
|
|
|
llama_sampler_free(gsmpl->chain);
|
|
llama_sampler_free(gsmpl->chain);
|
|
|
|
|
|
|
|
delete gsmpl;
|
|
delete gsmpl;
|
|
@@ -324,25 +320,12 @@ void common_sampler_free(struct common_sampler * gsmpl) {
|
|
|
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
|
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
|
|
const auto tm = gsmpl->tm();
|
|
const auto tm = gsmpl->tm();
|
|
|
|
|
|
|
|
- if (gsmpl->grammar) {
|
|
|
|
|
- const int n_smpl = llama_sampler_chain_n(gsmpl->chain);
|
|
|
|
|
-
|
|
|
|
|
- for (int i = 0; i < n_smpl; i++) {
|
|
|
|
|
- auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
|
|
|
|
-
|
|
|
|
|
- // the grammar sampler is always the first one
|
|
|
|
|
- if (i == 0) {
|
|
|
|
|
- if (accept_grammar) {
|
|
|
|
|
- llama_sampler_accept(smpl, token);
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- llama_sampler_accept(smpl, token);
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- llama_sampler_accept(gsmpl->chain, token);
|
|
|
|
|
|
|
+ if (gsmpl->grmr && accept_grammar) {
|
|
|
|
|
+ llama_sampler_accept(gsmpl->grmr, token);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ llama_sampler_accept(gsmpl->chain, token);
|
|
|
|
|
+
|
|
|
gsmpl->prev.push_back(token);
|
|
gsmpl->prev.push_back(token);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -353,8 +336,8 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
|
|
|
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
|
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
|
|
return new common_sampler {
|
|
return new common_sampler {
|
|
|
/* .params = */ gsmpl->params,
|
|
/* .params = */ gsmpl->params,
|
|
|
|
|
+ /* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
|
|
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
|
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
|
|
- /* .grammar = */ gsmpl->grammar,
|
|
|
|
|
/* .prev = */ gsmpl->prev,
|
|
/* .prev = */ gsmpl->prev,
|
|
|
/* .cur = */ gsmpl->cur,
|
|
/* .cur = */ gsmpl->cur,
|
|
|
/* .cur_p = */ gsmpl->cur_p,
|
|
/* .cur_p = */ gsmpl->cur_p,
|
|
@@ -410,7 +393,7 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
|
|
|
return gsmpl->chain;
|
|
return gsmpl->chain;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
|
|
|
|
|
|
|
+llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
|
|
llama_synchronize(ctx);
|
|
llama_synchronize(ctx);
|
|
|
|
|
|
|
|
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
|
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
|
@@ -418,11 +401,42 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
|
|
|
|
|
|
|
llama_token id = LLAMA_TOKEN_NULL;
|
|
llama_token id = LLAMA_TOKEN_NULL;
|
|
|
|
|
|
|
|
|
|
+ auto & grmr = gsmpl->grmr;
|
|
|
auto & chain = gsmpl->chain;
|
|
auto & chain = gsmpl->chain;
|
|
|
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
|
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
|
|
|
|
|
|
|
gsmpl->set_logits(ctx, idx);
|
|
gsmpl->set_logits(ctx, idx);
|
|
|
|
|
|
|
|
|
|
+ if (grammar_first) {
|
|
|
|
|
+ llama_sampler_apply(grmr, &cur_p);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ llama_sampler_apply(chain, &cur_p);
|
|
|
|
|
+
|
|
|
|
|
+ id = cur_p.data[cur_p.selected].id;
|
|
|
|
|
+
|
|
|
|
|
+ if (grammar_first) {
|
|
|
|
|
+ return id;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // check if it the sampled token fits the grammar (grammar-based rejection sampling)
|
|
|
|
|
+ {
|
|
|
|
|
+ llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
|
|
|
|
+ llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
|
|
|
|
|
+
|
|
|
|
|
+ llama_sampler_apply(grmr, &single_token_data_array);
|
|
|
|
|
+
|
|
|
|
|
+ const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
|
|
|
|
+ if (is_valid) {
|
|
|
|
|
+ return id;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // resampling:
|
|
|
|
|
+ // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
|
|
|
|
|
+ gsmpl->set_logits(ctx, idx);
|
|
|
|
|
+
|
|
|
|
|
+ llama_sampler_apply(grmr, &cur_p);
|
|
|
llama_sampler_apply(chain, &cur_p);
|
|
llama_sampler_apply(chain, &cur_p);
|
|
|
|
|
|
|
|
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
|
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
|
@@ -432,7 +446,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
|
|
return id;
|
|
return id;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft) {
|
|
|
|
|
|
|
+std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
|
|
|
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
|
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
|
|
|
|
|
|
|
std::vector<llama_token> result;
|
|
std::vector<llama_token> result;
|
|
@@ -440,7 +454,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
|
|
|
|
|
|
|
size_t i = 0;
|
|
size_t i = 0;
|
|
|
for (; i < draft.size(); i++) {
|
|
for (; i < draft.size(); i++) {
|
|
|
- const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
|
|
|
|
|
|
|
+ const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
|
|
|
|
|
|
|
common_sampler_accept(gsmpl, id, true);
|
|
common_sampler_accept(gsmpl, id, true);
|
|
|
|
|
|
|
@@ -452,7 +466,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if (i == draft.size()) {
|
|
if (i == draft.size()) {
|
|
|
- const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
|
|
|
|
|
|
|
+ const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
|
|
|
|
|
|
|
common_sampler_accept(gsmpl, id, true);
|
|
common_sampler_accept(gsmpl, id, true);
|
|
|
|
|
|
|
@@ -462,13 +476,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
|
|
return result;
|
|
return result;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
|
|
|
|
|
|
|
+std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
|
|
|
std::vector<int> idxs(draft.size() + 1);
|
|
std::vector<int> idxs(draft.size() + 1);
|
|
|
for (size_t i = 0; i < idxs.size(); ++i) {
|
|
for (size_t i = 0; i < idxs.size(); ++i) {
|
|
|
idxs[i] = i;
|
|
idxs[i] = i;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
|
|
|
|
|
|
|
+ return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
|
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|