Bladeren bron

llama : add token matching support to llama-grammar (#17816)

* llama : add token support to llama-grammar

* fix inverse token comment

* refactor trigger_patterns to replay tokens instead of the entire string

* add token documentation

* fix test-llama-grammar

* improve test cases for tokens
Aldehir Rojas 1 maand geleden
bovenliggende
commit
e39502e74b
6 gewijzigde bestanden met toevoegingen van 400 en 38 verwijderingen
  1. 24 0
      grammars/README.md
  2. 233 33
      src/llama-grammar.cpp
  3. 20 1
      src/llama-grammar.h
  4. 108 3
      tests/test-grammar-integration.cpp
  5. 14 0
      tests/test-grammar-parser.cpp
  6. 1 1
      tests/test-llama-grammar.cpp

+ 24 - 0
grammars/README.md

@@ -67,6 +67,30 @@ Parentheses `()` can be used to group sequences, which allows for embedding alte
 - `{m,n}` repeats the precedent symbol or sequence at between `m` and `n` times (included)
 - `{0,n}` repeats the precedent symbol or sequence at most `n` times (included)
 
+## Tokens
+
+Tokens allow grammars to match specific tokenizer tokens rather than character sequences. This is useful for constraining outputs based on special tokens (like `<think>` or `</think>`).
+
+Tokens can be specified in two ways:
+
+1. **Token ID**: Use angle brackets with the token ID in square brackets: `<[token-id]>`. For example, `<[1000]>` matches the token with ID 1000.
+
+2. **Token string**: Use angle brackets with the token text directly: `<token>`. For example, `<think>` will match the token whose text is exactly `<think>`. This only works if the string tokenizes to exactly one token in the vocabulary, otherwise the grammar will fail to parse.
+
+You can negate token matches using the `!` prefix: `!<[1000]>` or `!<think>` matches any token *except* the specified one.
+
+```
+# Match a thinking block: <think>...</think>
+# Using token strings (requires these to be single tokens in the vocab)
+root ::= <think> thinking </think> .*
+thinking ::= !</think>*
+
+# Equivalent grammar using explicit token IDs
+# Assumes token 1000 = <think>, token 1001 = </think>
+root ::= <[1000]> thinking <[1001]> .*
+thinking ::= !<[1001]>*
+```
+
 ## Comments and newlines
 
 Comments can be specified with `#`:

+ 233 - 33
src/llama-grammar.cpp

@@ -181,6 +181,52 @@ static std::pair<uint32_t, const char *> parse_char(const char * src) {
     throw std::runtime_error("unexpected end of input");
 }
 
+static std::pair<uint32_t, const char *> parse_token(const llama_vocab * vocab, const char * src) {
+    const char * pos = src;
+    if (*pos != '<') {
+        throw std::runtime_error(std::string("expecting '<' at ") + pos);
+    }
+    pos++;
+
+    // Parse <[id]>
+    if (*pos == '[') {
+        pos++;
+        const char * int_end = parse_int(pos);
+        uint32_t token_id = std::stoul(std::string(pos, int_end - pos));
+        pos = int_end;
+        if (*pos != ']') {
+            throw std::runtime_error(std::string("expecting ']' at ") + pos);
+        }
+        pos++;
+        if (*pos != '>') {
+            throw std::runtime_error(std::string("expecting '>' at ") + pos);
+        }
+        pos++;
+        return std::make_pair(token_id, pos);
+    }
+
+    if (vocab == nullptr) {
+        throw std::runtime_error(std::string("no vocab to parse token at ") + src);
+    }
+
+    // Parse <token> and tokenize to obtain the token id
+    while (*pos != 0 && *pos != '>') {
+        pos++;
+    }
+    if (*pos != '>') {
+        throw std::runtime_error(std::string("expecting '>' at ") + pos);
+    }
+    pos++;
+
+    llama_token tokens[2];
+    int32_t n_tokens = vocab->tokenize(src, static_cast<int32_t>(pos - src), tokens, 2, false, true);
+    if (n_tokens != 1) {
+        // must tokenize to exactly 1 token
+        throw std::runtime_error("invalid token '" + std::string(src, pos - src) + "'");
+    }
+    return std::make_pair(tokens[0], pos);
+}
+
 static void print_grammar_char(FILE * file, uint32_t c) {
     if (0x20 <= c && c <= 0x7f) {
         fprintf(file, "%c", static_cast<char>(c));
@@ -212,6 +258,8 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
             case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
             case LLAMA_GRETYPE_CHAR_ALT:       fprintf(file, "CHAR_ALT");       break;
             case LLAMA_GRETYPE_CHAR_ANY:       fprintf(file, "CHAR_ANY");       break;
+            case LLAMA_GRETYPE_TOKEN:          fprintf(file, "TOKEN");          break;
+            case LLAMA_GRETYPE_TOKEN_NOT:      fprintf(file, "TOKEN_NOT");      break;
         }
         switch (elem.type) {
             case LLAMA_GRETYPE_END:
@@ -228,6 +276,17 @@ static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) {
                 print_grammar_char(file, elem.value);
                 fprintf(file, "\") ");
                 break;
+            case LLAMA_GRETYPE_TOKEN:
+                fprintf(file, "<[");
+                fprintf(file, "%u", elem.value);
+                fprintf(file, "]> ");
+                break;
+            case LLAMA_GRETYPE_TOKEN_NOT:
+                fprintf(file, "!");
+                fprintf(file, "<[");
+                fprintf(file, "%u", elem.value);
+                fprintf(file, "]> ");
+                break;
         }
     }
     fprintf(file, "\n");
@@ -284,6 +343,17 @@ static void print_rule(
             case LLAMA_GRETYPE_CHAR_ANY:
                 fprintf(file, ".");
                 break;
+            case LLAMA_GRETYPE_TOKEN:
+                fprintf(file, "<[");
+                fprintf(file, "%u", elem.value);
+                fprintf(file, "]> ");
+                break;
+            case LLAMA_GRETYPE_TOKEN_NOT:
+                fprintf(file, "!");
+                fprintf(file, "<[");
+                fprintf(file, "%u", elem.value);
+                fprintf(file, "]> ");
+                break;
         }
         if (is_char_element(elem)) {
             switch (rule[i + 1].type) {
@@ -444,6 +514,17 @@ const char * llama_grammar_parser::parse_sequence(
                 }
             }
             pos = parse_space(pos + 1, is_nested);
+        } else if (*pos == '<' || *pos == '!') { // token
+            auto type = LLAMA_GRETYPE_TOKEN;
+            if (*pos == '!') { // token inverse
+                type = LLAMA_GRETYPE_TOKEN_NOT;
+                pos++;
+            }
+            auto token_pair = parse_token(vocab, pos);
+            const char * token_end  = token_pair.second;
+            last_sym_start = rule.size();
+            rule.push_back({type, token_pair.first});
+            pos = parse_space(token_end, is_nested);
         } else if (is_word_char(*pos)) { // rule reference
             const char * name_end    = parse_name(pos);
             uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos);
@@ -691,6 +772,21 @@ static bool llama_grammar_match_partial_char(
     return !is_positive_char;
 }
 
+// returns true iff token matches the rule at pos (regular or inverse)
+// asserts that pos is pointing to a token element
+static bool llama_grammar_match_token(
+    const llama_grammar_element * pos,
+    const llama_token             token) {
+    GGML_ASSERT(pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT);
+    if (pos->type == LLAMA_GRETYPE_TOKEN) {
+        return pos->value == static_cast<uint32_t>(token);
+    }
+    if (pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
+        return pos->value != static_cast<uint32_t>(token);
+    }
+    return false;
+}
+
 // transforms a grammar pushdown stack into N possible stacks, all ending
 // at a character range (terminal element)
 static void llama_grammar_advance_stack(
@@ -738,6 +834,8 @@ static void llama_grammar_advance_stack(
         case LLAMA_GRETYPE_CHAR:
         case LLAMA_GRETYPE_CHAR_NOT:
         case LLAMA_GRETYPE_CHAR_ANY:
+        case LLAMA_GRETYPE_TOKEN:
+        case LLAMA_GRETYPE_TOKEN_NOT:
             if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) {
                 // only add the stack if it's not a duplicate of one we already have
                 new_stacks.emplace_back(stack);
@@ -831,26 +929,38 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar)
     return grammar->stacks;
 }
 
+static void llama_grammar_accept_chr(
+        struct llama_grammar       & grammar,
+        const llama_grammar_stack  & stack,
+              uint32_t               chr,
+              llama_grammar_stacks & new_stacks) {
+    if (stack.empty()) {
+        return;
+    }
+
+    const llama_grammar_element * pos = stack.back();
+
+    // ignore if this turns into a token
+    if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
+        return;
+    }
+
+    auto match = llama_grammar_match_char(pos, chr);
+    if (match.first) {
+        llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
+        if (!llama_grammar_is_end_of_sequence(match.second)) {
+            new_stack.push_back(match.second);
+        }
+        llama_grammar_advance_stack(grammar.rules, new_stack, new_stacks);
+    }
+}
+
 void llama_grammar_accept(struct llama_grammar * grammar, uint32_t chr) {
     llama_grammar_stacks stacks_new;
     stacks_new.reserve(grammar->stacks.size());
 
     for (const auto & stack : grammar->stacks) {
-        if (stack.empty()) {
-            continue;
-        }
-
-        auto match = llama_grammar_match_char(stack.back(), chr);
-        if (match.first) {
-            const llama_grammar_element * pos = match.second;
-
-            // update top of stack to next element, if any
-            llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
-            if (!llama_grammar_is_end_of_sequence(pos)) {
-                new_stack.push_back(pos);
-            }
-            llama_grammar_advance_stack(grammar->rules, new_stack, stacks_new);
-        }
+        llama_grammar_accept_chr(*grammar, stack, chr, stacks_new);
     }
 
     grammar->stacks = std::move(stacks_new);
@@ -875,6 +985,22 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
 
     const llama_grammar_element * stack_pos = stack.back();
 
+    // if the top of the stack is a token rule, then we only need to check the token id
+    if (stack_pos->type == LLAMA_GRETYPE_TOKEN || stack_pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
+        for (const auto & tok : candidates) {
+            if (*tok.code_points == 0) {
+                // reached the end of a token consumed by char rules, reject iff it ended
+                // in a partial response
+                if (tok.partial_utf8.n_remain != 0) {
+                    rejects.push_back(tok);
+                }
+            } else if (!llama_grammar_match_token(stack_pos, tok.id)) {
+                rejects.push_back(tok);
+            }
+        }
+        return rejects;
+    }
+
     llama_grammar_candidates next_candidates;
     next_candidates.reserve(candidates.size());
 
@@ -887,7 +1013,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
                 rejects.push_back(tok);
             }
         } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) {
-            next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 });
+            next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8, tok.id });
         } else {
             rejects.push_back(tok);
         }
