cli.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. #include "common.h"
  2. #include "arg.h"
  3. #include "console.h"
  4. // #include "log.h"
  5. #include "server-context.h"
  6. #include "server-task.h"
  7. #include <atomic>
  8. #include <fstream>
  9. #include <thread>
  10. #include <signal.h>
  11. #if defined(_WIN32)
  12. #define WIN32_LEAN_AND_MEAN
  13. #ifndef NOMINMAX
  14. # define NOMINMAX
  15. #endif
  16. #include <windows.h>
  17. #endif
  18. const char * LLAMA_ASCII_LOGO = R"(
  19. ▄▄ ▄▄
  20. ██ ██
  21. ██ ██ ▀▀█▄ ███▄███▄ ▀▀█▄ ▄████ ████▄ ████▄
  22. ██ ██ ▄█▀██ ██ ██ ██ ▄█▀██ ██ ██ ██ ██ ██
  23. ██ ██ ▀█▄██ ██ ██ ██ ▀█▄██ ██ ▀████ ████▀ ████▀
  24. ██ ██
  25. ▀▀ ▀▀
  26. )";
  27. static std::atomic<bool> g_is_interrupted = false;
  28. static bool should_stop() {
  29. return g_is_interrupted.load();
  30. }
  31. #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
  32. static void signal_handler(int) {
  33. if (g_is_interrupted.load()) {
  34. // second Ctrl+C - exit immediately
  35. // make sure to clear colors before exiting (not using LOG or console.cpp here to avoid deadlock)
  36. fprintf(stdout, "\033[0m\n");
  37. fflush(stdout);
  38. std::exit(130);
  39. }
  40. g_is_interrupted.store(true);
  41. }
  42. #endif
  43. struct cli_context {
  44. server_context ctx_server;
  45. json messages = json::array();
  46. std::vector<raw_buffer> input_files;
  47. task_params defaults;
  48. // thread for showing "loading" animation
  49. std::atomic<bool> loading_show;
  50. cli_context(const common_params & params) {
  51. defaults.sampling = params.sampling;
  52. defaults.speculative = params.speculative;
  53. defaults.n_keep = params.n_keep;
  54. defaults.n_predict = params.n_predict;
  55. defaults.antiprompt = params.antiprompt;
  56. defaults.stream = true; // make sure we always use streaming mode
  57. defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way
  58. // defaults.return_progress = true; // TODO: show progress
  59. }
  60. std::string generate_completion(result_timings & out_timings) {
  61. server_response_reader rd = ctx_server.get_response_reader();
  62. auto chat_params = format_chat();
  63. {
  64. // TODO: reduce some copies here in the future
  65. server_task task = server_task(SERVER_TASK_TYPE_COMPLETION);
  66. task.id = rd.get_new_id();
  67. task.index = 0;
  68. task.params = defaults; // copy
  69. task.cli_prompt = chat_params.prompt; // copy
  70. task.cli_files = input_files; // copy
  71. task.cli = true;
  72. // chat template settings
  73. task.params.chat_parser_params = common_chat_parser_params(chat_params);
  74. task.params.chat_parser_params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
  75. rd.post_task({std::move(task)});
  76. }
  77. // wait for first result
  78. console::spinner::start();
  79. server_task_result_ptr result = rd.next(should_stop);
  80. console::spinner::stop();
  81. std::string curr_content;
  82. bool is_thinking = false;
  83. while (result) {
  84. if (should_stop()) {
  85. break;
  86. }
  87. if (result->is_error()) {
  88. json err_data = result->to_json();
  89. if (err_data.contains("message")) {
  90. console::error("Error: %s\n", err_data["message"].get<std::string>().c_str());
  91. } else {
  92. console::error("Error: %s\n", err_data.dump().c_str());
  93. }
  94. return curr_content;
  95. }
  96. auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
  97. if (res_partial) {
  98. out_timings = std::move(res_partial->timings);
  99. for (const auto & diff : res_partial->oaicompat_msg_diffs) {
  100. if (!diff.content_delta.empty()) {
  101. if (is_thinking) {
  102. console::log("\n[End thinking]\n\n");
  103. console::set_display(DISPLAY_TYPE_RESET);
  104. is_thinking = false;
  105. }
  106. curr_content += diff.content_delta;
  107. console::log("%s", diff.content_delta.c_str());
  108. console::flush();
  109. }
  110. if (!diff.reasoning_content_delta.empty()) {
  111. console::set_display(DISPLAY_TYPE_REASONING);
  112. if (!is_thinking) {
  113. console::log("[Start thinking]\n");
  114. }
  115. is_thinking = true;
  116. console::log("%s", diff.reasoning_content_delta.c_str());
  117. console::flush();
  118. }
  119. }
  120. }
  121. auto res_final = dynamic_cast<server_task_result_cmpl_final *>(result.get());
  122. if (res_final) {
  123. out_timings = std::move(res_final->timings);
  124. break;
  125. }
  126. result = rd.next(should_stop);
  127. }
  128. g_is_interrupted.store(false);
  129. // server_response_reader automatically cancels pending tasks upon destruction
  130. return curr_content;
  131. }
  132. // TODO: support remote files in the future (http, https, etc)
  133. std::string load_input_file(const std::string & fname, bool is_media) {
  134. std::ifstream file(fname, std::ios::binary);
  135. if (!file) {
  136. return "";
  137. }
  138. if (is_media) {
  139. raw_buffer buf;
  140. buf.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
  141. input_files.push_back(std::move(buf));
  142. return mtmd_default_marker();
  143. } else {
  144. std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
  145. return content;
  146. }
  147. }
  148. common_chat_params format_chat() {
  149. auto meta = ctx_server.get_meta();
  150. auto & chat_params = meta.chat_params;
  151. common_chat_templates_inputs inputs;
  152. inputs.messages = common_chat_msgs_parse_oaicompat(messages);
  153. inputs.tools = {}; // TODO
  154. inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
  155. inputs.json_schema = ""; // TODO
  156. inputs.grammar = ""; // TODO
  157. inputs.use_jinja = chat_params.use_jinja;
  158. inputs.parallel_tool_calls = false;
  159. inputs.add_generation_prompt = true;
  160. inputs.enable_thinking = chat_params.enable_thinking;
  161. // Apply chat template to the list of messages
  162. return common_chat_templates_apply(chat_params.tmpls.get(), inputs);
  163. }
  164. };
  165. int main(int argc, char ** argv) {
  166. common_params params;
  167. params.verbosity = LOG_LEVEL_ERROR; // by default, less verbose logs
  168. if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_CLI)) {
  169. return 1;
  170. }
  171. // TODO: maybe support it later?
  172. if (params.conversation_mode == COMMON_CONVERSATION_MODE_DISABLED) {
  173. console::error("--no-conversation is not supported by llama-cli\n");
  174. console::error("please use llama-completion instead\n");
  175. }
  176. common_init();
  177. // struct that contains llama context and inference
  178. cli_context ctx_cli(params);
  179. llama_backend_init();
  180. llama_numa_init(params.numa);
  181. // TODO: avoid using atexit() here by making `console` a singleton
  182. console::init(params.simple_io, params.use_color);
  183. atexit([]() { console::cleanup(); });
  184. console::set_display(DISPLAY_TYPE_RESET);
  185. #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
  186. struct sigaction sigint_action;
  187. sigint_action.sa_handler = signal_handler;
  188. sigemptyset (&sigint_action.sa_mask);
  189. sigint_action.sa_flags = 0;
  190. sigaction(SIGINT, &sigint_action, NULL);
  191. sigaction(SIGTERM, &sigint_action, NULL);
  192. #elif defined (_WIN32)
  193. auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
  194. return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
  195. };
  196. SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
  197. #endif
  198. console::log("\nLoading model... "); // followed by loading animation
  199. console::spinner::start();
  200. if (!ctx_cli.ctx_server.load_model(params)) {
  201. console::spinner::stop();
  202. console::error("\nFailed to load the model\n");
  203. return 1;
  204. }
  205. console::spinner::stop();
  206. console::log("\n");
  207. std::thread inference_thread([&ctx_cli]() {
  208. ctx_cli.ctx_server.start_loop();
  209. });
  210. auto inf = ctx_cli.ctx_server.get_meta();
  211. std::string modalities = "text";
  212. if (inf.has_inp_image) {
  213. modalities += ", vision";
  214. }
  215. if (inf.has_inp_audio) {
  216. modalities += ", audio";
  217. }
  218. if (!params.system_prompt.empty()) {
  219. ctx_cli.messages.push_back({
  220. {"role", "system"},
  221. {"content", params.system_prompt}
  222. });
  223. }
  224. console::log("\n");
  225. console::log("%s\n", LLAMA_ASCII_LOGO);
  226. console::log("build : %s\n", inf.build_info.c_str());
  227. console::log("model : %s\n", inf.model_name.c_str());
  228. console::log("modalities : %s\n", modalities.c_str());
  229. if (!params.system_prompt.empty()) {
  230. console::log("using custom system prompt\n");
  231. }
  232. console::log("\n");
  233. console::log("available commands:\n");
  234. console::log(" /exit or Ctrl+C stop or exit\n");
  235. console::log(" /regen regenerate the last response\n");
  236. console::log(" /clear clear the chat history\n");
  237. console::log(" /read add a text file\n");
  238. if (inf.has_inp_image) {
  239. console::log(" /image <file> add an image file\n");
  240. }
  241. if (inf.has_inp_audio) {
  242. console::log(" /audio <file> add an audio file\n");
  243. }
  244. console::log("\n");
  245. // interactive loop
  246. std::string cur_msg;
  247. while (true) {
  248. std::string buffer;
  249. console::set_display(DISPLAY_TYPE_USER_INPUT);
  250. if (params.prompt.empty()) {
  251. console::log("\n> ");
  252. std::string line;
  253. bool another_line = true;
  254. do {
  255. another_line = console::readline(line, params.multiline_input);
  256. buffer += line;
  257. } while (another_line);
  258. } else {
  259. // process input prompt from args
  260. for (auto & fname : params.image) {
  261. std::string marker = ctx_cli.load_input_file(fname, true);
  262. if (marker.empty()) {
  263. console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
  264. break;
  265. }
  266. console::log("Loaded media from '%s'\n", fname.c_str());
  267. cur_msg += marker;
  268. }
  269. buffer = params.prompt;
  270. if (buffer.size() > 500) {
  271. console::log("\n> %s ... (truncated)\n", buffer.substr(0, 500).c_str());
  272. } else {
  273. console::log("\n> %s\n", buffer.c_str());
  274. }
  275. params.prompt.clear(); // only use it once
  276. }
  277. console::set_display(DISPLAY_TYPE_RESET);
  278. console::log("\n");
  279. if (should_stop()) {
  280. g_is_interrupted.store(false);
  281. break;
  282. }
  283. // remove trailing newline
  284. if (!buffer.empty() &&buffer.back() == '\n') {
  285. buffer.pop_back();
  286. }
  287. // skip empty messages
  288. if (buffer.empty()) {
  289. continue;
  290. }
  291. bool add_user_msg = true;
  292. // process commands
  293. if (string_starts_with(buffer, "/exit")) {
  294. break;
  295. } else if (string_starts_with(buffer, "/regen")) {
  296. if (ctx_cli.messages.size() >= 2) {
  297. size_t last_idx = ctx_cli.messages.size() - 1;
  298. ctx_cli.messages.erase(last_idx);
  299. add_user_msg = false;
  300. } else {
  301. console::error("No message to regenerate.\n");
  302. continue;
  303. }
  304. } else if (string_starts_with(buffer, "/clear")) {
  305. ctx_cli.messages.clear();
  306. ctx_cli.input_files.clear();
  307. console::log("Chat history cleared.\n");
  308. continue;
  309. } else if (
  310. (string_starts_with(buffer, "/image ") && inf.has_inp_image) ||
  311. (string_starts_with(buffer, "/audio ") && inf.has_inp_audio)) {
  312. // just in case (bad copy-paste for example), we strip all trailing/leading spaces
  313. std::string fname = string_strip(buffer.substr(7));
  314. std::string marker = ctx_cli.load_input_file(fname, true);
  315. if (marker.empty()) {
  316. console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
  317. continue;
  318. }
  319. cur_msg += marker;
  320. console::log("Loaded media from '%s'\n", fname.c_str());
  321. continue;
  322. } else if (string_starts_with(buffer, "/read ")) {
  323. std::string fname = string_strip(buffer.substr(6));
  324. std::string marker = ctx_cli.load_input_file(fname, false);
  325. if (marker.empty()) {
  326. console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
  327. continue;
  328. }
  329. cur_msg += marker;
  330. console::log("Loaded text from '%s'\n", fname.c_str());
  331. continue;
  332. } else {
  333. // not a command
  334. cur_msg += buffer;
  335. }
  336. // generate response
  337. if (add_user_msg) {
  338. ctx_cli.messages.push_back({
  339. {"role", "user"},
  340. {"content", cur_msg}
  341. });
  342. cur_msg.clear();
  343. }
  344. result_timings timings;
  345. std::string assistant_content = ctx_cli.generate_completion(timings);
  346. ctx_cli.messages.push_back({
  347. {"role", "assistant"},
  348. {"content", assistant_content}
  349. });
  350. console::log("\n");
  351. if (params.show_timings) {
  352. console::set_display(DISPLAY_TYPE_INFO);
  353. console::log("\n");
  354. console::log("[ Prompt: %.1f t/s | Generation: %.1f t/s ]\n", timings.prompt_per_second, timings.predicted_per_second);
  355. console::set_display(DISPLAY_TYPE_RESET);
  356. }
  357. if (params.single_turn) {
  358. break;
  359. }
  360. }
  361. console::set_display(DISPLAY_TYPE_RESET);
  362. console::log("\nExiting...\n");
  363. ctx_cli.ctx_server.terminate();
  364. inference_thread.join();
  365. // bump the log level to display timings
  366. common_log_set_verbosity_thold(LOG_LEVEL_INFO);
  367. llama_memory_breakdown_print(ctx_cli.ctx_server.get_llama_context());
  368. return 0;
  369. }