Prechádzať zdrojové kódy

llama : return nullptr from llama_grammar_init (#8093)

* llama : return nullptr from llama_grammar_init

This commit updates llama_grammar_init to return nullptr instead of
throwing an exception.

The motivation for this is that this function is declared inside an
extern "C" block and is intended/may be used from C code which will not
be able to handle exceptions thrown, and results in undefined behavior.

On Windows and using MSVC the following warning is currently generated:
```console
C:\llama.cpp\llama.cpp(13998,1): warning C4297: 'llama_grammar_init':
function assumed not to throw an exception but does
C:\llama.cpp\llama.cpp(13998,1): message :
__declspec(nothrow), throw(), noexcept(true), or noexcept was specified
on the function
```

Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>

* squash! llama : return nullptr from llama_grammar_init

Add checks for nullptr when calling llama_grammar_init.

Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>

---------

Signed-off-by: Daniel Bevenius <daniel.bevenius@gmail.com>
Co-authored-by: Clint Herron <hanclinto@gmail.com>
Daniel Bevenius 1 rok pred
rodič
commit
e6bf007744

+ 10 - 2
common/sampling.cpp

@@ -28,9 +28,13 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
 
         std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
 
-        result->grammar = llama_grammar_init(
+        struct llama_grammar * grammar = llama_grammar_init(
                 grammar_rules.data(),
                 grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
+        if (grammar == nullptr) {
+            throw std::runtime_error("Failed to initialize llama_grammar");
+        }
+        result->grammar = grammar;
     }
 
     result->prev.resize(params.n_prev);
@@ -59,9 +63,13 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
     if (!ctx->parsed_grammar.rules.empty()) {
         std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
 
-        ctx->grammar = llama_grammar_init(
+        struct llama_grammar * grammar = llama_grammar_init(
                 grammar_rules.data(),
                 grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
+        if (grammar == nullptr) {
+            throw std::runtime_error("Failed to initialize llama_grammar");
+        }
+        ctx->grammar = grammar;
     }
 
     std::fill(ctx->prev.begin(), ctx->prev.end(), 0);

+ 3 - 1
examples/gbnf-validator/gbnf-validator.cpp

@@ -101,7 +101,9 @@ int main(int argc, char** argv) {
     auto grammar = llama_grammar_init(
             grammar_rules.data(),
             grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
-
+    if (grammar == nullptr) {
+        throw std::runtime_error("Failed to initialize llama_grammar");
+    }
     // Read the input file
     std::string input_str;
     {

+ 2 - 1
llama.cpp

@@ -14500,7 +14500,8 @@ struct llama_grammar * llama_grammar_init(
             continue;
         }
         if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
-            throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %zu", i));
+            LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i);
+            return nullptr;
         }
     }
 

+ 6 - 0
llama.h

@@ -924,6 +924,12 @@ extern "C" {
     // Grammar
     //
 
+    /// Initialize a llama_grammar.
+    ///
+    /// @param rules The rule elements of the grammar to initialize.
+    /// @param n_rules The number of rules.
+    /// @param start_rule_index The index of the root rule (the starting point of the grammar).
+    /// @return The initialized llama_grammar or nullptr if initialization failed.
     LLAMA_API struct llama_grammar * llama_grammar_init(
             const llama_grammar_element ** rules,
                                  size_t    n_rules,

+ 3 - 3
tests/test-grammar-integration.cpp

@@ -36,10 +36,10 @@ static llama_grammar* build_grammar(const std::string & grammar_str) {
 static bool test_build_grammar_fails(const std::string & grammar_str) {
     fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str());
     bool grammar_fails = false;
-    try {
-        build_grammar(grammar_str);
+    llama_grammar * grammar = build_grammar(grammar_str);
+    if (grammar != nullptr) {
         fprintf(stderr, "  ❌ Expected build failure, but succeeded\n");
-    } catch (const std::exception & err) {
+    } else {
         grammar_fails = true;
         fprintf(stdout, "  ✅︎\n");
     }

+ 4 - 0
tests/test-llama-grammar.cpp

@@ -116,6 +116,10 @@ int main()
     std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
     grammar = llama_grammar_init(
         grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
+    if (grammar == nullptr)
+    {
+        throw std::runtime_error("Failed to initialize llama_grammar");
+    }
 
     std::vector<std::vector<llama_grammar_element>> expected_stacks = {
         {