|
@@ -4,6 +4,7 @@
|
|
|
#include "llama-vocab.h"
|
|
#include "llama-vocab.h"
|
|
|
#include "llama-grammar.h"
|
|
#include "llama-grammar.h"
|
|
|
|
|
|
|
|
|
|
+#include <array>
|
|
|
#include <algorithm>
|
|
#include <algorithm>
|
|
|
#include <cassert>
|
|
#include <cassert>
|
|
|
#include <cfloat>
|
|
#include <cfloat>
|
|
@@ -1625,10 +1626,12 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
|
|
auto * ctx = new llama_sampler_grammar;
|
|
auto * ctx = new llama_sampler_grammar;
|
|
|
|
|
|
|
|
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
|
if (grammar_str != nullptr && grammar_str[0] != '\0') {
|
|
|
|
|
+ std::string trigger_pattern;
|
|
|
|
|
+ llama_grammar * grammar = nullptr;
|
|
|
// TODO: remove trigger_words support.
|
|
// TODO: remove trigger_words support.
|
|
|
if (trigger_words != nullptr && num_trigger_words > 0) {
|
|
if (trigger_words != nullptr && num_trigger_words > 0) {
|
|
|
GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
|
|
GGML_ASSERT(trigger_patterns == nullptr && num_trigger_patterns == 0);
|
|
|
- std::string trigger_pattern("[\\s\\S]*?(");
|
|
|
|
|
|
|
+ trigger_pattern = "[\\s\\S]*?(";
|
|
|
for (size_t i = 0; i < num_trigger_words; ++i) {
|
|
for (size_t i = 0; i < num_trigger_words; ++i) {
|
|
|
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
|
static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]");
|
|
|
if (i > 0) {
|
|
if (i > 0) {
|
|
@@ -1637,15 +1640,17 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
|
|
|
trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
|
|
trigger_pattern += std::regex_replace(trigger_words[i], special_chars, "\\$0");
|
|
|
}
|
|
}
|
|
|
trigger_pattern += ")[\\s\\S]*";
|
|
trigger_pattern += ")[\\s\\S]*";
|
|
|
- const auto * trigger_pattern_c = trigger_pattern.c_str();
|
|
|
|
|
- trigger_patterns = &trigger_pattern_c;
|
|
|
|
|
- num_trigger_patterns = 1;
|
|
|
|
|
|
|
+
|
|
|
|
|
+ std::array<const char *, 1> tmp_trigger_patterns = { trigger_pattern.c_str() };
|
|
|
|
|
+ grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, tmp_trigger_patterns.data(), tmp_trigger_patterns.size(), trigger_tokens, num_trigger_tokens);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens);
|
|
|
}
|
|
}
|
|
|
*ctx = {
|
|
*ctx = {
|
|
|
/* .vocab = */ vocab,
|
|
/* .vocab = */ vocab,
|
|
|
/* .grammar_str = */ grammar_str,
|
|
/* .grammar_str = */ grammar_str,
|
|
|
/* .grammar_root = */ grammar_root,
|
|
/* .grammar_root = */ grammar_root,
|
|
|
- /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
|
|
|
|
|
|
|
+ /* .grammar = */ grammar,
|
|
|
};
|
|
};
|
|
|
if (!ctx->grammar) {
|
|
if (!ctx->grammar) {
|
|
|
delete ctx;
|
|
delete ctx;
|