|
@@ -28,9 +28,13 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
|
|
|
|
|
|
|
|
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
|
|
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
|
|
|
|
|
|
|
|
- result->grammar = llama_grammar_init(
|
|
|
|
|
|
|
+ struct llama_grammar * grammar = llama_grammar_init(
|
|
|
grammar_rules.data(),
|
|
grammar_rules.data(),
|
|
|
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
|
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
|
|
|
|
+ if (grammar == nullptr) {
|
|
|
|
|
+ throw std::runtime_error("Failed to initialize llama_grammar");
|
|
|
|
|
+ }
|
|
|
|
|
+ result->grammar = grammar;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
result->prev.resize(params.n_prev);
|
|
result->prev.resize(params.n_prev);
|
|
@@ -59,9 +63,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
|
|
|
if (!ctx->parsed_grammar.rules.empty()) {
|
|
if (!ctx->parsed_grammar.rules.empty()) {
|
|
|
std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
|
|
std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
|
|
|
|
|
|
|
|
- ctx->grammar = llama_grammar_init(
|
|
|
|
|
|
|
+ struct llama_grammar * grammar = llama_grammar_init(
|
|
|
grammar_rules.data(),
|
|
grammar_rules.data(),
|
|
|
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
|
|
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
|
|
|
|
|
+ if (grammar == nullptr) {
|
|
|
|
|
+ throw std::runtime_error("Failed to initialize llama_grammar");
|
|
|
|
|
+ }
|
|
|
|
|
+ ctx->grammar = grammar;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
|
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|