Przeglądaj źródła

llama: fix error on bad grammar (#12628)

Johannes Gäßler 9 miesięcy temu
rodzic
commit
dd373dd3bf
3 zmienionych plików z 12 dodań i 0 usunięć
  1. 3 0
      common/sampling.cpp
  2. 4 0
      include/llama.h
  3. 5 0
      src/llama-sampling.cpp

+ 3 - 0
common/sampling.cpp

@@ -208,6 +208,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
                                                         trigger_patterns_c.data(), trigger_patterns_c.size(),
                                                         trigger_tokens.data(), trigger_tokens.size())
              :      llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
+        if (!grmr) {
+            return nullptr;
+        }
     }
 
     auto * result = new common_sampler {

+ 4 - 0
include/llama.h

@@ -1265,6 +1265,10 @@ extern "C" {
                                float   tau,
                                float   eta);
 
+    /// @details Intializes a GBNF grammar, see grammars/README.md for details.
+    /// @param vocab The vocabulary that this grammar will be used with.
+    /// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails.
+    /// @param grammar_root The name of the start symbol for the grammar.
     LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
             const struct llama_vocab * vocab,
                           const char * grammar_str,

+ 5 - 0
src/llama-sampling.cpp

@@ -1477,6 +1477,7 @@ static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sam
     const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
 
     auto * result = llama_sampler_init_grammar_impl(ctx->vocab, nullptr, nullptr, false, nullptr, 0, nullptr, 0, nullptr, 0);
+    GGML_ASSERT(result);
 
     // copy the state
     {
@@ -1548,6 +1549,10 @@ static struct llama_sampler * llama_sampler_init_grammar_impl(
             /* .grammar_root = */ grammar_root,
             /* .grammar      = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root, lazy, trigger_patterns, num_trigger_patterns, trigger_tokens, num_trigger_tokens),
         };
+        if (!ctx->grammar) {
+            delete ctx;
+            return nullptr;
+        }
     } else {
         *ctx = {
             /* .vocab        = */ vocab,