cli.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. #include "common.h"
  2. #include "arg.h"
  3. #include "console.h"
  4. // #include "log.h"
  5. #include "server-context.h"
  6. #include "server-task.h"
  7. #include <atomic>
  8. #include <fstream>
  9. #include <thread>
  10. #include <signal.h>
  11. #if defined(_WIN32)
  12. #define WIN32_LEAN_AND_MEAN
  13. #ifndef NOMINMAX
  14. # define NOMINMAX
  15. #endif
  16. #include <windows.h>
  17. #endif
  18. const char * LLAMA_ASCII_LOGO = R"(
  19. ▄▄ ▄▄
  20. ██ ██
  21. ██ ██ ▀▀█▄ ███▄███▄ ▀▀█▄ ▄████ ████▄ ████▄
  22. ██ ██ ▄█▀██ ██ ██ ██ ▄█▀██ ██ ██ ██ ██ ██
  23. ██ ██ ▀█▄██ ██ ██ ██ ▀█▄██ ██ ▀████ ████▀ ████▀
  24. ██ ██
  25. ▀▀ ▀▀
  26. )";
  27. static std::atomic<bool> g_is_interrupted = false;
  28. static bool should_stop() {
  29. return g_is_interrupted.load();
  30. }
  31. #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
  32. static void signal_handler(int) {
  33. if (g_is_interrupted.load()) {
  34. // second Ctrl+C - exit immediately
  35. // make sure to clear colors before exiting (not using LOG or console.cpp here to avoid deadlock)
  36. fprintf(stdout, "\033[0m\n");
  37. fflush(stdout);
  38. std::exit(130);
  39. }
  40. g_is_interrupted.store(true);
  41. }
  42. #endif
  43. struct cli_context {
  44. server_context ctx_server;
  45. json messages = json::array();
  46. std::vector<raw_buffer> input_files;
  47. task_params defaults;
  48. // thread for showing "loading" animation
  49. std::atomic<bool> loading_show;
  50. cli_context(const common_params & params) {
  51. defaults.sampling = params.sampling;
  52. defaults.speculative = params.speculative;
  53. defaults.n_keep = params.n_keep;
  54. defaults.n_predict = params.n_predict;
  55. defaults.antiprompt = params.antiprompt;
  56. defaults.stream = true; // make sure we always use streaming mode
  57. defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way
  58. // defaults.return_progress = true; // TODO: show progress
  59. defaults.oaicompat_chat_syntax.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
  60. }
  61. std::string generate_completion(result_timings & out_timings) {
  62. server_response_reader rd = ctx_server.get_response_reader();
  63. {
  64. // TODO: reduce some copies here in the future
  65. server_task task = server_task(SERVER_TASK_TYPE_COMPLETION);
  66. task.id = rd.get_new_id();
  67. task.index = 0;
  68. task.params = defaults; // copy
  69. task.cli_input = messages; // copy
  70. task.cli_files = input_files; // copy
  71. rd.post_task({std::move(task)});
  72. }
  73. // wait for first result
  74. console::spinner::start();
  75. server_task_result_ptr result = rd.next(should_stop);
  76. console::spinner::stop();
  77. std::string curr_content;
  78. bool is_thinking = false;
  79. while (result) {
  80. if (should_stop()) {
  81. break;
  82. }
  83. if (result->is_error()) {
  84. json err_data = result->to_json();
  85. if (err_data.contains("message")) {
  86. console::error("Error: %s\n", err_data["message"].get<std::string>().c_str());
  87. } else {
  88. console::error("Error: %s\n", err_data.dump().c_str());
  89. }
  90. return curr_content;
  91. }
  92. auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
  93. if (res_partial) {
  94. out_timings = std::move(res_partial->timings);
  95. for (const auto & diff : res_partial->oaicompat_msg_diffs) {
  96. if (!diff.content_delta.empty()) {
  97. if (is_thinking) {
  98. console::log("\n[End thinking]\n\n");
  99. console::set_display(DISPLAY_TYPE_RESET);
  100. is_thinking = false;
  101. }
  102. curr_content += diff.content_delta;
  103. console::log("%s", diff.content_delta.c_str());
  104. console::flush();
  105. }
  106. if (!diff.reasoning_content_delta.empty()) {
  107. console::set_display(DISPLAY_TYPE_REASONING);
  108. if (!is_thinking) {
  109. console::log("[Start thinking]\n");
  110. }
  111. is_thinking = true;
  112. console::log("%s", diff.reasoning_content_delta.c_str());
  113. console::flush();
  114. }
  115. }
  116. }
  117. auto res_final = dynamic_cast<server_task_result_cmpl_final *>(result.get());
  118. if (res_final) {
  119. out_timings = std::move(res_final->timings);
  120. break;
  121. }
  122. result = rd.next(should_stop);
  123. }
  124. g_is_interrupted.store(false);
  125. // server_response_reader automatically cancels pending tasks upon destruction
  126. return curr_content;
  127. }
  128. // TODO: support remote files in the future (http, https, etc)
  129. std::string load_input_file(const std::string & fname, bool is_media) {
  130. std::ifstream file(fname, std::ios::binary);
  131. if (!file) {
  132. return "";
  133. }
  134. if (is_media) {
  135. raw_buffer buf;
  136. buf.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
  137. input_files.push_back(std::move(buf));
  138. return mtmd_default_marker();
  139. } else {
  140. std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
  141. return content;
  142. }
  143. }
  144. };
  145. int main(int argc, char ** argv) {
  146. common_params params;
  147. params.verbosity = LOG_LEVEL_ERROR; // by default, less verbose logs
  148. if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_CLI)) {
  149. return 1;
  150. }
  151. // TODO: maybe support it later?
  152. if (params.conversation_mode == COMMON_CONVERSATION_MODE_DISABLED) {
  153. console::error("--no-conversation is not supported by llama-cli\n");
  154. console::error("please use llama-completion instead\n");
  155. }
  156. common_init();
  157. // struct that contains llama context and inference
  158. cli_context ctx_cli(params);
  159. llama_backend_init();
  160. llama_numa_init(params.numa);
  161. // TODO: avoid using atexit() here by making `console` a singleton
  162. console::init(params.simple_io, params.use_color);
  163. atexit([]() { console::cleanup(); });
  164. console::set_display(DISPLAY_TYPE_RESET);
  165. #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
  166. struct sigaction sigint_action;
  167. sigint_action.sa_handler = signal_handler;
  168. sigemptyset (&sigint_action.sa_mask);
  169. sigint_action.sa_flags = 0;
  170. sigaction(SIGINT, &sigint_action, NULL);
  171. sigaction(SIGTERM, &sigint_action, NULL);
  172. #elif defined (_WIN32)
  173. auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
  174. return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
  175. };
  176. SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
  177. #endif
  178. console::log("\nLoading model... "); // followed by loading animation
  179. console::spinner::start();
  180. if (!ctx_cli.ctx_server.load_model(params)) {
  181. console::spinner::stop();
  182. console::error("\nFailed to load the model\n");
  183. return 1;
  184. }
  185. console::spinner::stop();
  186. console::log("\n");
  187. std::thread inference_thread([&ctx_cli]() {
  188. ctx_cli.ctx_server.start_loop();
  189. });
  190. auto inf = ctx_cli.ctx_server.get_meta();
  191. std::string modalities = "text";
  192. if (inf.has_inp_image) {
  193. modalities += ", vision";
  194. }
  195. if (inf.has_inp_audio) {
  196. modalities += ", audio";
  197. }
  198. if (!params.system_prompt.empty()) {
  199. ctx_cli.messages.push_back({
  200. {"role", "system"},
  201. {"content", params.system_prompt}
  202. });
  203. }
  204. console::log("\n");
  205. console::log("%s\n", LLAMA_ASCII_LOGO);
  206. console::log("build : %s\n", inf.build_info.c_str());
  207. console::log("model : %s\n", inf.model_name.c_str());
  208. console::log("modalities : %s\n", modalities.c_str());
  209. if (!params.system_prompt.empty()) {
  210. console::log("using custom system prompt\n");
  211. }
  212. console::log("\n");
  213. console::log("available commands:\n");
  214. console::log(" /exit or Ctrl+C stop or exit\n");
  215. console::log(" /regen regenerate the last response\n");
  216. console::log(" /clear clear the chat history\n");
  217. console::log(" /read add a text file\n");
  218. if (inf.has_inp_image) {
  219. console::log(" /image <file> add an image file\n");
  220. }
  221. if (inf.has_inp_audio) {
  222. console::log(" /audio <file> add an audio file\n");
  223. }
  224. console::log("\n");
  225. // interactive loop
  226. std::string cur_msg;
  227. while (true) {
  228. std::string buffer;
  229. console::set_display(DISPLAY_TYPE_USER_INPUT);
  230. if (params.prompt.empty()) {
  231. console::log("\n> ");
  232. std::string line;
  233. bool another_line = true;
  234. do {
  235. another_line = console::readline(line, params.multiline_input);
  236. buffer += line;
  237. } while (another_line);
  238. } else {
  239. // process input prompt from args
  240. for (auto & fname : params.image) {
  241. std::string marker = ctx_cli.load_input_file(fname, true);
  242. if (marker.empty()) {
  243. console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
  244. break;
  245. }
  246. console::log("Loaded media from '%s'\n", fname.c_str());
  247. cur_msg += marker;
  248. }
  249. buffer = params.prompt;
  250. if (buffer.size() > 500) {
  251. console::log("\n> %s ... (truncated)\n", buffer.substr(0, 500).c_str());
  252. } else {
  253. console::log("\n> %s\n", buffer.c_str());
  254. }
  255. params.prompt.clear(); // only use it once
  256. }
  257. console::set_display(DISPLAY_TYPE_RESET);
  258. console::log("\n");
  259. if (should_stop()) {
  260. g_is_interrupted.store(false);
  261. break;
  262. }
  263. // remove trailing newline
  264. if (!buffer.empty() &&buffer.back() == '\n') {
  265. buffer.pop_back();
  266. }
  267. // skip empty messages
  268. if (buffer.empty()) {
  269. continue;
  270. }
  271. bool add_user_msg = true;
  272. // process commands
  273. if (string_starts_with(buffer, "/exit")) {
  274. break;
  275. } else if (string_starts_with(buffer, "/regen")) {
  276. if (ctx_cli.messages.size() >= 2) {
  277. size_t last_idx = ctx_cli.messages.size() - 1;
  278. ctx_cli.messages.erase(last_idx);
  279. add_user_msg = false;
  280. } else {
  281. console::error("No message to regenerate.\n");
  282. continue;
  283. }
  284. } else if (string_starts_with(buffer, "/clear")) {
  285. ctx_cli.messages.clear();
  286. ctx_cli.input_files.clear();
  287. console::log("Chat history cleared.\n");
  288. continue;
  289. } else if (
  290. (string_starts_with(buffer, "/image ") && inf.has_inp_image) ||
  291. (string_starts_with(buffer, "/audio ") && inf.has_inp_audio)) {
  292. // just in case (bad copy-paste for example), we strip all trailing/leading spaces
  293. std::string fname = string_strip(buffer.substr(7));
  294. std::string marker = ctx_cli.load_input_file(fname, true);
  295. if (marker.empty()) {
  296. console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
  297. continue;
  298. }
  299. cur_msg += marker;
  300. console::log("Loaded media from '%s'\n", fname.c_str());
  301. continue;
  302. } else if (string_starts_with(buffer, "/read ")) {
  303. std::string fname = string_strip(buffer.substr(6));
  304. std::string marker = ctx_cli.load_input_file(fname, false);
  305. if (marker.empty()) {
  306. console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
  307. continue;
  308. }
  309. cur_msg += marker;
  310. console::log("Loaded text from '%s'\n", fname.c_str());
  311. continue;
  312. } else {
  313. // not a command
  314. cur_msg += buffer;
  315. }
  316. // generate response
  317. if (add_user_msg) {
  318. ctx_cli.messages.push_back({
  319. {"role", "user"},
  320. {"content", cur_msg}
  321. });
  322. cur_msg.clear();
  323. }
  324. result_timings timings;
  325. std::string assistant_content = ctx_cli.generate_completion(timings);
  326. ctx_cli.messages.push_back({
  327. {"role", "assistant"},
  328. {"content", assistant_content}
  329. });
  330. console::log("\n");
  331. if (params.show_timings) {
  332. console::set_display(DISPLAY_TYPE_INFO);
  333. console::log("\n");
  334. console::log("[ Prompt: %.1f t/s | Generation: %.1f t/s ]\n", timings.prompt_per_second, timings.predicted_per_second);
  335. console::set_display(DISPLAY_TYPE_RESET);
  336. }
  337. if (params.single_turn) {
  338. break;
  339. }
  340. }
  341. console::set_display(DISPLAY_TYPE_RESET);
  342. console::log("\nExiting...\n");
  343. ctx_cli.ctx_server.terminate();
  344. inference_thread.join();
  345. // bump the log level to display timings
  346. common_log_set_verbosity_thold(LOG_LEVEL_INFO);
  347. llama_memory_breakdown_print(ctx_cli.ctx_server.get_llama_context());
  348. return 0;
  349. }