cli.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. #include "common.h"
  2. #include "arg.h"
  3. #include "console.h"
  4. // #include "log.h"
  5. #include "server-context.h"
  6. #include "server-task.h"
  7. #include <atomic>
  8. #include <fstream>
  9. #include <thread>
  10. #include <signal.h>
  11. #if defined(_WIN32)
  12. #define WIN32_LEAN_AND_MEAN
  13. #ifndef NOMINMAX
  14. # define NOMINMAX
  15. #endif
  16. #include <windows.h>
  17. #endif
  18. const char * LLAMA_ASCII_LOGO = R"(
  19. ▄▄ ▄▄
  20. ██ ██
  21. ██ ██ ▀▀█▄ ███▄███▄ ▀▀█▄ ▄████ ████▄ ████▄
  22. ██ ██ ▄█▀██ ██ ██ ██ ▄█▀██ ██ ██ ██ ██ ██
  23. ██ ██ ▀█▄██ ██ ██ ██ ▀█▄██ ██ ▀████ ████▀ ████▀
  24. ██ ██
  25. ▀▀ ▀▀
  26. )";
  27. static std::atomic<bool> g_is_interrupted = false;
  28. static bool should_stop() {
  29. return g_is_interrupted.load();
  30. }
  31. #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
  32. static void signal_handler(int) {
  33. if (g_is_interrupted.load()) {
  34. // second Ctrl+C - exit immediately
  35. // make sure to clear colors before exiting (not using LOG or console.cpp here to avoid deadlock)
  36. fprintf(stdout, "\033[0m\n");
  37. fflush(stdout);
  38. std::exit(130);
  39. }
  40. g_is_interrupted.store(true);
  41. }
  42. #endif
  43. struct cli_context {
  44. server_context ctx_server;
  45. json messages = json::array();
  46. std::vector<raw_buffer> input_files;
  47. task_params defaults;
  48. // thread for showing "loading" animation
  49. std::atomic<bool> loading_show;
  50. cli_context(const common_params & params) {
  51. defaults.sampling = params.sampling;
  52. defaults.speculative = params.speculative;
  53. defaults.n_keep = params.n_keep;
  54. defaults.n_predict = params.n_predict;
  55. defaults.antiprompt = params.antiprompt;
  56. defaults.stream = true; // make sure we always use streaming mode
  57. defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way
  58. // defaults.return_progress = true; // TODO: show progress
  59. }
  60. std::string generate_completion(result_timings & out_timings) {
  61. server_response_reader rd = ctx_server.get_response_reader();
  62. auto chat_params = format_chat();
  63. {
  64. // TODO: reduce some copies here in the future
  65. server_task task = server_task(SERVER_TASK_TYPE_COMPLETION);
  66. task.id = rd.get_new_id();
  67. task.index = 0;
  68. task.params = defaults; // copy
  69. task.cli_prompt = chat_params.prompt; // copy
  70. task.cli_files = input_files; // copy
  71. task.cli = true;
  72. // chat template settings
  73. task.params.chat_parser_params = common_chat_parser_params(chat_params);
  74. task.params.chat_parser_params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
  75. if (!chat_params.parser.empty()) {
  76. task.params.chat_parser_params.parser.load(chat_params.parser);
  77. }
  78. rd.post_task({std::move(task)});
  79. }
  80. // wait for first result
  81. console::spinner::start();
  82. server_task_result_ptr result = rd.next(should_stop);
  83. console::spinner::stop();
  84. std::string curr_content;
  85. bool is_thinking = false;
  86. while (result) {
  87. if (should_stop()) {
  88. break;
  89. }
  90. if (result->is_error()) {
  91. json err_data = result->to_json();
  92. if (err_data.contains("message")) {
  93. console::error("Error: %s\n", err_data["message"].get<std::string>().c_str());
  94. } else {
  95. console::error("Error: %s\n", err_data.dump().c_str());
  96. }
  97. return curr_content;
  98. }
  99. auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
  100. if (res_partial) {
  101. out_timings = std::move(res_partial->timings);
  102. for (const auto & diff : res_partial->oaicompat_msg_diffs) {
  103. if (!diff.content_delta.empty()) {
  104. if (is_thinking) {
  105. console::log("\n[End thinking]\n\n");
  106. console::set_display(DISPLAY_TYPE_RESET);
  107. is_thinking = false;
  108. }
  109. curr_content += diff.content_delta;
  110. console::log("%s", diff.content_delta.c_str());
  111. console::flush();
  112. }
  113. if (!diff.reasoning_content_delta.empty()) {
  114. console::set_display(DISPLAY_TYPE_REASONING);
  115. if (!is_thinking) {
  116. console::log("[Start thinking]\n");
  117. }
  118. is_thinking = true;
  119. console::log("%s", diff.reasoning_content_delta.c_str());
  120. console::flush();
  121. }
  122. }
  123. }
  124. auto res_final = dynamic_cast<server_task_result_cmpl_final *>(result.get());
  125. if (res_final) {
  126. out_timings = std::move(res_final->timings);
  127. break;
  128. }
  129. result = rd.next(should_stop);
  130. }
  131. g_is_interrupted.store(false);
  132. // server_response_reader automatically cancels pending tasks upon destruction
  133. return curr_content;
  134. }
  135. // TODO: support remote files in the future (http, https, etc)
  136. std::string load_input_file(const std::string & fname, bool is_media) {
  137. std::ifstream file(fname, std::ios::binary);
  138. if (!file) {
  139. return "";
  140. }
  141. if (is_media) {
  142. raw_buffer buf;
  143. buf.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
  144. input_files.push_back(std::move(buf));
  145. return mtmd_default_marker();
  146. } else {
  147. std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
  148. return content;
  149. }
  150. }
  151. common_chat_params format_chat() {
  152. auto meta = ctx_server.get_meta();
  153. auto & chat_params = meta.chat_params;
  154. common_chat_templates_inputs inputs;
  155. inputs.messages = common_chat_msgs_parse_oaicompat(messages);
  156. inputs.tools = {}; // TODO
  157. inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
  158. inputs.json_schema = ""; // TODO
  159. inputs.grammar = ""; // TODO
  160. inputs.use_jinja = chat_params.use_jinja;
  161. inputs.parallel_tool_calls = false;
  162. inputs.add_generation_prompt = true;
  163. inputs.enable_thinking = chat_params.enable_thinking;
  164. // Apply chat template to the list of messages
  165. return common_chat_templates_apply(chat_params.tmpls.get(), inputs);
  166. }
  167. };
  168. int main(int argc, char ** argv) {
  169. common_params params;
  170. params.verbosity = LOG_LEVEL_ERROR; // by default, less verbose logs
  171. if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_CLI)) {
  172. return 1;
  173. }
  174. // TODO: maybe support it later?
  175. if (params.conversation_mode == COMMON_CONVERSATION_MODE_DISABLED) {
  176. console::error("--no-conversation is not supported by llama-cli\n");
  177. console::error("please use llama-completion instead\n");
  178. }
  179. common_init();
  180. // struct that contains llama context and inference
  181. cli_context ctx_cli(params);
  182. llama_backend_init();
  183. llama_numa_init(params.numa);
  184. // TODO: avoid using atexit() here by making `console` a singleton
  185. console::init(params.simple_io, params.use_color);
  186. atexit([]() { console::cleanup(); });
  187. console::set_display(DISPLAY_TYPE_RESET);
  188. #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
  189. struct sigaction sigint_action;
  190. sigint_action.sa_handler = signal_handler;
  191. sigemptyset (&sigint_action.sa_mask);
  192. sigint_action.sa_flags = 0;
  193. sigaction(SIGINT, &sigint_action, NULL);
  194. sigaction(SIGTERM, &sigint_action, NULL);
  195. #elif defined (_WIN32)
  196. auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
  197. return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
  198. };
  199. SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
  200. #endif
  201. console::log("\nLoading model... "); // followed by loading animation
  202. console::spinner::start();
  203. if (!ctx_cli.ctx_server.load_model(params)) {
  204. console::spinner::stop();
  205. console::error("\nFailed to load the model\n");
  206. return 1;
  207. }
  208. console::spinner::stop();
  209. console::log("\n");
  210. std::thread inference_thread([&ctx_cli]() {
  211. ctx_cli.ctx_server.start_loop();
  212. });
  213. auto inf = ctx_cli.ctx_server.get_meta();
  214. std::string modalities = "text";
  215. if (inf.has_inp_image) {
  216. modalities += ", vision";
  217. }
  218. if (inf.has_inp_audio) {
  219. modalities += ", audio";
  220. }
  221. if (!params.system_prompt.empty()) {
  222. ctx_cli.messages.push_back({
  223. {"role", "system"},
  224. {"content", params.system_prompt}
  225. });
  226. }
  227. console::log("\n");
  228. console::log("%s\n", LLAMA_ASCII_LOGO);
  229. console::log("build : %s\n", inf.build_info.c_str());
  230. console::log("model : %s\n", inf.model_name.c_str());
  231. console::log("modalities : %s\n", modalities.c_str());
  232. if (!params.system_prompt.empty()) {
  233. console::log("using custom system prompt\n");
  234. }
  235. console::log("\n");
  236. console::log("available commands:\n");
  237. console::log(" /exit or Ctrl+C stop or exit\n");
  238. console::log(" /regen regenerate the last response\n");
  239. console::log(" /clear clear the chat history\n");
  240. console::log(" /read add a text file\n");
  241. if (inf.has_inp_image) {
  242. console::log(" /image <file> add an image file\n");
  243. }
  244. if (inf.has_inp_audio) {
  245. console::log(" /audio <file> add an audio file\n");
  246. }
  247. console::log("\n");
  248. // interactive loop
  249. std::string cur_msg;
  250. while (true) {
  251. std::string buffer;
  252. console::set_display(DISPLAY_TYPE_USER_INPUT);
  253. if (params.prompt.empty()) {
  254. console::log("\n> ");
  255. std::string line;
  256. bool another_line = true;
  257. do {
  258. another_line = console::readline(line, params.multiline_input);
  259. buffer += line;
  260. } while (another_line);
  261. } else {
  262. // process input prompt from args
  263. for (auto & fname : params.image) {
  264. std::string marker = ctx_cli.load_input_file(fname, true);
  265. if (marker.empty()) {
  266. console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
  267. break;
  268. }
  269. console::log("Loaded media from '%s'\n", fname.c_str());
  270. cur_msg += marker;
  271. }
  272. buffer = params.prompt;
  273. if (buffer.size() > 500) {
  274. console::log("\n> %s ... (truncated)\n", buffer.substr(0, 500).c_str());
  275. } else {
  276. console::log("\n> %s\n", buffer.c_str());
  277. }
  278. params.prompt.clear(); // only use it once
  279. }
  280. console::set_display(DISPLAY_TYPE_RESET);
  281. console::log("\n");
  282. if (should_stop()) {
  283. g_is_interrupted.store(false);
  284. break;
  285. }
  286. // remove trailing newline
  287. if (!buffer.empty() &&buffer.back() == '\n') {
  288. buffer.pop_back();
  289. }
  290. // skip empty messages
  291. if (buffer.empty()) {
  292. continue;
  293. }
  294. bool add_user_msg = true;
  295. // process commands
  296. if (string_starts_with(buffer, "/exit")) {
  297. break;
  298. } else if (string_starts_with(buffer, "/regen")) {
  299. if (ctx_cli.messages.size() >= 2) {
  300. size_t last_idx = ctx_cli.messages.size() - 1;
  301. ctx_cli.messages.erase(last_idx);
  302. add_user_msg = false;
  303. } else {
  304. console::error("No message to regenerate.\n");
  305. continue;
  306. }
  307. } else if (string_starts_with(buffer, "/clear")) {
  308. ctx_cli.messages.clear();
  309. ctx_cli.input_files.clear();
  310. console::log("Chat history cleared.\n");
  311. continue;
  312. } else if (
  313. (string_starts_with(buffer, "/image ") && inf.has_inp_image) ||
  314. (string_starts_with(buffer, "/audio ") && inf.has_inp_audio)) {
  315. // just in case (bad copy-paste for example), we strip all trailing/leading spaces
  316. std::string fname = string_strip(buffer.substr(7));
  317. std::string marker = ctx_cli.load_input_file(fname, true);
  318. if (marker.empty()) {
  319. console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
  320. continue;
  321. }
  322. cur_msg += marker;
  323. console::log("Loaded media from '%s'\n", fname.c_str());
  324. continue;
  325. } else if (string_starts_with(buffer, "/read ")) {
  326. std::string fname = string_strip(buffer.substr(6));
  327. std::string marker = ctx_cli.load_input_file(fname, false);
  328. if (marker.empty()) {
  329. console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
  330. continue;
  331. }
  332. cur_msg += marker;
  333. console::log("Loaded text from '%s'\n", fname.c_str());
  334. continue;
  335. } else {
  336. // not a command
  337. cur_msg += buffer;
  338. }
  339. // generate response
  340. if (add_user_msg) {
  341. ctx_cli.messages.push_back({
  342. {"role", "user"},
  343. {"content", cur_msg}
  344. });
  345. cur_msg.clear();
  346. }
  347. result_timings timings;
  348. std::string assistant_content = ctx_cli.generate_completion(timings);
  349. ctx_cli.messages.push_back({
  350. {"role", "assistant"},
  351. {"content", assistant_content}
  352. });
  353. console::log("\n");
  354. if (params.show_timings) {
  355. console::set_display(DISPLAY_TYPE_INFO);
  356. console::log("\n");
  357. console::log("[ Prompt: %.1f t/s | Generation: %.1f t/s ]\n", timings.prompt_per_second, timings.predicted_per_second);
  358. console::set_display(DISPLAY_TYPE_RESET);
  359. }
  360. if (params.single_turn) {
  361. break;
  362. }
  363. }
  364. console::set_display(DISPLAY_TYPE_RESET);
  365. console::log("\nExiting...\n");
  366. ctx_cli.ctx_server.terminate();
  367. inference_thread.join();
  368. // bump the log level to display timings
  369. common_log_set_verbosity_thold(LOG_LEVEL_INFO);
  370. llama_memory_breakdown_print(ctx_cli.ctx_server.get_llama_context());
  371. return 0;
  372. }