Kaynağa Gözat

grammar : fix integer overflow (#17381)

* Fix DoS / integer overflow

* Remove optional, use INT64_MAX instead as placeholder value (it's technically -1, so it fits :)

* White space

* Actually, since it's unsigned, use UINT64_MAX
Piotr Wilkin (ilintar) 1 ay önce
ebeveyn
işleme
92c0b387a9
1 değiştirilmiş dosya ile 15 ekleme ve 8 silme
  1. 15 8
      src/llama-grammar.cpp

+ 15 - 8
src/llama-grammar.cpp

@@ -6,8 +6,10 @@
 
 #include <cmath>
 #include <algorithm>
+#include <cstdint>
 #include <stdexcept>
 
+#define MAX_REPETITION_THRESHOLD 2000
 //
 // helpers
 //
@@ -345,7 +347,9 @@ const char * llama_grammar_parser::parse_sequence(
     size_t last_sym_start = rule.size();
     const char * pos = src;
 
-    auto handle_repetitions = [&](int min_times, int max_times) {
+    // use UINT64_MAX as the empty value because we aligned to the proper unsigned long type so -1 can't be used
+    // (though it's technically the same as -1 now)
+    auto handle_repetitions = [&](unsigned long min_times, unsigned long max_times) {
 
         if (last_sym_start == rule.size()) {
             throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos);
@@ -373,20 +377,20 @@ const char * llama_grammar_parser::parse_sequence(
             rule.resize(last_sym_start);
         } else {
             // Repeat the previous elements (min_times - 1) times
-            for (int i = 1; i < min_times; i++) {
+            for (unsigned long i = 1; i < min_times; i++) {
                 rule.insert(rule.end(), prev_rule.begin(), prev_rule.end());
             }
         }
 
         uint32_t last_rec_rule_id = 0;
-        auto n_opt = max_times < 0 ? 1 : max_times - min_times;
+        auto n_opt = max_times == UINT64_MAX ? 1 : max_times - min_times;
 
         llama_grammar_rule rec_rule(prev_rule);
-        for (int i = 0; i < n_opt; i++) {
+        for (unsigned long i = 0; i < n_opt; i++) {
             rec_rule.resize(prev_rule.size());
             uint32_t rec_rule_id = generate_symbol_id( rule_name);
-            if (i > 0 || max_times < 0) {
-                rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id});
+            if (i > 0 || max_times == UINT64_MAX) {
+                rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times == UINT64_MAX ? rec_rule_id : last_rec_rule_id});
             }
             rec_rule.push_back({LLAMA_GRETYPE_ALT, 0});
             rec_rule.push_back({LLAMA_GRETYPE_END, 0});
@@ -478,10 +482,10 @@ const char * llama_grammar_parser::parse_sequence(
                 throw std::runtime_error(std::string("expecting an int at ") + pos);
             }
             const char * int_end = parse_int(pos);
-            int min_times = std::stoul(std::string(pos, int_end - pos));
+            unsigned long min_times = std::stoul(std::string(pos, int_end - pos));
             pos = parse_space(int_end, is_nested);
 
-            int max_times = -1;
+            unsigned long max_times = UINT64_MAX;
 
             if (*pos == '}') {
                 max_times = min_times;
@@ -502,6 +506,9 @@ const char * llama_grammar_parser::parse_sequence(
             } else {
                 throw std::runtime_error(std::string("expecting ',' at ") + pos);
             }
+            if (min_times > MAX_REPETITION_THRESHOLD || (max_times != UINT64_MAX && max_times > MAX_REPETITION_THRESHOLD)) {
+                throw std::runtime_error(std::string("number of repetitions exceeds sane defaults, please reduce the number of repetitions"));
+            }
             handle_repetitions(min_times, max_times);
         } else {
             break;