@@ -905,7 +1031,7 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack(
 
     auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates);
     for (const auto & tok : next_rejects) {
-        rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 });
+        rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8, tok.id });
     }
 
     return rejects;
@@ -972,12 +1098,13 @@ struct llama_grammar * llama_grammar_init_impl(
         vocab,
         std::move(vec_rules),
         std::move(stacks),
-        /* .partial_utf8 = */     {},
-        /* .lazy =*/              false,
-        /* .awaiting_trigger = */ false,
-        /* .trigger_buffer = */   "",
-        /* .trigger_tokens   = */ {},
-        /* .trigger_patterns    = */ {},
+        /* .partial_utf8 = */             {},
+        /* .lazy = */                     false,
+        /* .awaiting_trigger = */         false,
+        /* .trigger_buffer = */           "",
+        /* .trigger_buffer_positions = */ {},
+        /* .trigger_tokens = */           {},
+        /* .trigger_patterns = */         {},
     };
 }
 
@@ -990,7 +1117,7 @@ struct llama_grammar * llama_grammar_init_impl(
                             size_t num_trigger_patterns,
                const llama_token * trigger_tokens,
                             size_t num_trigger_tokens) {
-    llama_grammar_parser parser;
+    llama_grammar_parser parser(vocab);
 
     // if there is a grammar, parse it
     // rules will be empty (default) if there are parse errors
@@ -1077,10 +1204,11 @@ struct llama_grammar * llama_grammar_init_impl(
         vocab,
         std::move(vec_rules),
         std::move(stacks),
-        /* .partial_utf8 = */     {},
-        /* .lazy = */             lazy,
-        /* .awaiting_trigger = */ lazy,
-        /* .trigger_buffer = */   "",
+        /* .partial_utf8 = */             {},
+        /* .lazy = */                     lazy,
+        /* .awaiting_trigger = */         lazy,
+        /* .trigger_buffer = */           "",
+        /* .trigger_buffer_positions = */ {},
         std::move(vec_trigger_tokens),
         std::move(vec_trigger_patterns),
     };
@@ -1103,6 +1231,7 @@ struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & gra
         grammar.lazy,
         grammar.awaiting_trigger,
         grammar.trigger_buffer,
+        grammar.trigger_buffer_positions,
         grammar.trigger_tokens,
         grammar.trigger_patterns,
     };
