Quellcode durchsuchen

tokenizer : BPE fixes (#7530)

* Random test: add_bos_token, add_eos_token
* Random test: add BPE models for testing
* Custom regex split fails with codepoint 0
* Fix falcon punctuation regex
* Refactor llm_tokenizer_bpe: move code to constructor
* Move 'add_special_bos/eos' logic to llm_tokenizer_bpe
* Move tokenizer flags to vocab structure.
* Default values for special_add_bos/eos
* Build vocab.special_tokens_cache using vocab token types
* Generalize 'jina-v2' per token attributes
* Fix unicode whitespaces (deepseek-coder, deepseek-llm)
* Skip missing byte tokens (falcon)
* Better unicode data generation
* Replace char32_t with uint32_t
jaime-m-p vor 1 Jahr
Ursprung
Commit
37bef89433
5 geänderte Dateien mit 707 neuen und 521 gelöschten Zeilen
  1. 172 137
      llama.cpp
  2. 120 60
      scripts/gen-unicode-data.py
  3. 121 49
      tests/test-tokenizer-random.py
  4. 273 267
      unicode-data.cpp
  5. 21 8
      unicode.cpp

+ 172 - 137
llama.cpp

@@ -2310,16 +2310,17 @@ struct llama_vocab {
     id special_cls_id  = -1;
     id special_mask_id = -1;
 
-    int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add.
-    int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
-
     id linefeed_id       = 13;
     id special_prefix_id = -1;
     id special_suffix_id = -1;
     id special_middle_id = -1;
     id special_eot_id    = -1; // TODO: move above after "eos_id", and here add "file separator" token
 
-    bool add_space_prefix = true;
+    // tokenizer flags
+    bool tokenizer_add_space_prefix = true;
+    bool tokenizer_add_bos          = false;
+    bool tokenizer_add_eos          = false;
+    bool tokenizer_ignore_merges    = false;
 
     int find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
         GGML_ASSERT(token_left.find(' ') == std::string::npos);
@@ -4770,7 +4771,7 @@ static void llm_load_vocab(
 
             const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
             if (add_space_prefix_keyidx != -1) {
-                vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
+                vocab.tokenizer_add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
             } // The default value of add_space_prefix is true.
         } else if (tokenizer_model == "bert") {
             vocab.type = LLAMA_VOCAB_TYPE_WPM;
@@ -4783,13 +4784,13 @@ static void llm_load_vocab(
             vocab.special_pad_id  = 0;
             vocab.special_cls_id  = 101;
             vocab.special_mask_id = 103;
-            vocab.add_space_prefix = false;
+            vocab.tokenizer_add_space_prefix = false;
         } else if (tokenizer_model == "gpt2") {
             vocab.type = LLAMA_VOCAB_TYPE_BPE;
 
             const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
             if (add_space_prefix_keyidx != -1) {
-                vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
+                vocab.tokenizer_add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
             }
 
             // read bpe merges and populate bpe ranks
@@ -4847,6 +4848,8 @@ static void llm_load_vocab(
                     tokenizer_pre == "llama-v3" ||
                     tokenizer_pre == "llama-bpe") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
+                vocab.tokenizer_ignore_merges = true;
+                vocab.tokenizer_add_bos = true;
             } else if (
                     tokenizer_pre == "deepseek-llm") {
                 vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM;
@@ -4897,6 +4900,14 @@ static void llm_load_vocab(
             } else {
                 throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
             }
+        } else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
+            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            vocab.tokenizer_add_bos = true;
+            vocab.tokenizer_add_eos = false;
+        } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
+            vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
+            vocab.tokenizer_add_bos = true;
+            vocab.tokenizer_add_eos = false;
         } else {
             vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
         }
@@ -5041,10 +5052,10 @@ static void llm_load_vocab(
             bool temp = true;
 
             if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
-                vocab.special_add_bos = int(temp);
+                vocab.tokenizer_add_bos = temp;
             }
             if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
-                vocab.special_add_eos = int(temp);
+                vocab.tokenizer_add_eos = temp;
             }
         }
 
@@ -5144,7 +5155,7 @@ static void llm_load_vocab(
         );
 
         // set attributes by model/tokenizer name
