1
0

cli.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. #include "common.h"
  2. #include "arg.h"
  3. #include "console.h"
  4. // #include "log.h"
  5. #include "server-context.h"
  6. #include "server-task.h"
  7. #include <atomic>
  8. #include <fstream>
  9. #include <thread>
  10. #include <signal.h>
  11. #if defined(_WIN32)
  12. #define WIN32_LEAN_AND_MEAN
  13. #ifndef NOMINMAX
  14. # define NOMINMAX
  15. #endif
  16. #include <windows.h>
  17. #endif
  18. const char * LLAMA_ASCII_LOGO = R"(
  19. ▄▄ ▄▄
  20. ██ ██
  21. ██ ██ ▀▀█▄ ███▄███▄ ▀▀█▄ ▄████ ████▄ ████▄
  22. ██ ██ ▄█▀██ ██ ██ ██ ▄█▀██ ██ ██ ██ ██ ██
  23. ██ ██ ▀█▄██ ██ ██ ██ ▀█▄██ ██ ▀████ ████▀ ████▀
  24. ██ ██
  25. ▀▀ ▀▀
  26. )";
  27. static std::atomic<bool> g_is_interrupted = false;
  28. static bool should_stop() {
  29. return g_is_interrupted.load();
  30. }
  31. #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
  32. static void signal_handler(int) {
  33. if (g_is_interrupted.load()) {
  34. // second Ctrl+C - exit immediately
  35. // make sure to clear colors before exiting (not using LOG or console.cpp here to avoid deadlock)
  36. fprintf(stdout, "\033[0m\n");
  37. fflush(stdout);
  38. std::exit(130);
  39. }
  40. g_is_interrupted.store(true);
  41. }
  42. #endif
  43. struct cli_context {
  44. server_context ctx_server;
  45. json messages = json::array();
  46. std::vector<raw_buffer> input_files;
  47. task_params defaults;
  48. // thread for showing "loading" animation
  49. std::atomic<bool> loading_show;
  50. cli_context(const common_params & params) {
  51. defaults.sampling = params.sampling;
  52. defaults.speculative = params.speculative;
  53. defaults.n_keep = params.n_keep;
  54. defaults.n_predict = params.n_predict;
  55. defaults.antiprompt = params.antiprompt;
  56. defaults.stream = true; // make sure we always use streaming mode
  57. defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way
  58. // defaults.return_progress = true; // TODO: show progress
  59. defaults.oaicompat_chat_syntax.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
  60. }
  61. std::string generate_completion(result_timings & out_timings) {
  62. server_response_reader rd = ctx_server.get_response_reader();
  63. {
  64. // TODO: reduce some copies here in the future
  65. server_task task = server_task(SERVER_TASK_TYPE_COMPLETION);
  66. task.id = rd.get_new_id();
  67. task.index = 0;
  68. task.params = defaults; // copy
  69. task.cli_input = messages; // copy
  70. task.cli_files = input_files; // copy
  71. rd.post_task({std::move(task)});
  72. }
  73. // wait for first result
  74. console::spinner::start();
  75. server_task_result_ptr result = rd.next(should_stop);
  76. console::spinner::stop();
  77. std::string curr_content;
  78. bool is_thinking = false;
  79. while (result) {
  80. if (should_stop()) {
  81. break;
  82. }
  83. if (result->is_error()) {
  84. json err_data = result->to_json();
  85. if (err_data.contains("message")) {
  86. console::error("Error: %s\n", err_data["message"].get<std::string>().c_str());
  87. } else {
  88. console::error("Error: %s\n", err_data.dump().c_str());
  89. }
  90. return curr_content;
  91. }
  92. auto res_partial = dynamic_cast<server_task_result_cmpl_partial *>(result.get());
  93. if (res_partial) {
  94. out_timings = std::move(res_partial->timings);
  95. for (const auto & diff : res_partial->oaicompat_msg_diffs) {
  96. if (!diff.content_delta.empty()) {
  97. if (is_thinking) {
  98. console::log("\n[End thinking]\n\n");
  99. console::set_display(DISPLAY_TYPE_RESET);
  100. is_thinking = false;
  101. }
  102. curr_content += diff.content_delta;
  103. console::log("%s", diff.content_delta.c_str());
  104. console::flush();
  105. }
  106. if (!diff.reasoning_content_delta.empty()) {
  107. console::set_display(DISPLAY_TYPE_REASONING);
  108. if (!is_thinking) {
  109. console::log("[Start thinking]\n");
  110. }
  111. is_thinking = true;
  112. console::log("%s", diff.reasoning_content_delta.c_str());
  113. console::flush();
  114. }
  115. }
  116. }
  117. auto res_final = dynamic_cast<server_task_result_cmpl_final *>(result.get());
  118. if (res_final) {
  119. out_timings = std::move(res_final->timings);
  120. break;
  121. }
  122. result = rd.next(should_stop);
  123. }
  124. g_is_interrupted.store(false);
  125. // server_response_reader automatically cancels pending tasks upon destruction
  126. return curr_content;
  127. }
  128. // TODO: support remote files in the future (http, https, etc)
  129. std::string load_input_file(const std::string & fname, bool is_media) {
  130. std::ifstream file(fname, std::ios::binary);
  131. if (!file) {
  132. return "";
  133. }
  134. if (is_media) {
  135. raw_buffer buf;
  136. buf.assign((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
  137. input_files.push_back(std::move(buf));
  138. return mtmd_default_marker();
  139. } else {
  140. std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
  141. return content;
  142. }
  143. }
  144. };
  145. int main(int argc, char ** argv) {
  146. common_params params;
  147. params.verbosity = LOG_LEVEL_ERROR; // by default, less verbose logs
  148. if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_CLI)) {
  149. return 1;
  150. }
  151. // TODO: maybe support it later?
  152. if (params.conversation_mode == COMMON_CONVERSATION_MODE_DISABLED) {
  153. console::error("--no-conversation is not supported by llama-cli\n");
  154. console::error("please use llama-completion instead\n");
  155. }
  156. common_init();
  157. // struct that contains llama context and inference
  158. cli_context ctx_cli(params);
  159. llama_backend_init();
  160. llama_numa_init(params.numa);
  161. // TODO: avoid using atexit() here by making `console` a singleton
  162. console::init(params.simple_io, params.use_color);
  163. atexit([]() { console::cleanup(); });
  164. console::set_display(DISPLAY_TYPE_RESET);
  165. #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
  166. struct sigaction sigint_action;
  167. sigint_action.sa_handler = signal_handler;
  168. sigemptyset (&sigint_action.sa_mask);
  169. sigint_action.sa_flags = 0;
  170. sigaction(SIGINT, &sigint_action, NULL);
  171. sigaction(SIGTERM, &sigint_action, NULL);
  172. #elif defined (_WIN32)
  173. auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
  174. return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
  175. };
  176. SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
  177. #endif
  178. console::log("\nLoading model... "); // followed by loading animation
  179. console::spinner::start();
  180. if (!ctx_cli.ctx_server.load_model(params)) {
  181. console::spinner::stop();
  182. console::error("\nFailed to load the model\n");
  183. return 1;
  184. }
  185. ctx_cli.ctx_server.init();
  186. console::spinner::stop();
  187. console::log("\n");
  188. std::thread inference_thread([&ctx_cli]() {
  189. ctx_cli.ctx_server.start_loop();
  190. });
  191. auto inf = ctx_cli.ctx_server.get_info();
  192. std::string modalities = "text";
  193. if (inf.has_inp_image) {
  194. modalities += ", vision";
  195. }
  196. if (inf.has_inp_audio) {
  197. modalities += ", audio";
  198. }
  199. if (!params.system_prompt.empty()) {
  200. ctx_cli.messages.push_back({
  201. {"role", "system"},
  202. {"content", params.system_prompt}
  203. });
  204. }
  205. console::log("\n");
  206. console::log("%s\n", LLAMA_ASCII_LOGO);
  207. console::log("build : %s\n", inf.build_info.c_str());
  208. console::log("model : %s\n", inf.model_name.c_str());
  209. console::log("modalities : %s\n", modalities.c_str());
  210. if (!params.system_prompt.empty()) {
  211. console::log("using custom system prompt\n");
  212. }
  213. console::log("\n");
  214. console::log("available commands:\n");
  215. console::log(" /exit or Ctrl+C stop or exit\n");
  216. console::log(" /regen regenerate the last response\n");
  217. console::log(" /clear clear the chat history\n");
  218. console::log(" /read add a text file\n");
  219. if (inf.has_inp_image) {
  220. console::log(" /image <file> add an image file\n");
  221. }
  222. if (inf.has_inp_audio) {
  223. console::log(" /audio <file> add an audio file\n");
  224. }
  225. console::log("\n");
  226. // interactive loop
  227. std::string cur_msg;
  228. while (true) {
  229. std::string buffer;
  230. console::set_display(DISPLAY_TYPE_USER_INPUT);
  231. if (params.prompt.empty()) {
  232. console::log("\n> ");
  233. std::string line;
  234. bool another_line = true;
  235. do {
  236. another_line = console::readline(line, params.multiline_input);
  237. buffer += line;
  238. } while (another_line);
  239. } else {
  240. // process input prompt from args
  241. for (auto & fname : params.image) {
  242. std::string marker = ctx_cli.load_input_file(fname, true);
  243. if (marker.empty()) {
  244. console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
  245. break;
  246. }
  247. console::log("Loaded media from '%s'\n", fname.c_str());
  248. cur_msg += marker;
  249. }
  250. buffer = params.prompt;
  251. if (buffer.size() > 500) {
  252. console::log("\n> %s ... (truncated)\n", buffer.substr(0, 500).c_str());
  253. } else {
  254. console::log("\n> %s\n", buffer.c_str());
  255. }
  256. params.prompt.clear(); // only use it once
  257. }
  258. console::set_display(DISPLAY_TYPE_RESET);
  259. console::log("\n");
  260. if (should_stop()) {
  261. g_is_interrupted.store(false);
  262. break;
  263. }
  264. // remove trailing newline
  265. if (!buffer.empty() &&buffer.back() == '\n') {
  266. buffer.pop_back();
  267. }
  268. // skip empty messages
  269. if (buffer.empty()) {
  270. continue;
  271. }
  272. bool add_user_msg = true;
  273. // process commands
  274. if (string_starts_with(buffer, "/exit")) {
  275. break;
  276. } else if (string_starts_with(buffer, "/regen")) {
  277. if (ctx_cli.messages.size() >= 2) {
  278. size_t last_idx = ctx_cli.messages.size() - 1;
  279. ctx_cli.messages.erase(last_idx);
  280. add_user_msg = false;
  281. } else {
  282. console::error("No message to regenerate.\n");
  283. continue;
  284. }
  285. } else if (string_starts_with(buffer, "/clear")) {
  286. ctx_cli.messages.clear();
  287. ctx_cli.input_files.clear();
  288. console::log("Chat history cleared.\n");
  289. continue;
  290. } else if (
  291. (string_starts_with(buffer, "/image ") && inf.has_inp_image) ||
  292. (string_starts_with(buffer, "/audio ") && inf.has_inp_audio)) {
  293. // just in case (bad copy-paste for example), we strip all trailing/leading spaces
  294. std::string fname = string_strip(buffer.substr(7));
  295. std::string marker = ctx_cli.load_input_file(fname, true);
  296. if (marker.empty()) {
  297. console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
  298. continue;
  299. }
  300. cur_msg += marker;
  301. console::log("Loaded media from '%s'\n", fname.c_str());
  302. continue;
  303. } else if (string_starts_with(buffer, "/read ")) {
  304. std::string fname = string_strip(buffer.substr(6));
  305. std::string marker = ctx_cli.load_input_file(fname, false);
  306. if (marker.empty()) {
  307. console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str());
  308. continue;
  309. }
  310. cur_msg += marker;
  311. console::log("Loaded text from '%s'\n", fname.c_str());
  312. continue;
  313. } else {
  314. // not a command
  315. cur_msg += buffer;
  316. }
  317. // generate response
  318. if (add_user_msg) {
  319. ctx_cli.messages.push_back({
  320. {"role", "user"},
  321. {"content", cur_msg}
  322. });
  323. cur_msg.clear();
  324. }
  325. result_timings timings;
  326. std::string assistant_content = ctx_cli.generate_completion(timings);
  327. ctx_cli.messages.push_back({
  328. {"role", "assistant"},
  329. {"content", assistant_content}
  330. });
  331. console::log("\n");
  332. if (params.show_timings) {
  333. console::set_display(DISPLAY_TYPE_INFO);
  334. console::log("\n");
  335. console::log("[ Prompt: %.1f t/s | Generation: %.1f t/s ]\n", timings.prompt_per_second, timings.predicted_per_second);
  336. console::set_display(DISPLAY_TYPE_RESET);
  337. }
  338. if (params.single_turn) {
  339. break;
  340. }
  341. }
  342. console::set_display(DISPLAY_TYPE_RESET);
  343. console::log("\nExiting...\n");
  344. ctx_cli.ctx_server.terminate();
  345. inference_thread.join();
  346. // bump the log level to display timings
  347. common_log_set_verbosity_thold(LOG_LEVEL_INFO);
  348. llama_memory_breakdown_print(ctx_cli.ctx_server.get_llama_context());
  349. return 0;
  350. }