test-tokenizer-random.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. # Test libllama tokenizer == AutoTokenizer.
  2. # Brute force random words/text generation.
  3. #
  4. # Sample usage:
  5. #
  6. # python3 tests/test-tokenizer-random.py ./models/ggml-vocab-llama-bpe.gguf ./models/tokenizers/llama-bpe
  7. #
  8. import time
  9. import logging
  10. import argparse
  11. import subprocess
  12. import random
  13. import unicodedata
  14. from typing import Callable, Iterator
  15. import cffi
  16. from transformers import AutoTokenizer
  17. logger = logging.getLogger("test-tokenizer-random")
  18. class LibLlama:
  19. DEFAULT_PATH_LLAMA_H = "./llama.h"
  20. DEFAULT_PATH_LIBLLAMA = "./build/libllama.so" # CMakeLists.txt: BUILD_SHARED_LIBS ON
  21. def __init__(self, path_llama_h: str = None, path_libllama: str = None):
  22. path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H
  23. path_libllama = path_libllama or self.DEFAULT_PATH_LIBLLAMA
  24. (self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_libllama)
  25. self.lib.llama_backend_init()
  26. def _load_libllama_cffi(self, path_llama_h: str, path_libllama: str):
  27. cmd = ["gcc", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)=", path_llama_h]
  28. res = subprocess.run(cmd, stdout=subprocess.PIPE)
  29. assert (res.returncode == 0)
  30. source = res.stdout.decode()
  31. ffi = cffi.FFI()
  32. if True: # workarounds for pycparser
  33. source = "typedef struct { } __builtin_va_list;" + "\n" + source
  34. source = source.replace("sizeof (int)", str(ffi.sizeof("int")))
  35. source = source.replace("sizeof (void *)", str(ffi.sizeof("void*")))
  36. source = source.replace("sizeof (size_t)", str(ffi.sizeof("size_t")))
  37. source = source.replace("sizeof(int32_t)", str(ffi.sizeof("int32_t")))
  38. ffi.cdef(source, override=True)
  39. lib = ffi.dlopen(path_libllama)
  40. return (ffi, lib)
  41. def model_default_params(self, **kwargs):
  42. mparams = self.lib.llama_model_default_params()
  43. for k, v in kwargs.items():
  44. setattr(mparams, k, v)
  45. return mparams
  46. def context_default_params(self, **kwargs):
  47. cparams = self.lib.llama_context_default_params()
  48. for k, v in kwargs.items():
  49. setattr(cparams, k, v)
  50. return cparams
  51. class LibLlamaModel:
  52. def __init__(self, libllama: LibLlama, path_model: str, mparams={}, cparams={}):
  53. self.lib = libllama.lib
  54. self.ffi = libllama.ffi
  55. if isinstance(mparams, dict):
  56. mparams = libllama.model_default_params(**mparams)
  57. self.model = self.lib.llama_load_model_from_file(path_model.encode(), mparams)
  58. if not self.model:
  59. raise RuntimeError("error: failed to load model '%s'" % path_model)
  60. if isinstance(cparams, dict):
  61. cparams = libllama.context_default_params(**cparams)
  62. self.ctx = self.lib.llama_new_context_with_model(self.model, cparams)
  63. if not self.ctx:
  64. raise RuntimeError("error: failed to create context for model '%s'" % path_model)
  65. n_tokens_max = self.lib.llama_n_ctx(self.ctx)
  66. self.token_ids = self.ffi.new("llama_token[]", n_tokens_max)
  67. def free(self):
  68. if self.ctx:
  69. self.lib.llama_free(self.ctx)
  70. if self.model:
  71. self.lib.llama_free_model(self.model)
  72. self.ctx = None
  73. self.model = None
  74. self.lib = None
  75. def tokenize(self, text: str, n_tokens_max: int = 0, add_special: bool = False, parse_special: bool = False) -> list[int]:
  76. n_tokens_max = n_tokens_max if n_tokens_max > 0 else len(self.token_ids)
  77. text = text.encode("utf-8")
  78. num = self.lib.llama_tokenize(self.model, text, len(text), self.token_ids, n_tokens_max, add_special, parse_special)
  79. if num < 0:
  80. return []
  81. return list(self.token_ids[0:num])
  82. def generator_custom_text() -> Iterator[str]:
  83. """General tests"""
  84. yield from [
  85. "",
  86. " ",
  87. " ",
  88. " ",
  89. "\t",
  90. "\n",
  91. "\n\n",
  92. "\n\n\n",
  93. "\t\n",
  94. "Hello world",
  95. " Hello world",
  96. "Hello World",
  97. " Hello World",
  98. " Hello World!",
  99. "Hello, world!",
  100. " Hello, world!",
  101. " this is 🦙.cpp",
  102. "w048 7tuijk dsdfhu",
  103. "нещо на Български",
  104. "កាន់តែពិសេសអាចខលចេញ",
  105. "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)",
  106. "Hello",
  107. " Hello",
  108. " Hello",
  109. " Hello",
  110. " Hello",
  111. " Hello\n Hello",
  112. " (",
  113. "\n =",
  114. "' era",
  115. "Hello, y'all! How are you 😁 ?我想在apple工作1314151天~",
  116. "3",
  117. "33",
  118. "333",
  119. "3333",
  120. "33333",
  121. "333333",
  122. "3333333",
  123. "33333333",
  124. "333333333",
  125. ]
  126. def generator_custom_text_edge_cases() -> Iterator[str]:
  127. """Edge cases found while debugging"""
  128. yield from [
  129. '\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F}
  130. '¼-a', # unicode_ranges_digit, 0x00BC
  131. '½-a', # unicode_ranges_digit, 0x00BD
  132. '¾-a', # unicode_ranges_digit, 0x00BE
  133. 'a 〇b', # unicode_ranges_digit, 0x3007
  134. 'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms
  135. '\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM)
  136. 'Cửa Việt', # llama-3, ignore_merges = true
  137. '<s>a', # Phi-3 fail
  138. '<unk><|endoftext|><s>', # Phi-3 fail
  139. 'a\na', # bert fail
  140. '"`', # falcon
  141. ' \u2e4e', # falcon
  142. 'a\xa0\xa0\x00b', # jina-v2-es
  143. 'one <mask>', # jina-v2-es <mask> lstrip=true
  144. 'a </s> b', # rstrip phi-3
  145. 'a <mask> b', # lstrip jina-v2
  146. '\xa0aC', # deepseek
  147. ]
  148. def generator_vocab_words(vocab: list[str]) -> Iterator[str]:
  149. """Brute force check all vocab words"""
  150. yield from vocab
  151. def generator_added_lr_strip(tokenizer) -> Iterator[str]:
  152. WHITESPACES = ["", " ", " ", " "]
  153. special_tokens = list(tokenizer.all_special_tokens)
  154. added_tokens = list(tokenizer.added_tokens_encoder)
  155. all_tokens = list(sorted(set(special_tokens + added_tokens)))
  156. for token in all_tokens:
  157. for lstrip in WHITESPACES:
  158. for rstrip in WHITESPACES:
  159. yield lstrip + token + rstrip
  160. yield "a" + lstrip + token + rstrip
  161. yield lstrip + token + rstrip + "z"
  162. yield "a" + lstrip + token + rstrip + "z"
  163. def generator_random_added_tokens(tokenizer, iterations=100) -> Iterator[str]:
  164. special_tokens = list(tokenizer.all_special_tokens)
  165. added_tokens = list(tokenizer.added_tokens_encoder)
  166. separations = [" ", "\n", "\t", "-", "!", "one", "1", "<s>", "</s>"]
  167. all_tokens = list(sorted(set(special_tokens + added_tokens + separations)))
  168. rand = random.Random()
  169. for m in range(iterations):
  170. rand.seed(m)
  171. words = rand.choices(all_tokens, k=500)
  172. if words and words[0] == tokenizer.bos_token: # skip spam warning of double BOS
  173. while len(words) > 1 and words[1] == tokenizer.bos_token: # leave one starting BOS
  174. words.pop(0)
  175. if tokenizer.add_bos_token: # drop all starting BOS
  176. words.pop(0)
  177. if words and words[-1] == tokenizer.eos_token: # skip spam warning of double EOS
  178. while len(words) > 1 and words[-2] == tokenizer.eos_token: # leave one trailing EOS
  179. words.pop(-1)
  180. if tokenizer.add_bos_token: # drop all trailing EOS
  181. words.pop(-1)
  182. yield "".join(words)
  183. def generator_random_chars(iterations=100) -> Iterator[str]:
  184. """Brute force random text with simple characters"""
  185. NUM_WORDS = 400
  186. WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
  187. CHARS = list(sorted(set("""
  188. ABCDEFGHIJKLMNOPQRSTUVWXYZ
  189. abcdefghijklmnopqrstuvwxyz
  190. ÁÉÍÓÚÀÈÌÒÙÂÊÎÔÛÄËÏÖÜ
  191. áéíóúàèìòùâêîôûäëïöü
  192. .-,*/-+ª!"·$%&/()=?¿[]{}<>\\|@#~½¬~;:_
  193. """)))
  194. rand = random.Random()
  195. for m in range(iterations):
  196. rand.seed(m)
  197. text = []
  198. for _ in range(NUM_WORDS):
  199. k = rand.randint(1, 7)
  200. word = rand.choices(CHARS, k=k)
  201. word.append(rand.choice(WHITESPACES))
  202. text.append("".join(word))
  203. yield "".join(text)
  204. def generator_unicodes() -> Iterator[str]:
  205. """Iterate unicode characters"""
  206. MAX_CODEPOINTS = 0x30000 # 0x110000
  207. def _valid(cpt):
  208. if cpt >= 0x30000: # unassigned and supplement­ary
  209. return False
  210. if 0x00D800 <= cpt <= 0x00F8FF: # Surrogates
  211. return False
  212. if unicodedata.category(chr(cpt)) == "Cn":
  213. return False
  214. return True
  215. characters = [chr(cpt) for cpt in range(1, MAX_CODEPOINTS) if _valid(cpt)]
  216. yield from characters
  217. def generator_random_unicodes(iterations=100) -> Iterator[str]:
  218. """Brute force random text with unicode characters"""
  219. NUM_WORDS = 200
  220. WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
  221. characters = list(generator_unicodes())
  222. rand = random.Random()
  223. for m in range(iterations):
  224. rand.seed(m)
  225. text = []
  226. for _ in range(NUM_WORDS):
  227. k = rand.randint(1, 7)
  228. word = rand.choices(characters, k=k)
  229. word.append(rand.choice(WHITESPACES))
  230. text.append("".join(word))
  231. yield "".join(text)
  232. def generator_random_vocab_chars(vocab: list[str], iterations=100) -> Iterator[str]:
  233. """Brute force random text with vocab characters"""
  234. vocab_chars = set()
  235. for word in vocab:
  236. vocab_chars.update(word)
  237. vocab_chars = list(sorted(vocab_chars))
  238. rand = random.Random()
  239. for m in range(iterations):
  240. rand.seed(m)
  241. text = rand.choices(vocab_chars, k=1024)
  242. yield "".join(text)
  243. def generator_random_vocab_words(vocab: list[str], iterations=100) -> Iterator[str]:
  244. """Brute force random text from vocab words"""
  245. vocab = [w.strip() for w in vocab]
  246. yield from vocab
  247. rand = random.Random()
  248. for m in range(iterations):
  249. rand.seed(m)
  250. text = []
  251. num_words = rand.randint(300, 400)
  252. for i in range(num_words):
  253. k = rand.randint(1, 3)
  254. words = rand.choices(vocab, k=k)
  255. sep = rand.choice(" \n\r\t")
  256. text.append("".join(words) + sep)
  257. yield "".join(text)
  258. def compare_tokenizers(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
  259. def find_first_mismatch(ids1: list[int], ids2: list[int]):
  260. for i, (a, b) in enumerate(zip(ids1, ids2)):
  261. if a != b:
  262. return i
  263. if len(ids1) == len(ids2):
  264. return -1
  265. return min(len(ids1), len(ids2))
  266. t_tokenizer1 = 0
  267. t_tokenizer2 = 0
  268. t_start = time.perf_counter()
  269. num_errors = 10
  270. logger.info("%s: %s" % (generator.__name__, "ini"))
  271. for text in generator:
  272. # print(repr(text), hex(ord(text[0])), text.encode())
  273. t0 = time.perf_counter()
  274. ids1 = func_tokenize1(text)
  275. t1 = time.perf_counter()
  276. ids2 = func_tokenize2(text)
  277. t2 = time.perf_counter()
  278. t_tokenizer1 += t1 - t0
  279. t_tokenizer2 += t2 - t1
  280. if ids1 != ids2:
  281. i = find_first_mismatch(ids1, ids2)
  282. ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
  283. ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
  284. logger.error(" TokenIDs: " + str(ids1))
  285. logger.error(" Expected: " + str(ids2))
  286. # raise Exception()
  287. num_errors += 1
  288. if num_errors > 10:
  289. break
  290. t_total = time.perf_counter() - t_start
  291. logger.info("%s: end, tok1: %.3f tok2: %.3f total: %.3f" % (generator.__name__, t_tokenizer1, t_tokenizer2, t_total))
  292. def main(argv: list[str] = None):
  293. parser = argparse.ArgumentParser()
  294. parser.add_argument("vocab_file", help="path to vocab 'gguf' file")
  295. parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
  296. parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
  297. args = parser.parse_args(argv)
  298. logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO)
  299. logger.info(f"VOCABFILE: '{args.vocab_file}'")
  300. model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
  301. tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
  302. def func_tokenize1(text: str):
  303. return model.tokenize(text, add_special=True, parse_special=True)
  304. def func_tokenize2(text: str):
  305. return tokenizer.encode(text, add_special_tokens=True)
  306. ids = func_tokenize2("a")
  307. assert 1 <= len(ids) <= 3
  308. add_bos_token = len(ids) > 1 and tokenizer.bos_token_id == ids[0]
  309. add_eos_token = len(ids) > 1 and tokenizer.eos_token_id == ids[-1]
  310. tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", add_bos_token)
  311. tokenizer.add_eos_token = getattr(tokenizer, "add_eos_token", add_eos_token)
  312. vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
  313. compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text())
  314. compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
  315. compare_tokenizers(func_tokenize1, func_tokenize2, generator_unicodes())
  316. compare_tokenizers(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
  317. compare_tokenizers(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer))
  318. compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000))
  319. compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
  320. compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_unicodes(10_000))
  321. compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
  322. compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 5_000))
  323. model.free()
  324. if __name__ == "__main__":
  325. # main()
  326. logging.basicConfig(
  327. level = logging.DEBUG,
  328. format = "%(asctime)s.%(msecs)03d %(name)s %(levelname)s %(message)s",
  329. datefmt = "%Y-%m-%d %H:%M:%S",
  330. filename = logger.name + ".log",
  331. filemode = "a"
  332. )
  333. path_tokenizers = "./models/tokenizers/"
  334. path_vocab_format = "./models/ggml-vocab-%s.gguf"
  335. # import os
  336. # tokenizers = os.listdir(path_tokenizers)
  337. tokenizers = [
  338. # "llama-spm", # SPM
  339. # "phi-3", # SPM
  340. # "bert-bge", # WPM
  341. # "jina-v2-en", # WPM
  342. "gpt-2", # BPE
  343. "llama-bpe", # BPE
  344. "falcon", # BPE
  345. "starcoder", # BPE
  346. "jina-v2-es", # BPE
  347. "jina-v2-de", # BPE
  348. "jina-v2-code", # BPE
  349. "smaug-bpe", # BPE
  350. "phi-2", # BPE
  351. "deepseek-coder", # BPE
  352. "deepseek-llm", # BPE
  353. ]
  354. for tokenizer in tokenizers:
  355. logger.info("=" * 50)
  356. logger.info(f"TOKENIZER: '{tokenizer}'")
  357. vocab_file = path_vocab_format % tokenizer
  358. dir_tokenizer = path_tokenizers + "/" + tokenizer
  359. main([vocab_file, dir_tokenizer, "--verbose"])