Просмотр исходного кода

llama3 custom regex split (#6965)

* merged the changes from deepseeker models to main branch

* Moved regex patterns to unicode.cpp and updated unicode.h

* Moved header files

* Resolved issues

* added and refactored unicode_regex_split and related functions

* Updated/merged the deepseek coder pr

* Refactored code

* Adding unicode regex mappings

* Adding unicode regex function

* Added needed functionality, testing remains

* Fixed issues

* Fixed issue with gpt2 regex custom preprocessor

* unicode : fix? unicode_wstring_to_utf8

* lint : fix whitespaces

* tests : add tokenizer tests for numbers

* unicode : remove redundant headers

* tests : remove and rename tokenizer test scripts

* tests : add sample usage

* gguf-py : reader prints warnings on duplicate keys

* llama : towards llama3 tokenization support (wip)

* unicode : shot in the dark to fix tests on Windows

* unicode : first try custom implementations

* convert : add "tokenizer.ggml.pre" GGUF KV (wip)

* llama : use new pre-tokenizer type

* convert : fix pre-tokenizer type writing

* lint : fix

* make : add test-tokenizer-0-llama-v3

* wip

* models : add llama v3 vocab file

* llama : adapt punctuation regex + add llama 3 regex

* minor

* unicode : set bomb

* unicode : set bomb

* unicode : always use std::wregex

* unicode : support \p{N}, \p{L} and \p{P} natively

* unicode : try fix windows

* unicode : category support via std::regex

* unicode : clean-up

* unicode : simplify

* llama3 custom regex split

* convert : add convert-hf-to-gguf-update.py

ggml-ci

* lint : update

* convert : add falcon

ggml-ci

* unicode : normalize signatures

* lint : fix

* lint : fix

* convert : remove unused functions

* convert : add comments

* convert : exercise contractions

ggml-ci

* Using char32_t for codepoints

* lint : fix

* already exists unicode_tolower()

* Typing

* Restore BOM

* cmake : refactor test targets

* tests : refactor vocab tests

ggml-ci

* tests : add more vocabs and tests

ggml-ci

* unicode : cleanup

* scripts : ignore new update script in check-requirements.sh

* Fix merge

* models : add phi-3, mpt, gpt-2, starcoder

* tests : disable obsolete

ggml-ci

* tests : use faster bpe test

ggml-ci

* llama : more prominent warning for old BPE models

* tests : disable test-tokenizer-1-bpe due to slowness

ggml-ci

* Move unused variable value

* GPT2 custom regex split

* Add alternative regex for custom aplit llama3

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Style

* Add bruteforce random tests for token encoding

* wip: fixing unicode codepoint ranges

* Fix merge

* Unicode tables: separator, lowercase, uppercase and whitespace

* llama3 custom regex split: fix \s

* Restore BOM

* Style

* wip: generate NDF table

* Ignore special tokens for testing

* Clean gen-unicode-data.py

* Refactor random tokenizer test

* lint : fix

* tests : add fail test for llama-bpe

---------

Co-authored-by: Jaggzh <jaggz.h@gmail.com>
Co-authored-by: Kazim Abrar Mahi <kazimabrarmahi135@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: jaime-m-p <>
jaime-m-p 1 год назад
Родитель
Сommit
43248e5594
8 измененных файлов с 1463 добавлено и 302 удалено
  1. 1 0
      convert-hf-to-gguf-update.py
  2. 1 1
      llama.cpp
  3. 38 40
      scripts/gen-unicode-data.py
  4. 295 0
      tests/test-tokenizer-random.py
  5. 875 143
      unicode-data.cpp
  6. 1 0
      unicode-data.h
  7. 249 117
      unicode.cpp
  8. 3 1
      unicode.h

+ 1 - 0
convert-hf-to-gguf-update.py

@@ -261,6 +261,7 @@ tests = [
     "3333333",
     "33333333",
     "333333333",
+    # "Cửa Việt", # llama-bpe fails on this
     chktxt,
 ]
 

+ 1 - 1
llama.cpp

@@ -12488,7 +12488,7 @@ struct llm_tokenizer_wpm {
                 continue;
             }
             code = unicode_tolower(code);
-            if (type == CODEPOINT_TYPE_WHITESPACE) {
+            if (type == CODEPOINT_TYPE_SEPARATOR) {
                 code = ' ';
             }
             std::string s = unicode_cpt_to_utf8(code);

+ 38 - 40
scripts/gen-unicode-data.py

@@ -1,31 +1,14 @@
 import regex
 
 
-def cpt_to_utf8_str(cpt):
-    if cpt <= 0xFF:
-        return bytes([cpt, 0, 0, 0])
-    elif cpt <= 0xFFFF:
-        return bytes([cpt & 0xFF, cpt >> 8, 0, 0])
-    elif cpt <= 0xFFFFFF:
-        return bytes([cpt & 0xFF, (cpt >> 8) & 0xFF, (cpt >> 16) & 0xFF, 0])
-    else:
-        return bytes([cpt & 0xFF, (cpt >> 8) & 0xFF, (cpt >> 16) & 0xFF, cpt >> 24])
-
-
-def is_match(codepoint, regex_expr):
-    try:
-        res = regex.match(regex_expr, cpt_to_utf8_str(codepoint).decode('utf-32'))
-        return res is not None
-    except Exception:
-        return False
-
-
 def get_matches(regex_expr):
+    regex_expr_compiled = regex.compile(regex_expr)
     unicode_ranges = []
     current_range = None
 
     for codepoint in range(0x110000):
-        if is_match(codepoint, regex_expr):
+        char = chr(codepoint)
+        if regex_expr_compiled.match(char):
             if current_range is None:
                 current_range = [codepoint, codepoint]
             else:
@@ -40,27 +23,42 @@ def get_matches(regex_expr):
     return unicode_ranges
 
 
-def print_cat(cat, ranges):
-    print("const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_{} = {{".format(cat)) # noqa: NP100
-    cnt = 0
-    for start, end in ranges:
-        if cnt % 4 != 0:
-            print(" ", end="") # noqa: NP100
-        print("{{0x{:08X}, 0x{:08X}}},".format(start, end), end="") # noqa: NP100
-        if cnt % 4 == 3:
-            print("") # noqa: NP100
-        cnt += 1
-
-    if cnt % 4 != 0:
-        print("") # noqa: NP100
+def print_cat(mode, cat, ranges):
+    if mode == "range":
+        print("const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_{} = {{".format(cat)) # noqa: NP100
+    if mode == "map":
+        print("const std::map<uint32_t, uint32_t> unicode_map_{} = {{".format(cat)) # noqa: NP100
+    for i, values in enumerate(ranges):
+        end = ",\n" if (i % 4 == 3 or i + 1 == len(ranges)) else ", "
+        values = ["0x%08X" % value for value in values]
+        print("{" + ", ".join(values) + "}", end=end) # noqa: NP100
     print("};") # noqa: NP100
     print("") # noqa: NP100
 
 
-print_cat("number",      get_matches(r'\p{N}'))
-print_cat("letter",      get_matches(r'\p{L}'))
-print_cat("whitespace",  get_matches(r'\p{Z}'))
-print_cat("accent_mark", get_matches(r'\p{M}'))
-print_cat("punctuation", get_matches(r'\p{P}'))
-print_cat("symbol",      get_matches(r'\p{S}'))
-print_cat("control",     get_matches(r'\p{C}'))
+print_cat("range", "number",      get_matches(r'\p{N}'))
+print_cat("range", "letter",      get_matches(r'\p{L}'))
+print_cat("range", "separator",   get_matches(r'\p{Z}'))
+print_cat("range", "accent_mark", get_matches(r'\p{M}'))
+print_cat("range", "punctuation", get_matches(r'\p{P}'))
+print_cat("range", "symbol",      get_matches(r'\p{S}'))
+print_cat("range", "control",     get_matches(r'\p{C}'))
+
+print_cat("range", "whitespace",  get_matches(r'\s'))
+
+
+map_lowercase = []
+map_uppercase = []
+for codepoint in range(0x110000):
+    char = chr(codepoint)
+    lower = ord(char.lower()[0])
+    upper = ord(char.upper()[0])
+    if codepoint != lower:
+        map_lowercase.append((codepoint, lower))
+    if codepoint != upper:
+        map_uppercase.append((codepoint, upper))
+print_cat("map", "lowercase", map_lowercase)
+print_cat("map", "uppercase", map_uppercase)
+
+
+# TODO: generate unicode_map_nfd

+ 295 - 0
tests/test-tokenizer-random.py

@@ -0,0 +1,295 @@
+# Test libllama tokenizer == AutoTokenizer.
+# Brute force random tokens/text generation.
+#
+# Sample usage:
+#
+#   python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe
+#
+
+import time
+import logging
+import argparse
+import subprocess
+import random
+
+from typing import Iterator
+
+import cffi
+from transformers import AutoTokenizer, PreTrainedTokenizerBase
+
+logger = logging.getLogger("test-tokenizer-random-bpe")
+
+
+class LibLlama:
+
+    DEFAULT_PATH_LLAMA_H = "./llama.h"
+    DEFAULT_PATH_LIBLLAMA = "./build/libllama.so"  # CMakeLists.txt: BUILD_SHARED_LIBS ON
+
+    def __init__(self, path_llama_h: str = None, path_libllama: str = None):
+        path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H
+        path_libllama = path_libllama or self.DEFAULT_PATH_LIBLLAMA
+        (self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_libllama)
+        self.lib.llama_backend_init()
+
+    def _load_libllama_cffi(self, path_llama_h: str, path_libllama: str):
+        cmd = ["gcc", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)=", path_llama_h]
+        res = subprocess.run(cmd, stdout=subprocess.PIPE)
+        assert (res.returncode == 0)
+        source = res.stdout.decode()
+        ffi = cffi.FFI()
+        if True:  # workarounds for pycparser
+            source = "typedef struct { } __builtin_va_list;" + "\n" + source
+            source = source.replace("sizeof (int)",    str(ffi.sizeof("int")))
+            source = source.replace("sizeof (void *)", str(ffi.sizeof("void*")))
+            source = source.replace("sizeof (size_t)", str(ffi.sizeof("size_t")))
+            source = source.replace("sizeof(int32_t)", str(ffi.sizeof("int32_t")))
+        ffi.cdef(source, override=True)
+        lib = ffi.dlopen(path_libllama)
+        return (ffi, lib)
+
+    def model_default_params(self, **kwargs):
+        mparams = self.lib.llama_model_default_params()
+        for k, v in kwargs.items():
+            setattr(mparams, k, v)
+        return mparams
+
+    def context_default_params(self, **kwargs):
+        cparams = self.lib.llama_context_default_params()
+        for k, v in kwargs.items():
+            setattr(cparams, k, v)
+        return cparams
+
+
+class LibLlamaModel:
+
+    def __init__(self, libllama: LibLlama, path_model: str, mparams={}, cparams={}):
+        self.lib = libllama.lib
+        self.ffi = libllama.ffi
+        if isinstance(mparams, dict):
+            mparams = libllama.model_default_params(**mparams)
+        self.model = self.lib.llama_load_model_from_file(path_model.encode(), mparams)
+        if not self.model:
+            raise RuntimeError("error: failed to load model '%s'" % path_model)
+        if isinstance(cparams, dict):
+            cparams = libllama.context_default_params(**cparams)
+        self.ctx = self.lib.llama_new_context_with_model(self.model, cparams)
+        if not self.ctx:
+            raise RuntimeError("error: failed to create context for model '%s'" % path_model)
+        n_tokens_max = self.lib.llama_n_ctx(self.ctx)
+        self.token_ids = self.ffi.new("llama_token[]", n_tokens_max)
+
+    def free(self):
+        if self.ctx:
+            self.lib.llama_free(self.ctx)
+        if self.model:
+            self.lib.llama_free_model(self.model)
+        self.ctx = None
+        self.model = None
+        self.lib = None
+
+    def tokenize(self, text: str, n_tokens_max: int = 0, add_special: bool = False, parse_special: bool = False) -> list[int]:
+        n_tokens_max = n_tokens_max if n_tokens_max > 0 else len(self.token_ids)
+        text = text.encode("utf-8")
+        num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, n_tokens_max, add_special, parse_special)
+        if num < 0:
+            return []
+        return list(self.token_ids[0:num])
+
+
+def generator_custom_text() -> Iterator[str]:
+    """General tests"""
+    yield from [
+        "",
+        " ",
+        "  ",
+        "   ",
+        "\t",
+        "\n",
+        "\n\n",
+        "\n\n\n",
+        "\t\n",
+        "Hello world",
+        " Hello world",
+        "Hello World",
+        " Hello World",
+        " Hello World!",
+        "Hello, world!",
+        " Hello, world!",
+        " this is 🦙.cpp",
+        "w048 7tuijk dsdfhu",
+        "нещо на Български",
+        "កាន់តែពិសេសអាចខលចេញ",
+        "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)",
+        "Hello",
+        " Hello",
+        "  Hello",
+        "   Hello",
+        "    Hello",
+        "    Hello\n    Hello",
+        " (",
+        "\n =",
+        "' era",
+        "Hello, y'all! How are you 😁 ?我想在apple工作1314151天~",
+        "3",
+        "33",
+        "333",
+        "3333",
+        "33333",
+        "333333",
+        "3333333",
+        "33333333",
+        "333333333",
+    ]
+
+
+def generator_custom_text_edge_cases() -> Iterator[str]:
+    """Edge cases found while debugging"""
+    yield from [
+        '\x1f-a',   # unicode_ranges_control, {0x00001C, 0x00001F}
+        '¼-a',      # unicode_ranges_digit, 0x00BC
+        '½-a',      # unicode_ranges_digit, 0x00BD
+        '¾-a',      # unicode_ranges_digit, 0x00BE
+        'a 〇b',    # unicode_ranges_digit, 0x3007
+        'Ⅵ-a',     # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms
+        '\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM)
+        '<s>a'      # TODO: Phi-3 fail
+    ]
+
+
+def generator_random_chars(iterations = 100) -> Iterator[str]:
+    """Brute force random text with simple characters"""
+
+    WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
+    CHARS = list(set("""
+        ABCDEFGHIJKLMNOPQRSTUVWXYZ
+        abcdefghijklmnopqrstuvwxyz
+        ÁÉÍÓÚÀÈÌÒÙÂÊÎÔÛÄËÏÖÜ
+        áéíóúàèìòùâêîôûäëïöü
+        .-,*/-+ª!"·$%&/()=?¿[]{}<>\\|@#~½¬~;:_
+    """))
+
+    rand = random.Random()
+    for m in range(iterations):
+        rand.seed(m)
+        text = []
+        num_words = rand.randint(300, 400)
+        for i in range(num_words):
+            k = rand.randint(1, 7)
+            word = rand.choices(CHARS, k=k)
+            space = rand.choice(WHITESPACES)
+            text.append("".join(word) + space)
+        yield "".join(text)
+
+
+def generator_random_vocab_chars(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]:
+    """Brute force random text with vocab characters"""
+
+    vocab_ids = list(tokenizer.vocab.values())
+    vocab_text = tokenizer.decode(vocab_ids, skip_special_tokens=True)
+    vocab_chars = list(set(vocab_text))
+    del vocab_ids, vocab_text
+
+    rand = random.Random()
+    for m in range(iterations):
+        rand.seed(m)
+        text = rand.choices(vocab_chars, k=1024)
+        yield "".join(text)
+
+
+def generator_random_vocab_tokens(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]:
+    """Brute force random text from vocab tokens"""
+
+    space_id = tokenizer.encode(" ", add_special_tokens=False)[0]
+    vocab_ids = list(tokenizer.vocab.values())
+    vocab_ids = list(sorted(vocab_ids + vocab_ids))
+    for i in range(1, len(vocab_ids), 2):
+        vocab_ids[i] = space_id
+    vocab_tokens = tokenizer.decode(vocab_ids, skip_special_tokens=True)
+    vocab_tokens = vocab_tokens.split(" ")
+    del vocab_ids
+
+    yield from vocab_tokens
+
+    rand = random.Random()
+    for m in range(iterations):
+        rand.seed(m)
+        text = []
+        num_words = rand.randint(300, 400)
+        for i in range(num_words):
+            k = rand.randint(1, 3)
+            tokens = rand.choices(vocab_tokens, k=k)
+            tokens = [t.strip(" \n\r\t") for t in tokens]
+            sep = rand.choice("     \n\r\t")
+            text.append("".join(tokens) + sep)
+        yield "".join(text)
+
+
+def generator_random_bytes(iterations = 100) -> Iterator[str]:
+    """Brute force random bytes"""
+
+    WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
+
+    rand = random.Random()
+    for m in range(iterations):
+        rand.seed(m)
+        text = []
+        num_words = rand.randint(300, 400)
+        for i in range(num_words):
+            k = rand.randint(1, 8)
+            word = [chr(r) for r in rand.randbytes(k) if r]
+            word.append(rand.choice(WHITESPACES))
+            text.append("".join(word))
+        yield "".join(text)
+
+
+def test_compare_tokenizer(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, generator: Iterator[str]):
+
+    def find_first_mismatch(ids1: list[int], ids2: list[int]):
+        for i, (a,b) in enumerate(zip(ids1, ids2)):
+            if a != b:
+                return i
+        if len(ids1) == len(ids2):
+            return -1
+        return min(len(ids1), len(ids2))
+
+    t0 = time.perf_counter()
+    logger.info("%s: %s" % (generator.__name__, "ini"))
+    for text in generator:
+        ids1 = model.tokenize(text, add_special=False, parse_special=False)
+        ids2 = tokenizer.encode(text, add_special_tokens=False)
+        if ids1 != ids2:
+            i = find_first_mismatch(ids1, ids2)
+            ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1]
+            ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1]
+            text2 = tokenizer.decode(ids2, skip_special_tokens=True)
+            assert (text2 in text)
+            logger.info(" Text:     " + repr(text2))
+            logger.info(" TokenIDs: " + str(ids1))
+            logger.info(" Expected: " + str(ids2))
+            raise Exception()
+    t1 = time.perf_counter()
+    logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0))
+
+
+if __name__ == "__main__":
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument("vocab_file", help="path to vocab 'gguf' file")
+    parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
+    parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
+    args = parser.parse_args()
+
+    logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
+
+    model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=2048))
+
+    tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
+
+    test_compare_tokenizer(model, tokenizer, generator_custom_text())
+    test_compare_tokenizer(model, tokenizer, generator_custom_text_edge_cases())
+    test_compare_tokenizer(model, tokenizer, generator_random_chars(10_000))
+    test_compare_tokenizer(model, tokenizer, generator_random_vocab_chars(tokenizer, 10_000))
+    test_compare_tokenizer(model, tokenizer, generator_random_vocab_tokens(tokenizer, 10_000))
+    # test_compare_tokenizer(model, tokenizer, generator_random_bytes(10_000)) # FAIL
+
+    model.free()

