test-tokenizer-1-llama.cpp 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. #include "llama.h"
  2. #include "common.h"
  3. #include "console.h"
  4. #include <cassert>
  5. #include <cstdio>
  6. #include <cstring>
  7. #include <string>
  8. #include <codecvt>
  9. #include <map>
  10. #include <vector>
  11. #include <locale>
  12. typedef int codepoint;
  13. static std::string codepoint_to_utf8(codepoint cp) {
  14. std::string result;
  15. if (0x00 <= cp && cp <= 0x7f) {
  16. result.push_back(cp);
  17. } else if (0x80 <= cp && cp <= 0x7ff) {
  18. result.push_back(0xc0 | ((cp >> 6) & 0x1f));
  19. result.push_back(0x80 | (cp & 0x3f));
  20. } else if (0x800 <= cp && cp <= 0xffff) {
  21. result.push_back(0xe0 | ((cp >> 12) & 0x0f));
  22. result.push_back(0x80 | ((cp >> 6) & 0x3f));
  23. result.push_back(0x80 | (cp & 0x3f));
  24. } else if (0x10000 <= cp && cp <= 0x10ffff) {
  25. result.push_back(0xf0 | ((cp >> 18) & 0x07));
  26. result.push_back(0x80 | ((cp >> 12) & 0x3f));
  27. result.push_back(0x80 | ((cp >> 6) & 0x3f));
  28. result.push_back(0x80 | (cp & 0x3f));
  29. } else {
  30. throw std::invalid_argument("invalid codepoint");
  31. }
  32. return result;
  33. }
  34. int main(int argc, char **argv) {
  35. if (argc < 2) {
  36. fprintf(stderr, "Usage: %s <vocab-file>\n", argv[0]);
  37. return 1;
  38. }
  39. const std::string fname = argv[1];
  40. fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
  41. llama_model * model;
  42. llama_context * ctx;
  43. llama_backend_init(false);
  44. // load the vocab
  45. {
  46. auto lparams = llama_context_default_params();
  47. lparams.vocab_only = true;
  48. model = llama_load_model_from_file(fname.c_str(), lparams);
  49. if (model == NULL) {
  50. fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
  51. return 1;
  52. }
  53. ctx = llama_new_context_with_model(model, lparams);
  54. if (ctx == NULL) {
  55. fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
  56. llama_free_model(model);
  57. return 1;
  58. }
  59. }
  60. GGML_ASSERT(llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM);
  61. #ifdef _WIN32
  62. // We need this for unicode console support
  63. console::init(false, false);
  64. atexit([]() { console::cleanup(); });
  65. #endif
  66. const int n_vocab = llama_n_vocab(ctx);
  67. for (int i = 0; i < n_vocab; ++i) {
  68. std::string str = llama_detokenize_spm(ctx, std::vector<int>(1, i));
  69. std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
  70. std::string check = llama_detokenize_spm(ctx, tokens);
  71. if (check != str) {
  72. fprintf(stderr, "%s : error: token %d detokenizes to '%s'(%zu) but tokenization of this detokenizes to '%s'(%zu)\n",
  73. __func__, i, str.c_str(), str.length(), check.c_str(), check.length());
  74. return 2;
  75. }
  76. }
  77. for (codepoint cp = 0x0000; cp < 0xffff; ++cp) {
  78. if (cp < 0xd800 || cp > 0xdfff) {
  79. std::string str = codepoint_to_utf8(cp);
  80. std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
  81. std::string check = llama_detokenize_spm(ctx, tokens);
  82. if (cp != 9601 && str != check) {
  83. fprintf(stderr, "%s : error: codepoint %d detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
  84. __func__, cp, check.c_str(), check.length(), str.c_str(), str.length());
  85. return 3;
  86. }
  87. }
  88. }
  89. for (codepoint cp = 0x10000; cp < 0x0010ffff; ++cp) {
  90. std::string str = codepoint_to_utf8(cp);
  91. std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
  92. std::string check = llama_detokenize_spm(ctx, tokens);
  93. if (str != check) {
  94. fprintf(stderr, "%s : error: codepoint %d detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
  95. __func__, cp, check.c_str(), check.length(), str.c_str(), str.length());
  96. return 4;
  97. }
  98. }
  99. llama_free_model(model);
  100. llama_free(ctx);
  101. llama_backend_free();
  102. return 0;
  103. }