-        if (_contains_any(tokenizer_pre, {"jina-v2-es", "jina-v2-de"})) {
+        if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
             _set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
         } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
             for (auto id : vocab.cache_special_tokens) {
@@ -13158,112 +13169,142 @@ struct llm_bigram_bpe {
 };
 
 struct llm_tokenizer_bpe {
-    llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {}
-
-    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
-        int final_prev_index = -1;
-        bool ignore_merges = false;
-
-        std::vector<std::string> word_collection;
-        switch (vocab.type) {
-            case LLAMA_VOCAB_TYPE_BPE:
-                switch (vocab.type_pre) {
-                    case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
-                        ignore_merges = true;
-                        word_collection = unicode_regex_split(text, {
-                            // original regex from tokenizer.json
-                            //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
-
-                            // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
-                            "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_DBRX:
-                    case LLAMA_VOCAB_PRE_TYPE_SMAUG:
-                        word_collection = unicode_regex_split(text, {
-                            // same as llama3
-                            "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
-                        word_collection = unicode_regex_split(text, {
-                            "[\r\n]",
-                            "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
-                            "\\s?[!-/:-~!-/:-~‘-‟ -。]+",
-                            "\\s+$",
-                            "[一-龥ࠀ-一가-퟿]+",
-                            "\\p{N}+",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
-                        word_collection = unicode_regex_split(text, {
-                            "[\r\n]",
-                            "\\s?\\p{L}+",
-                            "\\s?\\p{P}+",
-                            "[一-龥ࠀ-一가-퟿]+",
-                            "\\p{N}",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_FALCON:
-                        word_collection = unicode_regex_split(text, {
-                            "[\\p{P}\\$\\+<=>\\^~\\|]+",
-                            "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                            "[0-9][0-9][0-9]",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_MPT:
-                        // TODO: MPT pre-tokenization regexes are unknown
-                        //       the following are close, but not exact. run the following:
-                        //       ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf
-                        GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed");
-                        word_collection = unicode_regex_split(text, {
-                            "\\s?\\p{L}+",
-                            "\\s?\\p{P}+",
-                            "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_STARCODER:
-                    case LLAMA_VOCAB_PRE_TYPE_REFACT:
-                    case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
-                        word_collection = unicode_regex_split(text, {
-                            "\\p{N}",
-                            "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_GPT2:
-                    case LLAMA_VOCAB_PRE_TYPE_OLMO:
-                        word_collection = unicode_regex_split(text, {
-                            "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
-                    case LLAMA_VOCAB_PRE_TYPE_QWEN2:
-                        word_collection = unicode_regex_split(text, {
-                            // original regex from tokenizer.json
-                            // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
-                            "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
-                        });
-                        break;
-                    case LLAMA_VOCAB_PRE_TYPE_PORO:
-                        word_collection = unicode_regex_split(text, {
-                            " ?[^(\\s|.,!?…。,、।۔،)]+",
-                        });
-                        break;
-                    default:
-                        // default regex for BPE tokenization pre-processing
-                        word_collection = unicode_regex_split(text, {
-                            "[\\p{P}\\$\\+<=>\\^~\\|]+",
-                            "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
-                            "\\p{N}+",
-                            "[0-9][0-9][0-9]",
-                        });
-                        break;
-                }
+    llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {
+        GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
+        switch (vocab.type_pre) {
+            case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
+                regex_exprs = {
+                    // original regex from tokenizer.json
+                    //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+
+                    // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
+                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_DBRX:
+            case LLAMA_VOCAB_PRE_TYPE_SMAUG:
+                regex_exprs = {
+                    // same as llama3
+                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
+                regex_exprs = {
+                    "[\r\n]",
+                    "\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
+                    "\\s?[!-/:-~!-/:-~‘-‟ -。]+",
+                    "\\s+$",
+                    "[一-龥ࠀ-一가-퟿]+",
+                    "\\p{N}+",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
+                regex_exprs = {
+                    "[\r\n]",
+                    "\\s?\\p{L}+",
+                    "\\s?\\p{P}+",
+                    "[一-龥ࠀ-一가-퟿]+",
+                    "\\p{N}",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_FALCON:
+                regex_exprs = {
+                    "[\\p{P}\\$\\+<=>\\^~\\|`]+",
+                    "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+                    "[0-9][0-9][0-9]",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_MPT:
+                // TODO: MPT pre-tokenization regexes are unknown
+                //       the following are close, but not exact. run the following:
+                //       ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf
+                GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed");
+                regex_exprs = {
+                    "\\s?\\p{L}+",
+                    "\\s?\\p{P}+",
+                    "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_STARCODER:
+            case LLAMA_VOCAB_PRE_TYPE_REFACT:
+            case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
+                regex_exprs = {
+                    "\\p{N}",
+                    "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_GPT2:
+            case LLAMA_VOCAB_PRE_TYPE_OLMO:
+                regex_exprs = {
+                    "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
+            case LLAMA_VOCAB_PRE_TYPE_QWEN2:
+                regex_exprs = {
+                    // original regex from tokenizer.json
+                    // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
+                    "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+                };
+                break;
+            case LLAMA_VOCAB_PRE_TYPE_PORO:
+                regex_exprs = {
+                    " ?[^(\\s|.,!?…。,、।۔،)]+",
+                };
                 break;
             default:
-                GGML_ASSERT(false);
+                // default regex for BPE tokenization pre-processing
+                regex_exprs = {
+                    "[\\p{P}\\$\\+<=>\\^~\\|]+",
+                    "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
+                    "\\p{N}+",
+                    "[0-9][0-9][0-9]",
+                };
                 break;
         }
+    }
+
+    void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) const {
+        output.push_back(token_id);
+    }
+
+    bool append_bos(std::vector<llama_vocab::id> & output) const {
+        if (vocab.tokenizer_add_bos) {
+            GGML_ASSERT(vocab.special_bos_id != -1);
+            output.push_back(vocab.special_bos_id);
+            return true;
+        }
+        return false;
+    }
+
+    bool append_eos(std::vector<llama_vocab::id> & output) const {
+        if (vocab.tokenizer_add_eos) {
+            GGML_ASSERT(vocab.special_eos_id != -1);
+            output.push_back(vocab.special_eos_id);
+            return true;
+        }
+        return false;
+    }
+
+    void check_double_bos_eos(const std::vector<llama_vocab::id> & output) const {
+        if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
+            LLAMA_LOG_WARN(
+                "%s: Added a BOS token to the prompt as specified by the model but the prompt "
+                "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
+                "Are you sure this is what you want?\n", __FUNCTION__);
+        }
+        if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) {
+            LLAMA_LOG_WARN(
+                "%s: Added a EOS token to the prompt as specified by the model but the prompt "
+                "also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. "
+                "Are you sure this is what you want?\n", __FUNCTION__);
+        }
+    }
+
+    void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
+        int final_prev_index = -1;
+
+        const auto word_collection = unicode_regex_split(text, regex_exprs);
 
         symbols_final.clear();
 
@@ -13274,7 +13315,7 @@ struct llm_tokenizer_bpe {
             int index = 0;
             size_t offset = 0;
 
-            if (ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
+            if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
                 symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
                 offset = word.size();
             }
@@ -13355,10 +13396,9 @@ struct llm_tokenizer_bpe {
                     for (auto j = str.begin(); j != str.end(); ++j) {
                         std::string byte_str(1, *j);
                         auto token_multibyte = vocab.token_to_id.find(byte_str);
-                        if (token_multibyte == vocab.token_to_id.end()) {
-                            throw std::runtime_error("ERROR: byte not found in vocab");
+                        if (token_multibyte != vocab.token_to_id.end()) {
+                            output.push_back(token_multibyte->second);
                         }
-                        output.push_back((*token_multibyte).second);
                     }
                 } else {
                     output.push_back((*token).second);
@@ -13397,6 +13437,8 @@ private:
 
     const llama_vocab & vocab;
 
+    std::vector<std::string> regex_exprs;
+
     std::vector<llm_symbol> symbols;
     std::vector<llm_symbol> symbols_final;
 
@@ -13677,7 +13719,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
 
                 bool is_prev_special = false;
 
-                if (add_special && vocab.special_add_bos != 0) {
+                if (add_special && vocab.tokenizer_add_bos) {
                     GGML_ASSERT(vocab.special_bos_id != -1);
                     output.push_back(vocab.special_bos_id);
                     is_prev_special = true;
@@ -13687,7 +13729,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
                     if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
                         auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
 
-                        if (vocab.add_space_prefix) {
+                        if (vocab.tokenizer_add_space_prefix) {
                             if (!output.size() || is_prev_special) {  // prefix with space if first token
                                 raw_text = " " + raw_text;
                             }
@@ -13705,23 +13747,24 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
                     }
                 }
 
-                if (add_special && vocab.special_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
+                if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
                     LLAMA_LOG_WARN(
                         "%s: Added a BOS token to the prompt as specified by the model but the prompt "
                         "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
                         "Are you sure this is what you want?\n", __FUNCTION__);
                 }
 
-                if (add_special && vocab.special_add_eos == 1) {
+                if (add_special && vocab.tokenizer_add_eos) {
                     GGML_ASSERT(vocab.special_eos_id != -1);
                     output.push_back(vocab.special_eos_id);
                 }
             } break;
         case LLAMA_VOCAB_TYPE_BPE:
             {
-                if (add_special && vocab.special_add_bos != 0) {
-                    GGML_ASSERT(vocab.special_bos_id != -1);
-                    output.push_back(vocab.special_bos_id);
+                llm_tokenizer_bpe tokenizer(vocab);
+
+                if (add_special) {
+                    tokenizer.append_bos(output);
                 }
 
                 for (const auto & fragment : fragment_buffer) {
@@ -13731,23 +13774,15 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
 #ifdef PRETOKENIZERDEBUG
                         LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
 #endif
-                        llm_tokenizer_bpe tokenizer(vocab);
                         tokenizer.tokenize(raw_text, output);
                     } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
-                        output.push_back(fragment.token);
+                        tokenizer.append(fragment.token, output);
                     }
                 }
 
-                if (add_special && vocab.special_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
-                    LLAMA_LOG_WARN(
-                        "%s: Added a BOS token to the prompt as specified by the model but the prompt "
-                        "also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
-                        "Are you sure this is what you want?\n", __FUNCTION__);
-                }
-
-                if (add_special && vocab.special_add_eos == 1) {
-                    GGML_ASSERT(vocab.special_add_eos != -1);
-                    output.push_back(vocab.special_eos_id);
+                if (add_special) {
+                    tokenizer.append_eos(output);
+                    tokenizer.check_double_bos_eos(output);
                 }
             } break;
         case LLAMA_VOCAB_TYPE_WPM:
@@ -18320,11 +18355,11 @@ llama_token llama_token_nl(const struct llama_model * model) {
 }
 
 int32_t llama_add_bos_token(const struct llama_model * model) {
-    return model->vocab.special_add_bos;
+    return model->vocab.tokenizer_add_bos;
 }
 
 int32_t llama_add_eos_token(const struct llama_model * model) {
-    return model->vocab.special_add_eos;
+    return model->vocab.tokenizer_add_eos;
 }
 
 llama_token llama_token_prefix(const struct llama_model * model) {

+ 120 - 60
scripts/gen-unicode-data.py

@@ -1,83 +1,143 @@
-import regex
-import ctypes
+import array
 import unicodedata
-
-
-class CoodepointFlags (ctypes.Structure):
-    _fields_ = [  # see definition in unicode.h
-        ("is_undefined",   ctypes.c_uint16, 1),
-        ("is_number",      ctypes.c_uint16, 1),  # regex: \p{N}
-        ("is_letter",      ctypes.c_uint16, 1),  # regex: \p{L}
-        ("is_separator",   ctypes.c_uint16, 1),  # regex: \p{Z}
-        ("is_accent_mark", ctypes.c_uint16, 1),  # regex: \p{M}
-        ("is_punctuation", ctypes.c_uint16, 1),  # regex: \p{P}
-        ("is_symbol",      ctypes.c_uint16, 1),  # regex: \p{S}
-        ("is_control",     ctypes.c_uint16, 1),  # regex: \p{C}
-    ]
-
-
-assert (ctypes.sizeof(CoodepointFlags) == 2)
+import requests
 
 
 MAX_CODEPOINTS = 0x110000
 
-regex_number      = regex.compile(r'\p{N}')
-regex_letter      = regex.compile(r'\p{L}')
-regex_separator   = regex.compile(r'\p{Z}')
-regex_accent_mark = regex.compile(r'\p{M}')
-regex_punctuation = regex.compile(r'\p{P}')
-regex_symbol      = regex.compile(r'\p{S}')
-regex_control     = regex.compile(r'\p{C}')
-regex_whitespace  = regex.compile(r'\s')
-
-codepoint_flags = (CoodepointFlags * MAX_CODEPOINTS)()
+UNICODE_DATA_URL = "https://www.unicode.org/Public/UCD/latest/ucd/UnicodeData.txt"
+
+
+# see https://www.unicode.org/L2/L1999/UnicodeData.html
+def unicode_data_iter():
+    res = requests.get(UNICODE_DATA_URL)
+    res.raise_for_status()
+    data = res.content.decode()
+
+    prev = []
+
+    for line in data.splitlines():
+        # ej: 0000;<control>;Cc;0;BN;;;;;N;NULL;;;;
+        line = line.split(";")
+
+        cpt = int(line[0], base=16)
+        assert cpt < MAX_CODEPOINTS
+
+        cpt_lower = int(line[-2] or "0", base=16)
+        assert cpt_lower < MAX_CODEPOINTS
+
+        cpt_upper = int(line[-3] or "0", base=16)
+        assert cpt_upper < MAX_CODEPOINTS
+
+        categ = line[2].strip()
+        assert len(categ) == 2
+
+        bidir = line[4].strip()
+        assert len(categ) == 2
+
+        name = line[1]
+        if name.endswith(", First>"):
+            prev = (cpt, cpt_lower, cpt_upper, categ, bidir)
+            continue
+        if name.endswith(", Last>"):
+            assert prev[1:] == (0, 0, categ, bidir)
+            for c in range(prev[0], cpt):
+                yield (c, cpt_lower, cpt_upper, categ, bidir)
+
+        yield (cpt, cpt_lower, cpt_upper, categ, bidir)
+
+
+# see definition in unicode.h
+CODEPOINT_FLAG_UNDEFINED   = 0x0001  #
+CODEPOINT_FLAG_NUMBER      = 0x0002  # \p{N}
+CODEPOINT_FLAG_LETTER      = 0x0004  # \p{L}
+CODEPOINT_FLAG_SEPARATOR   = 0x0008  # \p{Z}
+CODEPOINT_FLAG_MARK        = 0x0010  # \p{M}
+CODEPOINT_FLAG_PUNCTUATION = 0x0020  # \p{P}
+CODEPOINT_FLAG_SYMBOL      = 0x0040  # \p{S}
+CODEPOINT_FLAG_CONTROL     = 0x0080  # \p{C}
+
+UNICODE_CATEGORY_TO_FLAG = {
+    "Cn": CODEPOINT_FLAG_UNDEFINED,    # Undefined
+    "Cc": CODEPOINT_FLAG_CONTROL,      # Control
+    "Cf": CODEPOINT_FLAG_CONTROL,      # Format
+    "Co": CODEPOINT_FLAG_CONTROL,      # Private Use
+    "Cs": CODEPOINT_FLAG_CONTROL,      # Surrrogate
+    "Ll": CODEPOINT_FLAG_LETTER,       # Lowercase Letter
+    "Lm": CODEPOINT_FLAG_LETTER,       # Modifier Letter
+    "Lo": CODEPOINT_FLAG_LETTER,       # Other Letter
+    "Lt": CODEPOINT_FLAG_LETTER,       # Titlecase Letter
+    "Lu": CODEPOINT_FLAG_LETTER,       # Uppercase Letter
+    "L&": CODEPOINT_FLAG_LETTER,       # Cased Letter
+    "Mc": CODEPOINT_FLAG_MARK,         # Spacing Mark
+    "Me": CODEPOINT_FLAG_MARK,         # Enclosing Mark
+    "Mn": CODEPOINT_FLAG_MARK,         # Nonspacing Mark
+    "Nd": CODEPOINT_FLAG_NUMBER,       # Decimal Number
+    "Nl": CODEPOINT_FLAG_NUMBER,       # Letter Number
+    "No": CODEPOINT_FLAG_NUMBER,       # Other Number
+    "Pc": CODEPOINT_FLAG_PUNCTUATION,  # Connector Punctuation
+    "Pd": CODEPOINT_FLAG_PUNCTUATION,  # Dash Punctuation
+    "Pe": CODEPOINT_FLAG_PUNCTUATION,  # Close Punctuation
+    "Pf": CODEPOINT_FLAG_PUNCTUATION,  # Final Punctuation
+    "Pi": CODEPOINT_FLAG_PUNCTUATION,  # Initial Punctuation
+    "Po": CODEPOINT_FLAG_PUNCTUATION,  # Other Punctuation
+    "Ps": CODEPOINT_FLAG_PUNCTUATION,  # Open Punctuation
+    "Sc": CODEPOINT_FLAG_SYMBOL,       # Currency Symbol
+    "Sk": CODEPOINT_FLAG_SYMBOL,       # Modifier Symbol
+    "Sm": CODEPOINT_FLAG_SYMBOL,       # Math Symbol
+    "So": CODEPOINT_FLAG_SYMBOL,       # Other Symbol
+    "Zl": CODEPOINT_FLAG_SEPARATOR,    # Line Separator
+    "Zp": CODEPOINT_FLAG_SEPARATOR,    # Paragraph Separator
+    "Zs": CODEPOINT_FLAG_SEPARATOR,    # Space Separator
+}
+
+
+codepoint_flags = array.array('H', [CODEPOINT_FLAG_UNDEFINED]) * MAX_CODEPOINTS
 table_whitespace = []
 table_lowercase = []
 table_uppercase = []
 table_nfd = []
 
-for codepoint in range(MAX_CODEPOINTS):
+for (cpt, cpt_lower, cpt_upper, categ, bidir) in unicode_data_iter():
     # convert codepoint to unicode character
-    char = chr(codepoint)
-
-    # regex categories
-    flags = codepoint_flags[codepoint]
-    flags.is_number      = bool(regex_number.match(char))
-    flags.is_letter      = bool(regex_letter.match(char))
-    flags.is_separator   = bool(regex_separator.match(char))
-    flags.is_accent_mark = bool(regex_accent_mark.match(char))
-    flags.is_punctuation = bool(regex_punctuation.match(char))
-    flags.is_symbol      = bool(regex_symbol.match(char))
-    flags.is_control     = bool(regex_control.match(char))
-    flags.is_undefined   = bytes(flags)[0] == 0
-    assert (not flags.is_undefined)
-
-    # whitespaces
-    if bool(regex_whitespace.match(char)):
-        table_whitespace.append(codepoint)
+    char = chr(cpt)
+
+    # codepoint category flags
+    codepoint_flags[cpt] = UNICODE_CATEGORY_TO_FLAG[categ]
 
     # lowercase conversion
-    lower = ord(char.lower()[0])
-    if codepoint != lower:
-        table_lowercase.append((codepoint, lower))
+    if cpt_lower:
+        table_lowercase.append((cpt, cpt_lower))
 
     # uppercase conversion
-    upper = ord(char.upper()[0])
-    if codepoint != upper:
-        table_uppercase.append((codepoint, upper))
+    if cpt_upper:
+        table_uppercase.append((cpt, cpt_upper))
 
     # NFD normalization
     norm = ord(unicodedata.normalize('NFD', char)[0])
-    if codepoint != norm:
-        table_nfd.append((codepoint, norm))
+    if cpt != norm:
+        table_nfd.append((cpt, norm))
+
+
+# whitespaces, see "<White_Space>" https://www.unicode.org/Public/UCD/latest/ucd/PropList.txt
+table_whitespace.extend(range(0x0009, 0x000D + 1))
+table_whitespace.extend(range(0x2000, 0x200A + 1))
+table_whitespace.extend([0x0020, 0x0085, 0x00A0, 0x1680, 0x2028, 0x2029, 0x202F, 0x205F, 0x3000])
+
+
+# sort by codepoint
+table_whitespace.sort()
+table_lowercase.sort()
+table_uppercase.sort()
+table_nfd.sort()
 
 
 # group ranges with same flags
 ranges_flags = [(0, codepoint_flags[0])]  # start, flags
 for codepoint, flags in enumerate(codepoint_flags):
-    if bytes(flags) != bytes(ranges_flags[-1][1]):
+    if flags != ranges_flags[-1][1]:
         ranges_flags.append((codepoint, flags))
-ranges_flags.append((MAX_CODEPOINTS, CoodepointFlags()))
+ranges_flags.append((MAX_CODEPOINTS, 0x0000))
 
 
 # group ranges with same nfd
@@ -90,8 +150,8 @@ for codepoint, norm in table_nfd:
     ranges_nfd[-1] = (start, codepoint, norm)
 
 
-# Generate 'unicode-data.cpp'
-
+# Generate 'unicode-data.cpp':
+#   python ./scripts//gen-unicode-data.py > unicode-data.cpp
 
 def out(line=""):
     print(line, end='\n')  # noqa
@@ -110,12 +170,12 @@ out("""\
 
 out("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = {  // start, flags // last=next_start-1")
 for codepoint, flags in ranges_flags:
-    flags = int.from_bytes(bytes(flags), "little")
     out("{0x%06X, 0x%04X}," % (codepoint, flags))
 out("};\n")
 
 out("const std::unordered_set<uint32_t> unicode_set_whitespace = {")
-out(", ".join("0x%06X" % cpt for cpt in table_whitespace))
+for codepoint in table_whitespace:
+    out("0x%06X," % codepoint)
 out("};\n")
 
 out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")

+ 121 - 49
tests/test-tokenizer-random.py

@@ -11,13 +11,15 @@ import logging
 import argparse
 import subprocess
 import random
+import unicodedata
 
 from typing import Callable, Iterator
 
 import cffi
 from transformers import AutoTokenizer
 
-logger = logging.getLogger("test-tokenizer-random-bpe")
+
+logger = logging.getLogger("test-tokenizer-random")
 
 
 class LibLlama:
@@ -155,9 +157,14 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
         'Cửa Việt',   # llama-3, ignore_merges = true
         '<s>a',       # Phi-3 fail
         '<unk><|endoftext|><s>',  # Phi-3 fail
-        'a\na',       # TODO: Bert fail
-        'a </s> b',   # rstrip phi-3
-        'a <mask> b', # lstrip jina-v2
+        'a\na',            # bert fail
+        '"`',              # falcon
+        ' \u2e4e',         # falcon
+        'a\xa0\xa0\x00b',  # jina-v2-es
+        'one <mask>',      # jina-v2-es  <mask> lstrip=true
+        'a </s> b',        # rstrip phi-3
+        'a <mask> b',      # lstrip jina-v2
+        '\xa0aC',          # deepseek
     ]
 
 
@@ -189,17 +196,23 @@ def generator_random_added_tokens(tokenizer, iterations=100) -> Iterator[str]:
     for m in range(iterations):
         rand.seed(m)
         words = rand.choices(all_tokens, k=500)
-        if words[0] == tokenizer.bos_token:  # skip spam warning of double BOS
+        if words and words[0] == tokenizer.bos_token:  # skip spam warning of double BOS
             while len(words) > 1 and words[1] == tokenizer.bos_token:  # leave one starting BOS
                 words.pop(0)
             if tokenizer.add_bos_token:  # drop all starting BOS
                 words.pop(0)
+        if words and words[-1] == tokenizer.eos_token:  # skip spam warning of double EOS
+            while len(words) > 1 and words[-2] == tokenizer.eos_token:  # leave one trailing EOS
+                words.pop(-1)
+            if tokenizer.add_bos_token:  # drop all trailing EOS
+                words.pop(-1)
         yield "".join(words)
 
 
 def generator_random_chars(iterations=100) -> Iterator[str]:
     """Brute force random text with simple characters"""
 
+    NUM_WORDS = 400
     WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
     CHARS = list(sorted(set("""
         ABCDEFGHIJKLMNOPQRSTUVWXYZ
@@ -213,12 +226,50 @@ def generator_random_chars(iterations=100) -> Iterator[str]:
     for m in range(iterations):
         rand.seed(m)
         text = []
-        num_words = rand.randint(300, 400)
-        for i in range(num_words):
+        for _ in range(NUM_WORDS):
             k = rand.randint(1, 7)
             word = rand.choices(CHARS, k=k)
-            space = rand.choice(WHITESPACES)
-            text.append("".join(word) + space)
+            word.append(rand.choice(WHITESPACES))
+            text.append("".join(word))
+        yield "".join(text)
+
+
+def generator_unicodes() -> Iterator[str]:
+    """Iterate unicode characters"""
+
+    MAX_CODEPOINTS = 0x30000  # 0x110000
+
+    def _valid(cpt):
+        if cpt >= 0x30000:  # unassigned and supplement­ary
+            return False
+        if 0x00D800 <= cpt <= 0x00F8FF:  # Surrogates
+            return False
+        if unicodedata.category(chr(cpt)) == "Cn":
+            return False
+        return True
+
+    characters = [chr(cpt) for cpt in range(1, MAX_CODEPOINTS) if _valid(cpt)]
+
+    yield from characters
+
+
+def generator_random_unicodes(iterations=100) -> Iterator[str]:
+    """Brute force random text with unicode characters"""
+
+    NUM_WORDS = 200
+    WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
+
+    characters = list(generator_unicodes())
+
+    rand = random.Random()
+    for m in range(iterations):
+        rand.seed(m)
+        text = []
+        for _ in range(NUM_WORDS):
+            k = rand.randint(1, 7)
+            word = rand.choices(characters, k=k)
+            word.append(rand.choice(WHITESPACES))
+            text.append("".join(word))
         yield "".join(text)
 
 
@@ -256,25 +307,7 @@ def generator_random_vocab_words(vocab: list[str], iterations=100) -> Iterator[s
         yield "".join(text)
 
 
-def generator_random_bytes(iterations=100) -> Iterator[str]:
-    """Brute force random bytes"""
-
-    WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
-
-    rand = random.Random()
-    for m in range(iterations):
-        rand.seed(m)
-        text = []
-        num_words = rand.randint(300, 400)
-        for i in range(num_words):
-            k = rand.randint(1, 8)
-            word = [chr(r) for r in rand.randbytes(k) if r]
-            word.append(rand.choice(WHITESPACES))
-            text.append("".join(word))
-        yield "".join(text)
-
-
-def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
+def compare_tokenizers(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
 
     def find_first_mismatch(ids1: list[int], ids2: list[int]):
         for i, (a, b) in enumerate(zip(ids1, ids2)):
@@ -284,20 +317,34 @@ def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, g
             return -1
         return min(len(ids1), len(ids2))
 
-    t0 = time.perf_counter()
+    t_tokenizer1 = 0
+    t_tokenizer2 = 0
+    t_start = time.perf_counter()
+    num_errors = 10
+
     logger.info("%s: %s" % (generator.__name__, "ini"))
     for text in generator:
+        # print(repr(text), hex(ord(text[0])), text.encode())
+        t0 = time.perf_counter()
         ids1 = func_tokenize1(text)
+        t1 = time.perf_counter()
         ids2 = func_tokenize2(text)
+        t2 = time.perf_counter()
+        t_tokenizer1 += t1 - t0
+        t_tokenizer2 += t2 - t1
         if ids1 != ids2:
             i = find_first_mismatch(ids1, ids2)
             ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
             ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
-            logger.info(" TokenIDs: " + str(ids1))
-            logger.info(" Expected: " + str(ids2))
-            raise Exception()
-    t1 = time.perf_counter()
-    logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0))
+            logger.error(" TokenIDs: " + str(ids1))
+            logger.error(" Expected: " + str(ids2))
+            # raise Exception()
+            num_errors += 1
+            if num_errors > 10:
+                break
+
+    t_total = time.perf_counter() - t_start
+    logger.info("%s: end,  tok1: %.3f  tok2: %.3f  total: %.3f" % (generator.__name__, t_tokenizer1, t_tokenizer2, t_total))
 
 
 def main(argv: list[str] = None):
@@ -307,7 +354,8 @@ def main(argv: list[str] = None):
     parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
     args = parser.parse_args(argv)
 
-    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
+    logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO)
+    logger.info(f"VOCABFILE: '{args.vocab_file}'")
 
     model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
     tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
@@ -321,18 +369,22 @@ def main(argv: list[str] = None):
     ids = func_tokenize2("a")
     assert 1 <= len(ids) <= 3
     add_bos_token = len(ids) > 1 and tokenizer.bos_token_id == ids[0]
+    add_eos_token = len(ids) > 1 and tokenizer.eos_token_id == ids[-1]
     tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", add_bos_token)
+    tokenizer.add_eos_token = getattr(tokenizer, "add_eos_token", add_eos_token)
 
     vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
-    test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text())
-    test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
-    test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
-    test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer))
-    test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000))
-    test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
-    test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
-    test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 5_000))
-    # test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_bytes(10_000)) # FAIL
+
+    compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text())
+    compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
+    compare_tokenizers(func_tokenize1, func_tokenize2, generator_unicodes())
+    compare_tokenizers(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
+    compare_tokenizers(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer))
+    compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000))
+    compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
+    compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_unicodes(10_000))
+    compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
+    compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 5_000))
 
     model.free()
 
@@ -340,20 +392,40 @@ def main(argv: list[str] = None):
 if __name__ == "__main__":
     # main()
 
+    logging.basicConfig(
+        level    = logging.DEBUG,
+        format   = "%(asctime)s.%(msecs)03d %(name)s %(levelname)s %(message)s",
+        datefmt  = "%Y-%m-%d %H:%M:%S",
+        filename = logger.name + ".log",
+        filemode = "a"
+    )
+
     path_tokenizers   = "./models/tokenizers/"
     path_vocab_format = "./models/ggml-vocab-%s.gguf"
 
     # import os
     # tokenizers = os.listdir(path_tokenizers)
     tokenizers = [
-        "llama-spm",   # SPM
-        "phi-3",       # SPM
-        "jina-v2-en",  # WPM
-        "bert-bge",    # WPM
+        # "llama-spm",   # SPM
+        # "phi-3",       # SPM
+        # "bert-bge",    # WPM
+        # "jina-v2-en",  # WPM
+        "gpt-2",          # BPE
+        "llama-bpe",      # BPE
+        "falcon",         # BPE
+        "starcoder",      # BPE
+        "jina-v2-es",     # BPE
+        "jina-v2-de",     # BPE
+        "jina-v2-code",   # BPE
+        "smaug-bpe",      # BPE
+        "phi-2",          # BPE
+        "deepseek-coder", # BPE
+        "deepseek-llm",   # BPE
     ]
 
     for tokenizer in tokenizers:
-        print("\n" + "=" * 50 + "\n" + tokenizer + "\n")  # noqa
+        logger.info("=" * 50)
+        logger.info(f"TOKENIZER: '{tokenizer}'")
         vocab_file = path_vocab_format % tokenizer
         dir_tokenizer = path_tokenizers + "/" + tokenizer
         main([vocab_file, dir_tokenizer, "--verbose"])

Datei-Diff unterdrückt, da er zu groß ist
+ 273 - 267
unicode-data.cpp


+ 21 - 8
unicode.cpp

@@ -226,8 +226,9 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
         assert(offset_end <= cpts.size());
         start = offset_end;
 
+        static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
         auto _get_cpt = [&] (const size_t pos) -> uint32_t {
-            return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
+            return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
         };
 
         auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
@@ -309,7 +310,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
             }
 
             // regex: \s+(?!\S)
-            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
+            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
                 pos += num_whitespaces - 1;
                 _add_token(pos);
                 continue;
@@ -344,8 +345,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
         assert(offset_end <= cpts.size());
         start = offset_end;
 
+        static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
         auto _get_cpt = [&] (const size_t pos) -> uint32_t {
-            return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
+            return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
         };
 
         auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
@@ -450,7 +452,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
             }
 
             // regex: \s+(?!\S)
-            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
+            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
                 pos += num_whitespaces - 1;
                 _add_token(pos);
                 continue;
@@ -679,10 +681,14 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
                 continue;
             }
 
-            const int cpt_flag = unicode_cpt_flags(cpts[i]).category_flag();
+            const auto flags = unicode_cpt_flags(cpts[i]);
 
-            if (k_ucat_cpt.find(cpt_flag) != k_ucat_cpt.end()) {
-                text_collapsed[i] = k_ucat_cpt.at(cpt_flag);
+            if (flags.is_whitespace) {
+                //NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
+                //text_collapsed[i] = (char) 0x85;  // <Next Line> as whitespace fallback
+                text_collapsed[i] = (char) 0x0B;    // <vertical tab> as whitespace fallback
+            } else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
+                text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
             } else {
                 text_collapsed[i] = (char) 0xD0; // fallback
             }
@@ -766,9 +772,16 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
                 bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
             } else {
                 // no unicode category used, we can use std::wregex directly
-                const std::wstring wtext       = unicode_wstring_from_utf8(text);
                 const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
 
+                // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
+                std::wstring wtext(cpts.begin(), cpts.end());
+                for (size_t i = 0; i < wtext.size(); ++i) {
+                    if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
+                        wtext[i] = 0x0B;
+                    }
+                }
+
                 //printf("text: %s\n", text.c_str());
                 //printf("regex_expr: %s\n", regex_expr.c_str());
                 bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);

Einige Dateien werden nicht angezeigt, da zu viele Dateien in diesem Diff geändert wurden.