server.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. #include "server-context.h"
  2. #include "server-http.h"
  3. #include "server-models.h"
  4. #include "arg.h"
  5. #include "common.h"
  6. #include "llama.h"
  7. #include "log.h"
  8. #include <atomic>
  9. #include <exception>
  10. #include <signal.h>
  11. #include <thread> // for std::thread::hardware_concurrency
  12. #if defined(_WIN32)
  13. #include <windows.h>
  14. #endif
  15. static std::function<void(int)> shutdown_handler;
  16. static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
  17. static inline void signal_handler(int signal) {
  18. if (is_terminating.test_and_set()) {
  19. // in case it hangs, we can force terminate the server by hitting Ctrl+C twice
  20. // this is for better developer experience, we can remove when the server is stable enough
  21. fprintf(stderr, "Received second interrupt, terminating immediately.\n");
  22. exit(1);
  23. }
  24. shutdown_handler(signal);
  25. }
  26. // wrapper function that handles exceptions and logs errors
  27. // this is to make sure handler_t never throws exceptions; instead, it returns an error response
  28. static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) {
  29. return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr {
  30. std::string message;
  31. error_type error;
  32. try {
  33. return func(req);
  34. } catch (const std::invalid_argument & e) {
  35. // treat invalid_argument as invalid request (400)
  36. error = ERROR_TYPE_INVALID_REQUEST;
  37. message = e.what();
  38. } catch (const std::exception & e) {
  39. // treat other exceptions as server error (500)
  40. error = ERROR_TYPE_SERVER;
  41. message = e.what();
  42. } catch (...) {
  43. error = ERROR_TYPE_SERVER;
  44. message = "unknown error";
  45. }
  46. auto res = std::make_unique<server_http_res>();
  47. res->status = 500;
  48. try {
  49. json error_data = format_error_response(message, error);
  50. res->status = json_value(error_data, "code", 500);
  51. res->data = safe_json_to_str({{ "error", error_data }});
  52. SRV_WRN("got exception: %s\n", res->data.c_str());
  53. } catch (const std::exception & e) {
  54. SRV_ERR("got another exception: %s | while handling exception: %s\n", e.what(), message.c_str());
  55. res->data = "Internal Server Error";
  56. }
  57. return res;
  58. };
  59. }
  60. int main(int argc, char ** argv, char ** envp) {
  61. // own arguments required by this example
  62. common_params params;
  63. if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
  64. return 1;
  65. }
  66. // validate batch size for embeddings
  67. // embeddings require all tokens to be processed in a single ubatch
  68. // see https://github.com/ggml-org/llama.cpp/issues/12836
  69. if (params.embedding && params.n_batch > params.n_ubatch) {
  70. LOG_WRN("%s: embeddings enabled with n_batch (%d) > n_ubatch (%d)\n", __func__, params.n_batch, params.n_ubatch);
  71. LOG_WRN("%s: setting n_batch = n_ubatch = %d to avoid assertion failure\n", __func__, params.n_ubatch);
  72. params.n_batch = params.n_ubatch;
  73. }
  74. if (params.n_parallel < 0) {
  75. LOG_INF("%s: n_parallel is set to auto, using n_parallel = 4 and kv_unified = true\n", __func__);
  76. params.n_parallel = 4;
  77. params.kv_unified = true;
  78. }
  79. // for consistency between server router mode and single-model mode, we set the same model name as alias
  80. if (params.model_alias.empty() && !params.model.name.empty()) {
  81. params.model_alias = params.model.name;
  82. }
  83. common_init();
  84. // struct that contains llama context and inference
  85. server_context ctx_server;
  86. llama_backend_init();
  87. llama_numa_init(params.numa);
  88. LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency());
  89. LOG_INF("\n");
  90. LOG_INF("%s\n", common_params_get_system_info(params).c_str());
  91. LOG_INF("\n");
  92. server_http_context ctx_http;
  93. if (!ctx_http.init(params)) {
  94. LOG_ERR("%s: failed to initialize HTTP server\n", __func__);
  95. return 1;
  96. }
  97. //
  98. // Router
  99. //
  100. // register API routes
  101. server_routes routes(params, ctx_server, [&ctx_http]() { return ctx_http.is_ready.load(); });
  102. bool is_router_server = params.model.path.empty();
  103. std::optional<server_models_routes> models_routes{};
  104. if (is_router_server) {
  105. // setup server instances manager
  106. try {
  107. models_routes.emplace(params, argc, argv, envp);
  108. } catch (const std::exception & e) {
  109. LOG_ERR("%s: failed to initialize router models: %s\n", __func__, e.what());
  110. return 1;
  111. }
  112. // proxy handlers
  113. // note: routes.get_health stays the same
  114. routes.get_metrics = models_routes->proxy_get;
  115. routes.post_props = models_routes->proxy_post;
  116. routes.get_api_show = models_routes->proxy_get;
  117. routes.post_completions = models_routes->proxy_post;
  118. routes.post_completions_oai = models_routes->proxy_post;
  119. routes.post_chat_completions = models_routes->proxy_post;
  120. routes.post_anthropic_messages = models_routes->proxy_post;
  121. routes.post_anthropic_count_tokens = models_routes->proxy_post;
  122. routes.post_infill = models_routes->proxy_post;
  123. routes.post_embeddings = models_routes->proxy_post;
  124. routes.post_embeddings_oai = models_routes->proxy_post;
  125. routes.post_rerank = models_routes->proxy_post;
  126. routes.post_tokenize = models_routes->proxy_post;
  127. routes.post_detokenize = models_routes->proxy_post;
  128. routes.post_apply_template = models_routes->proxy_post;
  129. routes.get_lora_adapters = models_routes->proxy_get;
  130. routes.post_lora_adapters = models_routes->proxy_post;
  131. routes.get_slots = models_routes->proxy_get;
  132. routes.post_slots = models_routes->proxy_post;
  133. // custom routes for router
  134. routes.get_props = models_routes->get_router_props;
  135. routes.get_models = models_routes->get_router_models;
  136. ctx_http.post("/models/load", ex_wrapper(models_routes->post_router_models_load));
  137. ctx_http.post("/models/unload", ex_wrapper(models_routes->post_router_models_unload));
  138. }
  139. ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
  140. ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
  141. ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics));
  142. ctx_http.get ("/props", ex_wrapper(routes.get_props));
  143. ctx_http.post("/props", ex_wrapper(routes.post_props));
  144. ctx_http.post("/api/show", ex_wrapper(routes.get_api_show));
  145. ctx_http.get ("/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check)
  146. ctx_http.get ("/v1/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check)
  147. ctx_http.get ("/api/tags", ex_wrapper(routes.get_models)); // ollama specific endpoint. public endpoint (no API key check)
  148. ctx_http.post("/completion", ex_wrapper(routes.post_completions)); // legacy
  149. ctx_http.post("/completions", ex_wrapper(routes.post_completions));
  150. ctx_http.post("/v1/completions", ex_wrapper(routes.post_completions_oai));
  151. ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions));
  152. ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions));
  153. ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint
  154. ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API
  155. ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting
  156. ctx_http.post("/infill", ex_wrapper(routes.post_infill));
  157. ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy
  158. ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings));
  159. ctx_http.post("/v1/embeddings", ex_wrapper(routes.post_embeddings_oai));
  160. ctx_http.post("/rerank", ex_wrapper(routes.post_rerank));
  161. ctx_http.post("/reranking", ex_wrapper(routes.post_rerank));
  162. ctx_http.post("/v1/rerank", ex_wrapper(routes.post_rerank));
  163. ctx_http.post("/v1/reranking", ex_wrapper(routes.post_rerank));
  164. ctx_http.post("/tokenize", ex_wrapper(routes.post_tokenize));
  165. ctx_http.post("/detokenize", ex_wrapper(routes.post_detokenize));
  166. ctx_http.post("/apply-template", ex_wrapper(routes.post_apply_template));
  167. // LoRA adapters hotswap
  168. ctx_http.get ("/lora-adapters", ex_wrapper(routes.get_lora_adapters));
  169. ctx_http.post("/lora-adapters", ex_wrapper(routes.post_lora_adapters));
  170. // Save & load slots
  171. ctx_http.get ("/slots", ex_wrapper(routes.get_slots));
  172. ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots));
  173. //
  174. // Start the server
  175. //
  176. std::function<void()> clean_up;
  177. if (is_router_server) {
  178. LOG_INF("%s: starting router server, no model will be loaded in this process\n", __func__);
  179. clean_up = [&models_routes]() {
  180. SRV_INF("%s: cleaning up before exit...\n", __func__);
  181. if (models_routes.has_value()) {
  182. models_routes->models.unload_all();
  183. }
  184. llama_backend_free();
  185. };
  186. if (!ctx_http.start()) {
  187. clean_up();
  188. LOG_ERR("%s: exiting due to HTTP server error\n", __func__);
  189. return 1;
  190. }
  191. ctx_http.is_ready.store(true);
  192. shutdown_handler = [&](int) {
  193. ctx_http.stop();
  194. };
  195. } else {
  196. // setup clean up function, to be called before exit
  197. clean_up = [&ctx_http, &ctx_server]() {
  198. SRV_INF("%s: cleaning up before exit...\n", __func__);
  199. ctx_http.stop();
  200. ctx_server.terminate();
  201. llama_backend_free();
  202. };
  203. // start the HTTP server before loading the model to be able to serve /health requests
  204. if (!ctx_http.start()) {
  205. clean_up();
  206. LOG_ERR("%s: exiting due to HTTP server error\n", __func__);
  207. return 1;
  208. }
  209. // load the model
  210. LOG_INF("%s: loading model\n", __func__);
  211. if (!ctx_server.load_model(params)) {
  212. clean_up();
  213. if (ctx_http.thread.joinable()) {
  214. ctx_http.thread.join();
  215. }
  216. LOG_ERR("%s: exiting due to model loading error\n", __func__);
  217. return 1;
  218. }
  219. ctx_server.init();
  220. ctx_http.is_ready.store(true);
  221. LOG_INF("%s: model loaded\n", __func__);
  222. shutdown_handler = [&](int) {
  223. // this will unblock start_loop()
  224. ctx_server.terminate();
  225. };
  226. }
  227. // TODO: refactor in common/console
  228. #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
  229. struct sigaction sigint_action;
  230. sigint_action.sa_handler = signal_handler;
  231. sigemptyset (&sigint_action.sa_mask);
  232. sigint_action.sa_flags = 0;
  233. sigaction(SIGINT, &sigint_action, NULL);
  234. sigaction(SIGTERM, &sigint_action, NULL);
  235. #elif defined (_WIN32)
  236. auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
  237. return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
  238. };
  239. SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
  240. #endif
  241. if (is_router_server) {
  242. LOG_INF("%s: router server is listening on %s\n", __func__, ctx_http.listening_address.c_str());
  243. LOG_INF("%s: NOTE: router mode is experimental\n", __func__);
  244. LOG_INF("%s: it is not recommended to use this mode in untrusted environments\n", __func__);
  245. if (ctx_http.thread.joinable()) {
  246. ctx_http.thread.join(); // keep the main thread alive
  247. }
  248. // when the HTTP server stops, clean up and exit
  249. clean_up();
  250. } else {
  251. LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str());
  252. LOG_INF("%s: starting the main loop...\n", __func__);
  253. // optionally, notify router server that this instance is ready
  254. const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT");
  255. std::thread monitor_thread;
  256. if (router_port != nullptr) {
  257. monitor_thread = server_models::setup_child_server(shutdown_handler);
  258. }
  259. // this call blocks the main thread until queue_tasks.terminate() is called
  260. ctx_server.start_loop();
  261. clean_up();
  262. if (ctx_http.thread.joinable()) {
  263. ctx_http.thread.join();
  264. }
  265. if (monitor_thread.joinable()) {
  266. monitor_thread.join();
  267. }
  268. llama_memory_breakdown_print(ctx_server.get_llama_context());
  269. }
  270. return 0;
  271. }