cli.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. #include "common.h"
  2. #include "arg.h"
  3. #include "console.h"
  4. // #include "log.h"
  5. #include "server-context.h"
  6. #include "server-task.h"
  7. #include <atomic>
  8. #include <fstream>
  9. #include <thread>
  10. #include <signal.h>
  11. #if defined(_WIN32)
  12. #define WIN32_LEAN_AND_MEAN
  13. #ifndef NOMINMAX
  14. # define NOMINMAX
  15. #endif
  16. #include <windows.h>
  17. #endif
  18. const char * LLAMA_ASCII_LOGO = R"(
  19. ▄▄ ▄▄
  20. ██ ██
  21. ██ ██ ▀▀█▄ ███▄███▄ ▀▀█▄ ▄████ ████▄ ████▄
  22. ██ ██ ▄█▀██ ██ ██ ██ ▄█▀██ ██ ██ ██ ██ ██
  23. ██ ██ ▀█▄██ ██ ██ ██ ▀█▄██ ██ ▀████ ████▀ ████▀
  24. ██ ██
  25. ▀▀ ▀▀
  26. )";
  27. static std::atomic<bool> g_is_interrupted = false;
  28. static bool should_stop() {
  29. return g_is_interrupted.load();
  30. }
  31. #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
  32. static void signal_handler(int) {
  33. if (g_is_interrupted.load()) {
  34. // second Ctrl+C - exit immediately
  35. // make sure to clear colors before exiting (not using LOG or console.cpp here to avoid deadlock)
  36. fprintf(stdout, "\033[0m\n");
  37. fflush(stdout);
  38. std::exit(130);
  39. }
  40. g_is_interrupted.store(true);
  41. }
  42. #endif
  43. struct cli_context {
  44. server_context ctx_server;
  45. json messages = json::array();
  46. std::vector<raw_buffer> input_files;
  47. task_params defaults;
  48. // thread for showing "loading" animation
  49. std::atomic<bool> loading_show;
  50. cli_context(const common_params & params) {
  51. defaults.sampling = params.sampling;
  52. defaults.speculative = params.speculative;
  53. defaults.n_keep = params.n_keep;
  54. defaults.n_predict = params.n_predict;
  55. defaults.antiprompt = params.antiprompt;
  56. defaults.stream = true; // make sure we always use streaming mode
  57. defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way
  58. // defaults.return_progress = true; // TODO: show progress
  59. defaults.oaicompat_chat_syntax.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
  60. }
  61. std::string generate_completion(result_timings & out_timings) {
  62. server_response_reader rd = ctx_server.get_response_reader();
  63. auto formatted = format_chat();
  64. {
  65. // TODO: reduce some copies here in the future
  66. server_task task = server_task(SERVER_TASK_TYPE_COMPLETION);
  67. task.id = rd.get_new_id();
  68. task.index = 0;
  69. task.params = defaults; // copy
  70. task.cli_prompt = formatted.prompt; // copy
  71. task.cli_files = input_files; // copy
  72. task.cli = true;
  73. rd.post_task({std::move(task)});
  74. }
  75. // wait for first result
  76. console::spinner::start();
  77. server_task_result_ptr result = rd.next(should_stop);
  78. console::spinner::stop();
  79. std::string curr_content;
  80. bool is_thinking = false;
  81. while (result) {
  82. if (should_stop()) {
  83. break;
  84. }
  85. if (result->is_error()) {
  86. json err_data = result->to_json();
  87. if (err_data.contains("message")) {
  88. console::error("Error: %s\n", err_data["message"].get<std::string>().c_str());
  89. } else {
  90. console::error("Error: %s\n", err_data.dump().c_str());
  91. }
  92. return curr_content;
  93. }
  94. auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
  95. if (res_partial) {
  96. out_timings = std::move(res_partial->timings);
  97. for (const auto & diff : res_partial->oaicompat_msg_diffs) {
  98. if (!diff.content_delta.empty()) {
  99. if (is_thinking) {
  100. console::log("\n[End thinking]\n\n");
  101. console::set_display(DISPLAY_TYPE_RESET);
  102. is_thinking = false;
  103. }
  104. curr_content += diff.content_delta;
  105. console::log("%s", diff.content_delta.c_str());
  106. console::flush();
  107. }
  108. if (!diff.reasoning_content_delta.empty()) {
  109. console::set_display(DISPLAY_TYPE_REASONING);
  110. if (!is_thinking) {
  111. console::log("[Start thinking]\n");
  112. }
  113. is_thinking = true;
  114. console::log("%s", diff.reasoning_content_delta.c_str());
  115. console::flush();
  116. }
  117. }
  118. }
  119. auto res_final = dynamic_cast<server_task_result_cmpl_final *>(result.get());
  120. if (res_final) {
  121. out_timings = std::move(res_final->timings);
  122. break;
  123. }
  124. result = rd.next(should_stop);
  125. }
  126. g_is_interrupted.store(false);
  127. // server_response_reader automatically cancels pending tasks upon destruction
  128. return curr_content;
  129. }
  130. // TODO: support remote files in the future (http, https, etc)
  131. std::string load_input_file(const std::string & fname, bool is_media) {
  132. std::ifstream file(fname, std::ios::binary);
  133. if (!file) {
  134. return "";
  135. }
  136. if (is_media) {
  137. raw_buffer buf;
  138. buf.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
  139. input_files.push_back(std::move(buf));
  140. return mtmd_default_marker();
  141. } else {
  142. std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
  143. return content;
  144. }
  145. }
  146. common_chat_params format_chat() {
  147. auto meta = ctx_server.get_meta();
  148. auto & chat_params = meta.chat_params;
  149. common_chat_templates_inputs inputs;
  150. inputs.messages = common_chat_msgs_parse_oaicompat(messages);
  151. inputs.tools = {}; // TODO
  152. inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE;
  153. inputs.json_schema = ""; // TODO
  154. inputs.grammar = ""; // TODO
  155. inputs.use_jinja = chat_params.use_jinja;
  156. inputs.parallel_tool_calls = false;
  157. inputs.add_generation_prompt = true;
  158. inputs.reasoning_format = chat_params.reasoning_format;
  159. inputs.enable_thinking = chat_params.enable_thinking;
  160. // Apply chat template to the list of messages
  161. return common_chat_templates_apply(chat_params.tmpls.get(), inputs);
  162. }
  163. };
  164. int main(int argc, char ** argv) {
  165. common_params params;
  166. params.verbosity = LOG_LEVEL_ERROR; // by default, less verbose logs
  167. if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_CLI)) {
  168. return 1;
  169. }
  170. // TODO: maybe support it later?
  171. if (params.conversation_mode == COMMON_CONVERSATION_MODE_DISABLED) {
  172. console::error("--no-conversation is not supported by llama-cli\n");
  173. console::error("please use llama-completion instead\n");
  174. }
  175. common_init();
  176. // struct that contains llama context and inference
  177. cli_context ctx_cli(params);
  178. llama_backend_init();
  179. llama_numa_init(params.numa);
  180. // TODO: avoid using atexit() here by making `console` a singleton
  181. console::init(params.simple_io, params.use_color);
  182. atexit([]() { console::cleanup(); });
  183. console::set_display(DISPLAY_TYPE_RESET);
  184. #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
  185. struct sigaction sigint_action;
  186. sigint_action.sa_handler = signal_handler;
  187. sigemptyset (&sigint_action.sa_mask);
  188. sigint_action.sa_flags = 0;
  189. sigaction(SIGINT, &sigint_action, NULL);
  190. sigaction(SIGTERM, &sigint_action, NULL);
  191. #elif defined (_WIN32)
  192. auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
  193. return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
  194. };
  195. SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
  196. #endif
  197. console::log("\nLoading model... "); // followed by loading animation
  198. console::spinner::start();
  199. if (!ctx_cli.ctx_server.load_model(params)) {
  200. console::spinner::stop();
  201. console::error("\nFailed to load the model\n");
  202. return 1;
  203. }
  204. console::spinner::stop();
  205. console::log("\n");
  206. std::thread inference_thread([&ctx_cli]() {
  207. ctx_cli.ctx_server.start_loop();
  208. });
  209. auto inf = ctx_cli.ctx_server.get_meta();
  210. std::string modalities = "text";
  211. if (inf.has_inp_image) {
  212. modalities += ", vision";
  213. }
  214. if (inf.has_inp_audio) {
  215. modalities += ", audio";
  216. }
  217. if (!params.system_prompt.empty()) {
  218. ctx_cli.messages.push_back({
  219. {"role", "system"},
  220. {"content", params.system_prompt}
  221. });
  222. }
  223. console::log("\n");
  224. console::log("%s\n", LLAMA_ASCII_LOGO);
  225. console::log("build : %s\n", inf.build_info.c_str());
  226. console::log("model : %s\n", inf.model_name.c_str());
  227. console::log("modalities : %s\n", modalities.c_str());
  228. if (!params.system_prompt.empty()) {
  229. console::log("using custom system prompt\n");
  230. }
  231. console::log("\n");
  232. console::log("available commands:\n");
  233. console::log(" /exit or Ctrl+C stop or exit\n");
  234. console::log(" /regen regenerate the last response\n");
  235. console::log(" /clear clear the chat history\n");
  236. console::log(" /read add a text file\n");
  237. if (inf.has_inp_image) {
  238. console::log(" /image <file> add an image file\n");
  239. }
  240. if (inf.has_inp_audio) {
  241. console::log(" /audio <file> add an audio file\n");
  242. }
  243. console::log("\n");
  244. // interactive loop
  245. std::string cur_msg;
  246. while (true) {
  247. std::string buffer;
  248. console::set_display(DISPLAY_TYPE_USER_INPUT);
  249. if (params.prompt.empty()) {
  250. console::log("\n> ");
  251. std::string line;
  252. bool another_line = true;
  253. do {
  254. another_line = console::readline(line, params.multiline_input);
  255. buffer += line;
  256. } while (another_line);
  257. } else {
  258. // process input prompt from args
  259. for (auto & fname : params.image) {
  260. std::string marker = ctx_cli.load_input_file(fname, true);
  261. if (marker.empty()) {
  262. console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
  263. break;
  264. }
  265. console::log("Loaded media from '%s'\n", fname.c_str());
  266. cur_msg += marker;
  267. }
  268. buffer = params.prompt;
  269. if (buffer.size() > 500) {
  270. console::log("\n> %s ... (truncated)\n", buffer.substr(0, 500).c_str());
  271. } else {
  272. console::log("\n> %s\n", buffer.c_str());
  273. }
  274. params.prompt.clear(); // only use it once
  275. }
  276. console::set_display(DISPLAY_TYPE_RESET);
  277. console::log("\n");
  278. if (should_stop()) {
  279. g_is_interrupted.store(false);
  280. break;
  281. }
  282. // remove trailing newline
  283. if (!buffer.empty() &&buffer.back() == '\n') {
  284. buffer.pop_back();
  285. }
  286. // skip empty messages
  287. if (buffer.empty()) {
  288. continue;
  289. }
  290. bool add_user_msg = true;
  291. // process commands
  292. if (string_starts_with(buffer, "/exit")) {
  293. break;
  294. } else if (string_starts_with(buffer, "/regen")) {
  295. if (ctx_cli.messages.size() >= 2) {
  296. size_t last_idx = ctx_cli.messages.size() - 1;
  297. ctx_cli.messages.erase(last_idx);
  298. add_user_msg = false;
  299. } else {
  300. console::error("No message to regenerate.\n");
  301. continue;
  302. }
  303. } else if (string_starts_with(buffer, "/clear")) {
  304. ctx_cli.messages.clear();
  305. ctx_cli.input_files.clear();
  306. console::log("Chat history cleared.\n");
  307. continue;
  308. } else if (
  309. (string_starts_with(buffer, "/image ") && inf.has_inp_image) ||
  310. (string_starts_with(buffer, "/audio ") && inf.has_inp_audio)) {
  311. // just in case (bad copy-paste for example), we strip all trailing/leading spaces
  312. std::string fname = string_strip(buffer.substr(7));
  313. std::string marker = ctx_cli.load_input_file(fname, true);
  314. if (marker.empty()) {
  315. console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
  316. continue;
  317. }
  318. cur_msg += marker;
  319. console::log("Loaded media from '%s'\n", fname.c_str());
  320. continue;
  321. } else if (string_starts_with(buffer, "/read ")) {
  322. std::string fname = string_strip(buffer.substr(6));
  323. std::string marker = ctx_cli.load_input_file(fname, false);
  324. if (marker.empty()) {
  325. console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
  326. continue;
  327. }
  328. cur_msg += marker;
  329. console::log("Loaded text from '%s'\n", fname.c_str());
  330. continue;
  331. } else {
  332. // not a command
  333. cur_msg += buffer;
  334. }
  335. // generate response
  336. if (add_user_msg) {
  337. ctx_cli.messages.push_back({
  338. {"role", "user"},
  339. {"content", cur_msg}
  340. });
  341. cur_msg.clear();
  342. }
  343. result_timings timings;
  344. std::string assistant_content = ctx_cli.generate_completion(timings);
  345. ctx_cli.messages.push_back({
  346. {"role", "assistant"},
  347. {"content", assistant_content}
  348. });
  349. console::log("\n");
  350. if (params.show_timings) {
  351. console::set_display(DISPLAY_TYPE_INFO);
  352. console::log("\n");
  353. console::log("[ Prompt: %.1f t/s | Generation: %.1f t/s ]\n", timings.prompt_per_second, timings.predicted_per_second);
  354. console::set_display(DISPLAY_TYPE_RESET);
  355. }
  356. if (params.single_turn) {
  357. break;
  358. }
  359. }
  360. console::set_display(DISPLAY_TYPE_RESET);
  361. console::log("\nExiting...\n");
  362. ctx_cli.ctx_server.terminate();
  363. inference_thread.join();
  364. // bump the log level to display timings
  365. common_log_set_verbosity_thold(LOG_LEVEL_INFO);
  366. llama_memory_breakdown_print(ctx_cli.ctx_server.get_llama_context());
  367. return 0;
  368. }