@@ -1156,7 +1285,7 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
             cur_p->data[i].logit = -INFINITY;
         } else {
             candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8));
-            candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
+            candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second, id });
         }
     }
 
@@ -1175,10 +1304,12 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
         if (std::find(grammar.trigger_tokens.begin(), grammar.trigger_tokens.end(), token) != grammar.trigger_tokens.end()) {
             grammar.awaiting_trigger = false;
             grammar.trigger_buffer.clear();
-            llama_grammar_accept_str(grammar, piece);
+            llama_grammar_accept_token(grammar, token, piece);
             LLAMA_LOG_DEBUG("Grammar triggered on token %u (`%s`)", token, piece.c_str());
             return;
         } else {
+            auto position = std::make_pair(grammar.trigger_buffer.size(), grammar.trigger_buffer.size() + piece.size());
+            grammar.trigger_buffer_positions.push_back(std::make_pair(token, position));
             grammar.trigger_buffer += piece;
 
             std::smatch match;
@@ -1196,10 +1327,23 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
                     if (start == std::string::npos) {
                         start = match.position(0);
                     }
+
+                    // replay tokens that overlap with [start, end)
+                    for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) {
+                        auto [tok_start, tok_end] = tok_pos;
+                        if (tok_end <= start) {
+                            continue;
+                        }
+
+                        size_t piece_start = (tok_start < start) ? start : tok_start; // allow for partial token pieces
+                        size_t piece_len = tok_end - piece_start;
+                        auto tok_piece = grammar.trigger_buffer.substr(piece_start, piece_len);
+                        llama_grammar_accept_token(grammar, tok, tok_piece);
+                    }
+
                     auto constrained_str = grammar.trigger_buffer.substr(start);