Разница между файлами не показана из-за своего большого размера
+ 875 - 143
unicode-data.cpp


+ 1 - 0
unicode-data.h

@@ -7,6 +7,7 @@
 
 extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_number;
 extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_letter;
+extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_separator;
 extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace;
 extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_accent_mark;
 extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuation;

+ 249 - 117
unicode.cpp

@@ -9,6 +9,7 @@
 #include <stdexcept>
 #include <string>
 #include <unordered_map>
+#include <unordered_set>
 #include <utility>
 #include <vector>
 #include <locale>
@@ -111,27 +112,27 @@ static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset)
 static std::unordered_map<uint32_t, int> unicode_cpt_type_map() {
     std::unordered_map<uint32_t, int> cpt_types;
     for (auto p : unicode_ranges_number) {
-        for (auto i = p.first; i <= p.second; ++ i) {
+        for (auto i = p.first; i <= p.second; ++i) {
             cpt_types[i] = CODEPOINT_TYPE_NUMBER;
         }
     }
     for (auto p : unicode_ranges_letter) {
-        for (auto i = p.first; i <= p.second; ++ i) {
+        for (auto i = p.first; i <= p.second; ++i) {
             cpt_types[i] = CODEPOINT_TYPE_LETTER;
         }
     }
-    for (auto p : unicode_ranges_whitespace) {
-        for (auto i = p.first; i <= p.second; ++ i) {
-            cpt_types[i] = CODEPOINT_TYPE_WHITESPACE;
+    for (auto p : unicode_ranges_separator) {
+        for (auto i = p.first; i <= p.second; ++i) {
+            cpt_types[i] = CODEPOINT_TYPE_SEPARATOR;
         }
     }
     for (auto p : unicode_ranges_accent_mark) {
-        for (auto i = p.first; i <= p.second; ++ i) {
+        for (auto i = p.first; i <= p.second; ++i) {
             cpt_types[i] = CODEPOINT_TYPE_ACCENT_MARK;
         }
     }
     for (auto p : unicode_ranges_punctuation) {
-        for (auto i = p.first; i <= p.second; ++ i) {
+        for (auto i = p.first; i <= p.second; ++i) {
             cpt_types[i] = CODEPOINT_TYPE_PUNCTUATION;
         }
     }
@@ -141,7 +142,7 @@ static std::unordered_map<uint32_t, int> unicode_cpt_type_map() {
         }
     }
     for (auto p : unicode_ranges_control) {
-        for (auto i = p.first; i <= p.second; ++ i) {
+        for (auto i = p.first; i <= p.second; ++i) {
             cpt_types[i] = CODEPOINT_TYPE_CONTROL;
         }
     }
@@ -224,138 +225,256 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
     std::vector<size_t> bpe_offsets; // store the offset of each word
     bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
 
-    size_t start = 0;
-
     const auto cpts = unicode_cpts_from_utf8(text);
 
+    size_t start = 0;
     for (auto offset : offsets) {
-        std::string token;
+        const size_t offset_ini = start;
+        const size_t offset_end = start + offset;
+        assert(offset_end <= cpts.size());
+        start = offset_end;
+
+        auto _get_cpt = [&] (const size_t pos) -> char32_t {
+            return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
+        };
+
+        auto _get_cpt_type = [&] (const size_t pos) -> int {
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED;
+        };
+
+        size_t _prev_end = offset_ini;
+        auto _add_token = [&] (const size_t end) -> size_t {
+            assert(_prev_end <= end && end <= offset_end);
+            size_t len = end - _prev_end;
+            if (len > 0) {
+                bpe_offsets.push_back(len);
+            }
+            _prev_end = end;
+            //if (len > 0) {
+            //    std::string s = "";
+            //    for(size_t p = end-len; p < end; p++)
+            //        s += unicode_cpt_to_utf8(cpts[p]);
+            //    printf(">>> '%s'\n", s.c_str());
+            //}
+            return len;
+        };
+
+        for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
+            const char32_t cpt = _get_cpt(pos);
+            const int cpt_type = _get_cpt_type(pos);
+
+            // regex: 's|'t|'re|'ve|'m|'ll|'d
+            if (cpt == '\'' && pos+1 < offset_end) {
+                char32_t cpt_next = _get_cpt(pos+1);
+                if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
+                    pos += _add_token(pos+2);
+                    continue;
+                }
+                if (pos+2 < offset_end) {
+                    char32_t cpt_next_next = _get_cpt(pos+2);
+                    if ((cpt_next == 'r' && cpt_next_next == 'e') ||
+                        (cpt_next == 'v' && cpt_next_next == 'e') ||
+                        (cpt_next == 'l' && cpt_next_next == 'l')) {
+                        pos += _add_token(pos+3);
+                        continue;
+                    }
+                }
+            }
 
-        bool collecting_numeric = false;
-        bool collecting_letter = false;
-        bool collecting_special = false;
-        bool collecting_whitespace_lookahead = false;
-        bool collecting = false;
+            char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt);
+            int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
+            // regex: <space>?\p{L}+
+            if (cpt2_type == CODEPOINT_TYPE_LETTER) {
+                pos += (cpt == ' ');
+                while (cpt2_type == CODEPOINT_TYPE_LETTER) {
+                    cpt2_type = _get_cpt_type(++pos);
+                }
+                _add_token(pos);
+                continue;
+            }
+            // regex: <space>?\p{N}+
+            if (cpt2_type == CODEPOINT_TYPE_NUMBER) {
+                pos += (cpt == ' ');
+                while (cpt2_type == CODEPOINT_TYPE_NUMBER) {
+                    cpt2_type = _get_cpt_type(++pos);
+                }
+                _add_token(pos);
+                continue;
+            }
+            // regex: <space>?[^\s\p{L}\p{N}]+
+            if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
+                pos += (cpt == ' ');
+                while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
+                    cpt2_type = _get_cpt_type(++pos);
+                    cpt2 = _get_cpt(pos);
+                }
+                _add_token(pos);
+                continue;
+            }
 
-        std::vector<std::string> text_utf;
-        text_utf.reserve(offset);
+            size_t num_whitespaces = 0;
+            while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) {
+                num_whitespaces++;
+            }
 
-        for (size_t i = start; i < start + offset; ++i) {
-            text_utf.emplace_back(unicode_cpt_to_utf8(cpts[i]));
+            // regex: \s+(?!\S)
+            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
+                pos += num_whitespaces - 1;
+                _add_token(pos);
+                continue;
+            }
+
+            // regex: \s+
+            if (num_whitespaces > 0) {
+                pos += num_whitespaces;
+                _add_token(pos);
+                continue;
+            }
+
+            // no matches
+            _add_token(++pos);
         }
+    }
+
+    return bpe_offsets;
+}
 
