tokenize.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. #include "common.h"
  2. //#include "log.h" // TODO: start using log.h
  3. #include "llama.h"
  4. #include <cstdio>
  5. #include <cstring>
  6. #include <fstream>
  7. #include <string>
  8. #include <vector>
  9. #include <iostream> // TODO: remove me
  10. #if defined(_WIN32)
  11. #define WIN32_LEAN_AND_MEAN
  12. #include <windows.h>
  13. #include <shellapi.h> // For CommandLineToArgvW
  14. #endif
  15. static void print_usage_information(const char * argv0) {
  16. printf("usage: %s [options]\n\n", argv0);
  17. printf("The tokenize program tokenizes a prompt using a given model,\n");
  18. printf("and prints the resulting tokens to standard output.\n\n");
  19. printf("It needs a model file, a prompt, and optionally other flags\n");
  20. printf("to control the behavior of the tokenizer.\n\n");
  21. printf(" The possible options are:\n");
  22. printf("\n");
  23. printf(" -h, --help print this help and exit\n");
  24. printf(" -m MODEL_PATH, --model MODEL_PATH path to model.\n");
  25. printf(" --ids if given, only print numerical token IDs, and not token strings.\n");
  26. printf(" The output format looks like [1, 2, 3], i.e. parseable by Python.\n");
  27. printf(" -f PROMPT_FNAME, --file PROMPT_FNAME read prompt from a file.\n");
  28. printf(" -p PROMPT, --prompt PROMPT read prompt from the argument.\n");
  29. printf(" --stdin read prompt from standard input.\n");
  30. printf(" --no-bos do not ever add a BOS token to the prompt, even if normally the model uses a BOS token.\n");
  31. printf(" --no-escape do not escape input (such as \\n, \\t, etc.).\n");
  32. printf(" --no-parse-special do not parse control tokens.\n");
  33. printf(" --log-disable disable logs. Makes stderr quiet when loading the model.\n");
  34. printf(" --show-count print the total number of tokens.\n");
  35. }
  36. static void llama_log_callback_null(ggml_log_level level, const char * text, void * user_data) {
  37. (void) level;
  38. (void) text;
  39. (void) user_data;
  40. }
  41. static std::string read_prompt_from_file(const char * filepath, bool & success) {
  42. success = false;
  43. std::ifstream in(filepath, std::ios::binary);
  44. if (!in) {
  45. fprintf(stderr, "%s: could not open file '%s' for reading: %s\n", __func__, filepath, strerror(errno));
  46. return std::string();
  47. }
  48. // do not assume the file is seekable (e.g. /dev/stdin)
  49. std::stringstream buffer;
  50. buffer << in.rdbuf();
  51. if (in.fail()) {
  52. fprintf(stderr, "%s: could not read the entire file '%s': %s\n", __func__, filepath, strerror(errno));
  53. return std::string();
  54. }
  55. success = true;
  56. return buffer.str();
  57. }
  58. //
  59. // Function: ingest_args(...) -> vector<string>
  60. //
  61. // Takes argc and argv arguments, and converts them to a vector of UTF-8 encoded
  62. // strings, as an STL vector<string>.
  63. //
  64. // In particular, it handles character encoding shenanigans on Windows.
  65. //
  66. // Note: raw_argc and raw_argv are not actually read at all on Windows.
  67. // On Windows we call GetCommandLineW to get the arguments in wchar_t
  68. // format, ignoring the regular argc/argv arguments to main().
  69. //
  70. // TODO: potential opportunity to roll common stuff into common/console.cpp
  71. // in relation to Windows wchar_t shenanigans.
  72. static std::vector<std::string> ingest_args(int raw_argc, char ** raw_argv) {
  73. std::vector<std::string> argv;
  74. // Handle Windows, if given non-ASCII arguments.
  75. // We convert wchar_t arguments into UTF-8 char* on this platform.
  76. // Lets you invoke 'tokenize' on Windows cmd.exe with non-ASCII characters
  77. // without throwing tantrums.
  78. #if defined(_WIN32)
  79. int argc;
  80. const LPWSTR cmdline_wargv = GetCommandLineW();
  81. LPWSTR * wargv = CommandLineToArgvW(cmdline_wargv, &argc);
  82. // silence unused arg warnings
  83. (void) raw_argc;
  84. (void) raw_argv;
  85. for (int i = 0; i < argc; ++i) {
  86. int length_needed = WideCharToMultiByte(CP_UTF8, 0, wargv[i], wcslen(wargv[i]), 0, 0, NULL, NULL);
  87. char * output_buf = (char *) calloc(length_needed+1, sizeof(char));
  88. GGML_ASSERT(output_buf);
  89. WideCharToMultiByte(CP_UTF8, 0, wargv[i], wcslen(wargv[i]), output_buf, length_needed, NULL, NULL);
  90. output_buf[length_needed] = '\0';
  91. argv.push_back(output_buf);
  92. free(output_buf);
  93. }
  94. LocalFree((HLOCAL) wargv);
  95. #else
  96. int argc = raw_argc;
  97. for (int i = 0; i < argc; ++i) {
  98. argv.push_back(raw_argv[i]);
  99. }
  100. #endif
  101. GGML_ASSERT((unsigned int) argc == argv.size());
  102. return argv;
  103. }
  104. //
  105. // Function: write_utf8_cstr_to_stdout(const char *) -> <writes to stdout>
  106. //
  107. // writes a string to standard output; taking into account that on Windows
  108. // to display correctly you have to use special handling. Works even if the
  109. // user has not set a unicode code page on a Windows cmd.exe.
  110. //
  111. // In case of invalid UTF-8, invalid_utf8 is set to true on Windows, and something
  112. // a human-readable is written instead.
  113. //
  114. // On non-Windows systems, simply printfs() the string.
  115. static void write_utf8_cstr_to_stdout(const char * str, bool & invalid_utf8) {
  116. invalid_utf8 = false;
  117. #if defined(_WIN32)
  118. // Are we in a console?
  119. HANDLE hConsole = GetStdHandle(STD_OUTPUT_HANDLE);
  120. DWORD dwMode = 0;
  121. // According to Microsoft docs:
  122. // "WriteConsole fails if it is used with a standard handle that is redirected to a file."
  123. // Also according to the docs, you can use GetConsoleMode to check for that.
  124. if (hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(hConsole, &dwMode)) {
  125. printf("%s", str);
  126. return;
  127. }
  128. // MultiByteToWideChar reports an error if str is empty, don't report
  129. // them as invalid_utf8.
  130. if (*str == 0) {
  131. return;
  132. }
  133. int length_needed = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, str, strlen(str), NULL, 0);
  134. if (length_needed == 0) {
  135. DWORD err = GetLastError();
  136. if (err == ERROR_NO_UNICODE_TRANSLATION) {
  137. invalid_utf8 = true;
  138. int len = strlen(str);
  139. printf("<");
  140. for (int i = 0; i < len; ++i) {
  141. if (i > 0) {
  142. printf(" ");
  143. }
  144. printf("%02x", (uint8_t) str[i]);
  145. }
  146. printf(">");
  147. return;
  148. }
  149. GGML_ABORT("MultiByteToWideChar() failed in an unexpected way.");
  150. }
  151. LPWSTR wstr = (LPWSTR) calloc(length_needed+1, sizeof(*wstr));
  152. GGML_ASSERT(wstr);
  153. MultiByteToWideChar(CP_UTF8, 0, str, strlen(str), wstr, length_needed);
  154. WriteConsoleW(hConsole, wstr, length_needed, NULL, NULL);
  155. free(wstr);
  156. #else
  157. // TODO: reporting invalid_utf8 would be useful on non-Windows too.
  158. // printf will silently just write bad unicode.
  159. printf("%s", str);
  160. #endif
  161. }
  162. int main(int raw_argc, char ** raw_argv) {
  163. const std::vector<std::string> argv = ingest_args(raw_argc, raw_argv);
  164. const int argc = argv.size();
  165. if (argc <= 1) {
  166. print_usage_information(argv[0].c_str());
  167. return 1;
  168. }
  169. //////
  170. // Read out all the command line arguments.
  171. //////
  172. // variables where to put any arguments we see.
  173. bool printing_ids = false;
  174. bool no_bos = false;
  175. bool no_escape = false;
  176. bool no_parse_special = false;
  177. bool disable_logging = false;
  178. bool show_token_count = false;
  179. const char * model_path = NULL;
  180. const char * prompt_path = NULL;
  181. const char * prompt_arg = NULL;
  182. // track which arguments were explicitly given
  183. // used for sanity checking down the line
  184. bool model_path_set = false;
  185. bool prompt_path_set = false;
  186. bool prompt_set = false;
  187. bool stdin_set = false;
  188. int iarg = 1;
  189. for (; iarg < argc; ++iarg) {
  190. std::string arg{argv[iarg]};
  191. if (arg == "-h" || arg == "--help") {
  192. print_usage_information(argv[0].c_str());
  193. return 0;
  194. }
  195. else if (arg == "--ids") {
  196. printing_ids = true;
  197. }
  198. else if (arg == "-m" || arg == "--model") {
  199. if (model_path_set) {
  200. fprintf(stderr, "Error: -m or --model specified multiple times.\n");
  201. return 1;
  202. }
  203. model_path = argv[++iarg].c_str();
  204. model_path_set = true;
  205. }
  206. else if (arg == "--no-bos") {
  207. no_bos = true;
  208. }
  209. else if (arg == "--no-escape") {
  210. no_escape = true;
  211. }
  212. else if (arg == "--no-parse-special") {
  213. no_parse_special = true;
  214. }
  215. else if (arg == "-p" || arg == "--prompt") {
  216. if (prompt_set) {
  217. fprintf(stderr, "Error: -p or --prompt specified multiple times.\n");
  218. return 1;
  219. }
  220. prompt_arg = argv[++iarg].c_str();
  221. prompt_set = true;
  222. }
  223. else if (arg == "-f" || arg == "--file") {
  224. if (prompt_path_set) {
  225. fprintf(stderr, "Error: -f or --file specified multiple times.\n");
  226. return 1;
  227. }
  228. prompt_path = argv[++iarg].c_str();
  229. prompt_path_set = true;
  230. }
  231. else if (arg == "--stdin") {
  232. stdin_set = true;
  233. }
  234. else if (arg == "--log-disable") {
  235. disable_logging = true;
  236. }
  237. else if (arg == "--show-count") {
  238. show_token_count = true;
  239. }
  240. else {
  241. fprintf(stderr, "Error: unknown option '%s'\n", argv[iarg].c_str());
  242. return 1;
  243. }
  244. }
  245. //////
  246. // Sanity check the command line arguments.
  247. //////
  248. // Check that we have the required stuff set.
  249. if (model_path_set && model_path == NULL) {
  250. fprintf(stderr, "Error: --model requires an argument.\n");
  251. return 1;
  252. }
  253. if (!model_path_set) {
  254. fprintf(stderr, "Error: must specify --model.\n");
  255. return 1;
  256. }
  257. if (prompt_path_set && prompt_path == NULL) {
  258. fprintf(stderr, "Error: --file requires an argument.\n");
  259. return 1;
  260. }
  261. if (prompt_set && prompt_arg == NULL) {
  262. fprintf(stderr, "Error: --prompt requires an argument.\n");
  263. return 1;
  264. }
  265. const int prompts_set = !!(prompt_path_set) + !!(prompt_set) + !!(stdin_set);
  266. if (prompts_set > 1) {
  267. fprintf(stderr, "Error: --stdin, --file and --prompt are mutually exclusive.\n");
  268. return 1;
  269. }
  270. // Must have some prompt.
  271. if (prompts_set == 0) {
  272. fprintf(stderr, "Error: must specify one of: --stdin, --file or --prompt.\n");
  273. return 1;
  274. }
  275. GGML_ASSERT(model_path);
  276. GGML_ASSERT(prompt_path || prompt_arg || stdin_set);
  277. //////
  278. // Figure out where will the prompt come from.
  279. //////
  280. std::string prompt;
  281. if (prompt_path_set) {
  282. bool success = false;
  283. prompt = read_prompt_from_file(prompt_path, success);
  284. if (!success) {
  285. return 1;
  286. }
  287. } else if (prompt_set) {
  288. prompt = prompt_arg;
  289. } else {
  290. GGML_ASSERT(stdin_set);
  291. // we read stdin *after* loading model (early exit if model cannot
  292. // be loaded, which can be a nicer user experience)
  293. }
  294. //////
  295. // Start actually doing the tokenizing stuff.
  296. //////
  297. if (disable_logging) {
  298. llama_log_set(llama_log_callback_null, NULL);
  299. }
  300. llama_backend_init();
  301. llama_model_params model_params = llama_model_default_params();
  302. model_params.vocab_only = true;
  303. llama_model * model = llama_model_load_from_file(model_path, model_params);
  304. if (!model) {
  305. fprintf(stderr, "Error: could not load model from file '%s'.\n", model_path);
  306. return 1;
  307. }
  308. const llama_vocab * vocab = llama_model_get_vocab(model);
  309. llama_context_params ctx_params = llama_context_default_params();
  310. llama_context * ctx = llama_init_from_model(model, ctx_params);
  311. if (!ctx) {
  312. fprintf(stderr, "Error: could not create context.\n");
  313. return 1;
  314. }
  315. // read entire prompt from stdin?
  316. if (stdin_set) {
  317. GGML_ASSERT(!prompt_path_set && !prompt_set);
  318. std::stringstream stdin_buffer;
  319. stdin_buffer << std::cin.rdbuf();
  320. if (std::cin.fail()) {
  321. fprintf(stderr, "Error: could not read the entire standard input.\n");
  322. return 1;
  323. }
  324. prompt = stdin_buffer.str();
  325. }
  326. const bool model_wants_add_bos = llama_vocab_get_add_bos(vocab);
  327. const bool add_bos = model_wants_add_bos && !no_bos;
  328. const bool parse_special = !no_parse_special;
  329. const bool escape = !no_escape;
  330. if (escape) {
  331. string_process_escapes(prompt);
  332. }
  333. std::vector<llama_token> tokens;
  334. tokens = common_tokenize(vocab, prompt, add_bos, parse_special);
  335. if (printing_ids) {
  336. printf("[");
  337. }
  338. for (int i = 0; i < (int) tokens.size(); i++) {
  339. if (printing_ids) {
  340. if (i > 0) {
  341. printf(", ");
  342. }
  343. printf("%d", tokens[i]);
  344. } else {
  345. bool invalid_utf8 = false;
  346. printf("%6d -> '", tokens[i]);
  347. write_utf8_cstr_to_stdout(common_token_to_piece(ctx, tokens[i]).c_str(), invalid_utf8);
  348. if (invalid_utf8) {
  349. printf("' (utf-8 decode failure)\n");
  350. } else {
  351. printf("'\n");
  352. }
  353. }
  354. }
  355. if (printing_ids) {
  356. printf("]\n");
  357. }
  358. if (show_token_count) {
  359. printf("Total number of tokens: %zu\n", tokens.size());
  360. }
  361. // silence valgrind
  362. llama_free(ctx);
  363. llama_model_free(model);
  364. return 0;
  365. }