-                    // std::string constrained_str(match[1].first, grammar.trigger_buffer.end());
                     grammar.trigger_buffer.clear();
-                    llama_grammar_accept_str(grammar, constrained_str);
+                    grammar.trigger_buffer_positions.clear();
                     LLAMA_LOG_DEBUG("Grammar triggered on regex: '%s'\n", constrained_str.c_str());
                     return;
                 }
@@ -1218,7 +1362,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
         GGML_ABORT("fatal error");
     }
 
-    llama_grammar_accept_str(grammar, piece);
+    llama_grammar_accept_token(grammar, token, piece);
 }
 
 void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string & piece) {
@@ -1235,3 +1379,59 @@ void llama_grammar_accept_str(struct llama_grammar & grammar, const std::string
         throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece);
     }
 }
+
+void llama_grammar_accept_token(struct llama_grammar & grammar, llama_token token, const std::string & piece) {
+    // Note terminating 0 in decoded string
+    const auto   decoded     = decode_utf8(piece, grammar.partial_utf8);
+    const auto & code_points = decoded.first;
+
+    llama_grammar_stacks stacks_new;
+    stacks_new.reserve(grammar.stacks.size());
+
+    for (const auto & stack : grammar.stacks) {
+        if (stack.empty()) {
+            continue;
+        }
+
+        const llama_grammar_element * pos = stack.back();
+
+        if (pos->type == LLAMA_GRETYPE_TOKEN || pos->type == LLAMA_GRETYPE_TOKEN_NOT) {
+            if (llama_grammar_match_token(pos, token)) {
+                llama_grammar_stack new_stack(stack.begin(), stack.end() - 1);
+                if (!llama_grammar_is_end_of_sequence(pos + 1)) {
+                    new_stack.push_back(pos + 1);
+                }
+                llama_grammar_advance_stack(grammar.rules, new_stack, stacks_new);
+            }
+        } else {
+            llama_grammar_stacks current_stacks = {stack};
+
+            for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
+                llama_grammar_stacks next_stacks;
+
+                for (const auto & cur_stack : current_stacks) {
+                    llama_grammar_accept_chr(grammar, cur_stack, *it, next_stacks);
+                }
+
+                current_stacks = std::move(next_stacks);
+                if (current_stacks.empty()) {
+                    break;
+                }
+            }
+
+            for (auto & surviving_stack : current_stacks) {
+                if (std::find(stacks_new.begin(), stacks_new.end(), surviving_stack) == stacks_new.end()) {
+                    stacks_new.emplace_back(surviving_stack);
+                }
+            }
+        }
+    }
+
+    grammar.stacks = std::move(stacks_new);
+    grammar.partial_utf8 = decoded.second;
+
+    if (grammar.stacks.empty()) {
+        throw std::runtime_error("Unexpected empty grammar stack after accepting piece: " + piece + " (" + std::to_string(token) + ")");
+    }
+}
+