-        for (int i = 0; i < (int)text_utf.size(); i++) {
-            const std::string & utf_char = text_utf[i];
-            bool split_condition = false;
-            int bytes_remain = text_utf.size() - i;
+// LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
+static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string & text, const std::vector<size_t> & offsets) {
+    std::vector<size_t> bpe_offsets; // store the offset of each word
+    bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size
 
-            // forward backward lookups
-            const std::string & utf_char_next      = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : "";
-            const std::string & utf_char_next_next = (i + 2 < (int)text_utf.size()) ? text_utf[i + 2] : "";
+    const auto cpts = unicode_cpts_from_utf8(text);
 
-            // handling contractions
-            if (!split_condition && bytes_remain >= 2) {
-                // 's|'t|'m|'d
-                if (utf_char == "\'" && (utf_char_next == "s" || utf_char_next == "t" || utf_char_next == "m" || utf_char_next == "d")) {
-                    split_condition = true;
+    size_t start = 0;
+    for (auto offset : offsets) {
+        const size_t offset_ini = start;
+        const size_t offset_end = start + offset;
+        assert(offset_end <= cpts.size());
+        start = offset_end;
+
+        auto _get_cpt = [&] (const size_t pos) -> char32_t {
+            return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
+        };
+
+        auto _get_cpt_type = [&] (const size_t pos) -> int {
+            return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED;
+        };
+
+        size_t _prev_end = offset_ini;
+        auto _add_token = [&] (const size_t end) -> size_t {
+            assert(_prev_end <= end && end <= offset_end);
+            size_t len = end - _prev_end;
+            if (len > 0) {
+                bpe_offsets.push_back(len);
+            }
+            _prev_end = end;
+            //if (len > 0) {
+            //    std::string s = "";
+            //    for(size_t p = end-len; p < end; p++)
+            //        s += unicode_cpt_to_utf8(cpts[p]);
+            //    printf(">>> '%s'\n", s.c_str());
+            //}
+            return len;
+        };
+
+        for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
+            const char32_t cpt = _get_cpt(pos);
+            const int cpt_type = _get_cpt_type(pos);
+
+            // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
+            if (cpt == '\'' && pos+1 < offset_end) {
+                char32_t cpt_next = unicode_tolower(_get_cpt(pos+1));
+                if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') {
+                    pos += _add_token(pos+2);
+                    continue;
                 }
-                if (split_condition) {
-                    if (token.size()) {
-                        bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
+                if (pos+2 < offset_end) {
+                    char32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2));
+                    if ((cpt_next == 'r' && cpt_next_next == 'e') ||
+                        (cpt_next == 'v' && cpt_next_next == 'e') ||
+                        (cpt_next == 'l' && cpt_next_next == 'l')) {
+                        pos += _add_token(pos+3);
+                        continue;
                     }
-                    token = utf_char + utf_char_next;
-                    bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
-                    token = "";
-                    i++;
-                    continue;
                 }
             }
-            if (!split_condition && bytes_remain >= 3) {
-                // 're|'ve|'ll
-                if (utf_char == "\'" && (
-                    (utf_char_next == "r" && utf_char_next_next == "e") ||
-                    (utf_char_next == "v" && utf_char_next_next == "e") ||
-                    (utf_char_next == "l" && utf_char_next_next == "l"))
-                    ) {
-                    split_condition = true;
-                }
-                if (split_condition) {
-                    // current token + next token can be defined
-                    if (token.size()) {
-                        bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
-                    }
-                    token =  utf_char;
-                    token += utf_char_next;
-                    token += utf_char_next_next;
 
-                    bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
-                    token = "";
-                    i += 2;
+            // regex: [^\r\n\p{L}\p{N}]?\p{L}+  //####FIXME: the first \p{L} is correct?
+            if (cpt != '\r' && cpt != '\n' && /*cpt_type != CODEPOINT_TYPE_LETTER &&*/ cpt_type != CODEPOINT_TYPE_NUMBER) {
+                if (cpt_type == CODEPOINT_TYPE_LETTER || _get_cpt_type(pos+1) == CODEPOINT_TYPE_LETTER) {  // one or more letters
+                    pos++;
+                    while (_get_cpt_type(pos) == CODEPOINT_TYPE_LETTER) {
+                        pos++;
+                    }
+                    _add_token(pos);
                     continue;
                 }
             }
 
-            if (!split_condition && !collecting) {
-                if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER)) {
-                    collecting_letter = true;
-                    collecting = true;
-                }
-                else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_NUMBER || (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_NUMBER)) {
-                    collecting_numeric = true;
-                    collecting = true;
-                }
-                else if (
-                    ((unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_NUMBER) && (unicode_cpt_type(utf_char) != CODEPOINT_TYPE_WHITESPACE)) ||
-                    (token.empty() && utf_char == " " && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_LETTER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_NUMBER && unicode_cpt_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE)
-                    ) {
-                    collecting_special = true;
-                    collecting = true;
-                }
-                else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE && unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_WHITESPACE) {
-                    collecting_whitespace_lookahead = true;
-                    collecting = true;
-                }
-                else if (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE) {
-                    split_condition = true;
+            // regex: \p{N}{1,3}
+            if (cpt_type == CODEPOINT_TYPE_NUMBER) {
+                size_t ini = pos;
+                while (_get_cpt_type(pos) == CODEPOINT_TYPE_NUMBER) {
+                    if (++pos - ini >= 3 ) {
+                        _add_token(pos);
+                        ini = pos;
+                    }
                 }
+                _add_token(pos);
+                continue;
             }
-            else if (!split_condition && collecting) {
-                if (collecting_letter && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_LETTER) {
-                    split_condition = true;
-                }
-                else if (collecting_numeric && unicode_cpt_type(utf_char) != CODEPOINT_TYPE_NUMBER) {
-                    split_condition = true;
+
+            // regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
+            char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt);
+            int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
+            if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
+                pos += (cpt == ' ');
+                while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
+                    cpt2_type = _get_cpt_type(++pos);
+                    cpt2 = _get_cpt(pos);
                 }
-                else if (collecting_special && (unicode_cpt_type(utf_char) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_NUMBER || unicode_cpt_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) {
-                    split_condition = true;
+                while (cpt2 == '\r' || cpt2 == '\n') {
+                    cpt2 = _get_cpt(++pos);
                 }
-                else if (collecting_whitespace_lookahead && (unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_LETTER || unicode_cpt_type(utf_char_next) == CODEPOINT_TYPE_NUMBER)) {
-                    split_condition = true;
+                _add_token(pos);
+                continue;
+            }
+
+            size_t num_whitespaces = 0;
+            size_t last_end_r_or_n = 0;
+            while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) {
+                char32_t cpt2 = _get_cpt(pos+num_whitespaces);
+                if (cpt2 == '\r' || cpt2 == '\n') {
+                    last_end_r_or_n = pos + num_whitespaces + 1;
                 }
+                num_whitespaces++;
             }
 
-            if (utf_char_next == "") {
-                split_condition = true; // final
-                token += utf_char;
+            // regex: \s*[\r\n]+
+            if (last_end_r_or_n > 0) {
+                pos = last_end_r_or_n;
+                _add_token(pos);
+                continue;
             }
 
-            if (split_condition) {
-                if (token.size()) {
-                    bpe_offsets.emplace_back(unicode_cpts_from_utf8(token).size());
-                }
-                token = utf_char;
-                collecting = false;
-                collecting_letter = false;
-                collecting_numeric = false;
-                collecting_special = false;
-                collecting_whitespace_lookahead = false;
+            // regex: \s+(?!\S)
+            if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
+                pos += num_whitespaces - 1;
+                _add_token(pos);
+                continue;
             }
-            else {
-                token += utf_char;
+
+            // regex: \s+
+            if (num_whitespaces > 0) {
+                pos += num_whitespaces;
+                _add_token(pos);
+                continue;
             }
-        }
 
-        start += offset;
+            // no matches
+            _add_token(++pos);
+        }
     }
 
     return bpe_offsets;
@@ -424,14 +543,14 @@ static std::vector<size_t> unicode_regex_split_stl(const std::string & text, con
 static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) {
     std::vector<size_t> bpe_offsets;
 
-    (void)(text);
-    (void)(regex_expr);
-    (void)(offsets);
-    // TODO: this implementation is actually wrong, uncomment and run:
-    //       make -j && ./bin/test-tokenizer-0 ../models/ggml-vocab-gpt-2.gguf
-    //if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
-    //    bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
-    //}
+    if (regex_expr == "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)") {
+        bpe_offsets = unicode_regex_split_custom_gpt2(text, offsets);
+    } else if (
+            regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" ||
+            regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") {
+
+        bpe_offsets = unicode_regex_split_custom_llama3(text, offsets);
+    }
 
     return bpe_offsets;
 }
@@ -506,6 +625,19 @@ int unicode_cpt_type(const std::string & utf8) {
     return unicode_cpt_type(unicode_cpt_from_utf8(utf8, offset));
 }
 
+bool unicode_cpt_is_whitespace(uint32_t cp) {
+    static const std::unordered_set<uint32_t> is_whitespace = [] {
+        std::unordered_set<uint32_t> is_whitespace;
+        for (auto p : unicode_ranges_whitespace) {
+            for (auto i = p.first; i <= p.second; ++i) {
+                is_whitespace.insert(i);
+            }
+        }
+        return is_whitespace;
+    }();
+    return (bool)is_whitespace.count(cp);
+}
+
 std::string unicode_byte_to_utf8(uint8_t byte) {
     static std::unordered_map<uint8_t, std::string> map = unicode_byte_to_utf8_map();
     return map.at(byte);

+ 3 - 1
unicode.h

@@ -7,7 +7,7 @@
 #define CODEPOINT_TYPE_UNIDENTIFIED 0
 #define CODEPOINT_TYPE_NUMBER       1
 #define CODEPOINT_TYPE_LETTER       2
-#define CODEPOINT_TYPE_WHITESPACE   3
+#define CODEPOINT_TYPE_SEPARATOR    3
 #define CODEPOINT_TYPE_ACCENT_MARK  4
 #define CODEPOINT_TYPE_PUNCTUATION  5
 #define CODEPOINT_TYPE_SYMBOL       6
@@ -21,6 +21,8 @@ std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & c
 int unicode_cpt_type(uint32_t cp);
 int unicode_cpt_type(const std::string & utf8);
 
+bool unicode_cpt_is_whitespace(uint32_t cp);
+
 std::string unicode_byte_to_utf8(uint8_t byte);
 uint8_t unicode_utf8_to_byte(const std::string & utf8);
 

Некоторые файлы не были показаны из-за большого количества измененных файлов