+ 20 - 1
src/llama-grammar.h

@@ -36,11 +36,17 @@ enum llama_gretype {
 
     // any character (.)
     LLAMA_GRETYPE_CHAR_ANY       = 7,
+
+    // terminal element: token (<[token-id]>)
+    LLAMA_GRETYPE_TOKEN          = 8,
+
+    // inverse token (!<[token-id]>)
+    LLAMA_GRETYPE_TOKEN_NOT      = 9,
 };
 
 typedef struct llama_grammar_element {
     enum llama_gretype type;
-    uint32_t           value; // Unicode code point or rule ID
+    uint32_t           value; // Unicode code point, rule ID, or token ID
 } llama_grammar_element;
 
 struct llama_partial_utf8 {
@@ -52,6 +58,7 @@ struct llama_grammar_candidate {
     size_t               index;
     const uint32_t     * code_points;
     llama_partial_utf8   partial_utf8;
+    llama_token          id;
 };
 
 using llama_grammar_rule  = std::vector<      llama_grammar_element>;
@@ -77,10 +84,13 @@ std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
         const llama_grammar_candidates & candidates);
 
 struct llama_grammar_parser {
+    const llama_vocab * vocab;
     std::map<std::string, uint32_t> symbol_ids;
 
     llama_grammar_rules rules;
 
+    llama_grammar_parser(const struct llama_vocab * vocab = nullptr) : vocab(vocab) {}
+
     llama_grammar_stack c_rules() const;
 
     uint32_t get_symbol_id(const char * src, size_t len);
@@ -112,6 +122,9 @@ struct llama_grammar_trigger_pattern {
 };
 
 struct llama_grammar {
+    // maintain a list of llama_tokens and their positions in the trigger_buffer
+    using token_pos = std::pair<llama_token, std::pair<size_t, size_t>>;
+
     // note: allow null vocab for testing (not great)
     const llama_vocab * vocab;
 
@@ -127,6 +140,7 @@ struct llama_grammar {
     bool                     lazy             = false;
     bool                     awaiting_trigger = false; // Initialized to true for lazy grammars only
     std::string              trigger_buffer;           // Output buffered by lazy grammar. Will be cleared once trigger is found.
+    std::vector<token_pos>   trigger_buffer_positions; // Tokens buffered by lazy grammar. Used to replay when a trigger is found.
     std::vector<llama_token> trigger_tokens;           // Tokens that trigger a lazy grammar, or tokens to force printing of (even if special).
     std::vector<llama_grammar_trigger_pattern>
                              trigger_patterns;         // Regular expressions that trigger a lazy grammar. Must be a full match of the entire generated
@@ -171,3 +185,8 @@ void llama_grammar_accept_impl(
 void llama_grammar_accept_str(
               struct llama_grammar & grammar,
                  const std::string & piece);
+
+void llama_grammar_accept_token(
+              struct llama_grammar & grammar,
+                       llama_token   token,
+                 const std::string & piece);

+ 108 - 3
tests/test-grammar-integration.cpp

@@ -32,13 +32,66 @@ static bool test_build_grammar_fails(const std::string & grammar_str) {
     return grammar_fails;
 }
 
+struct token_and_piece {
+    llama_token token;
+    std::string piece;
+};
+
+// token() encodes a 32-bit ID as 5 bytes: a 0xff marker followed by the ID in big-endian order.
+static std::string token(llama_token id) {
+    return std::string{
+        static_cast<char>(0xff),
+        static_cast<char>((id >> 24) & 0xff),
+        static_cast<char>((id >> 16) & 0xff),
+        static_cast<char>((id >> 8) & 0xff),
+        static_cast<char>(id & 0xff)
+    };
+}
+
+// parse_tokens() parses the token encodes above and UTF-8 text.
+static std::vector<token_and_piece> parse_tokens(const std::string & input) {
+    std::vector<token_and_piece> result;
+    result.reserve(input.size());
+    size_t offset = 0;
+    while (offset < input.size()) {
+        try {
+            if (static_cast<unsigned char>(input[offset]) == 0xff) {
+                if (offset + 5 > input.size()) {
+                    throw std::runtime_error("not enough bytes for token id");
+                }
+                uint32_t val =
+                    (static_cast<unsigned char>(input[offset + 1]) << 24) |
+                    (static_cast<unsigned char>(input[offset + 2]) << 16) |
+                    (static_cast<unsigned char>(input[offset + 3]) << 8)  |
+                    (static_cast<unsigned char>(input[offset + 4]));
+                auto piece = "<[" + std::to_string(val) + "]>";
+                result.push_back({static_cast<llama_token>(val), piece});
+                offset += 5;
+            } else {
+                uint32_t cpt = unicode_cpt_from_utf8(input, offset);
+                result.push_back({0, unicode_cpt_to_utf8(cpt)});
+            }
+        } catch (const std::invalid_argument & /*ex*/) {
+            // Silently ignore invalid UTF-8 input to avoid leaking the exception beyond llama_tokenize
+            ++offset;
+            result.push_back({0, unicode_cpt_to_utf8(0xFFFD)}); // replacement character
+        }
+    }
+    return result;
+}
+
 static bool match_string(const std::string & input, llama_grammar * grammar) {
-    const auto cpts = unicode_cpts_from_utf8(input);
+    const auto parsed = parse_tokens(input);
 
     auto & stacks_cur = llama_grammar_get_stacks(grammar);
 
-    for (const auto & cpt : cpts) {
-        llama_grammar_accept(grammar, cpt);
+    for (const auto & in : parsed) {
+        try {
+            llama_grammar_accept_token(*grammar, in.token, in.piece);
+        } catch (const std::runtime_error & /*e*/) {
+            // normally this shouldn't get hit because of llama_grammar_apply
+            return false;
+        }
 
         if (stacks_cur.empty()) {
             // no stacks means that the grammar failed to match at this point
@@ -426,6 +479,30 @@ static void test_simple_grammar() {
             "12a45",
         }
     );
+
+    // Test case for a simple grammar with tokens
+    test_grammar(
+        "simple grammar with tokens",
+        R"""(
+            root ::= <[10]> content <[11]>
+            content ::= (!<[11]>)*)""",
+        // Passing strings
+        {
+            token(10) + "hello world" + token(11),
+            token(10) + "text with " + token(12) + " other tokens " + token(13) + " mixed in" + token(11),
+            token(10) + token(11),
+            token(10) + token(12) + token(13) + token(14) + token(15) + token(11),
+            token(10) + "a" + token(11),
+        },
+        // Failing strings
+        {
+            token(10) + "missing end token",
+            token(10),
+            "missing start token" + token(11),
+            token(10) + token(11) + token(11),  // double end token
+            token(11) + "wrong order" + token(10),
+        }
+    );
 }
 
 static void test_complex_grammar() {
@@ -487,6 +564,34 @@ static void test_complex_grammar() {
             "123+456*789-123/456+789*123-456/789+123*456-789/123+456*789-123/456+789*123-456/",
         }
     );
+
+    // Test case for a more complex grammar with tokens
+    test_grammar(
+        "complex grammar with tokens",
+        R"""(
+            root ::= reasoning+ content tool-call*
+            reasoning ::= <[10]> (!<[11]>)* <[11]>
+            content ::= <[20]> (!<[21]>)* <[21]>
+            tool-call ::= <[12]> name <[13]> args <[14]>
+            name ::= (!<[13]>)+
+            args ::= (!<[14]>)*)""",
+        // Passing strings
+        {
+            token(10) + "I am thinking" + token(11) + token(20) + "hello world!" + token(21) + token(12) + "search" + token(13) + "query=test" + token(14),
+            token(10) + "reasoning 1" + token(11) + token(10) + "reasoning 2" + token(11) + token(20) + token(21) + token(12) + "tool" + token(13) + token(14),
+            token(10) + token(11) + token(20) + "content" + token(21),
+            token(10) + "think" + token(12) + " nested" + token(11) + token(20) + token(10) + "more content" + token(21) + token(12) + "fn" + token(13) + "x=1,y=2" + token(14) + token(12) + "fn2" + token(13) + token(14),
+            token(10) + "reasoning" + token(11) + token(10) + "more" + token(11) + token(10) + "even more" + token(11) + token(20) + "text" + token(21) + token(12) + "a" + token(13) + "b" + token(14) + token(12) + "c" + token(13) + "d" + token(14),
+        },
+        // Failing strings
+        {
+            token(20) + "content only" + token(21),
+            token(10) + "no closing reasoning",
+            token(10) + token(11) + token(20) + "no closing content",
+            token(10) + token(11) + token(20) + token(21) + token(12) + "incomplete tool",
+            token(10) + token(11) + token(11) + token(20) + token(21),
+        }
+    );
 }
 
 static void test_special_chars() {

+ 14 - 0
tests/test-grammar-parser.cpp

@@ -515,5 +515,19 @@ int main()
         {LLAMA_GRETYPE_END, 0},
     });
 
+    // <[1000]> = "<think>"
+    // <[1001]> = "</think>"
+    verify_parsing(R"""(
+        root  ::= <[1000]> !<[1001]> <[1001]>
+    )""", {
+        {"root", 0}
+    }, {
+        // root (index 0)
+        {LLAMA_GRETYPE_TOKEN, 1000},
+        {LLAMA_GRETYPE_TOKEN_NOT, 1001},
+        {LLAMA_GRETYPE_TOKEN, 1001},
+        {LLAMA_GRETYPE_END, 0},
+    });
+
     return 0;
 }

+ 1 - 1
tests/test-llama-grammar.cpp

@@ -202,7 +202,7 @@ int main()
         uint32_t *cp = new uint32_t[2]; // dynamically allocate memory for code_point
         cp[0] = 37 + i;
         cp[1] = 0;
-        next_candidates[i] = {i, cp, {}};
+        next_candidates[i] = {i, cp, {}, 0};
     }
 
     std::vector<std::vector<std::pair<uint32_t, uint16_t>>> expected_reject = {