server-http.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. #include "common.h"
  2. #include "server-http.h"
  3. #include "server-common.h"
  4. #include <cpp-httplib/httplib.h>
  5. #include <functional>
  6. #include <string>
  7. #include <thread>
  8. // auto generated files (see README.md for details)
  9. #include "index.html.gz.hpp"
  10. #include "loading.html.hpp"
  11. //
  12. // HTTP implementation using cpp-httplib
  13. //
  14. class server_http_context::Impl {
  15. public:
  16. std::unique_ptr<httplib::Server> srv;
  17. };
  18. server_http_context::server_http_context()
  19. : pimpl(std::make_unique<server_http_context::Impl>())
  20. {}
  21. server_http_context::~server_http_context() = default;
  22. static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
  23. // skip logging requests that are regularly sent, to avoid log spam
  24. if (req.path == "/health"
  25. || req.path == "/v1/health"
  26. || req.path == "/models"
  27. || req.path == "/v1/models"
  28. || req.path == "/props"
  29. || req.path == "/metrics"
  30. ) {
  31. return;
  32. }
  33. // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
  34. SRV_INF("done request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
  35. SRV_DBG("request: %s\n", req.body.c_str());
  36. SRV_DBG("response: %s\n", res.body.c_str());
  37. }
  38. bool server_http_context::init(const common_params & params) {
  39. path_prefix = params.api_prefix;
  40. port = params.port;
  41. hostname = params.hostname;
  42. auto & srv = pimpl->srv;
  43. #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
  44. if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
  45. LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
  46. srv.reset(
  47. new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str())
  48. );
  49. } else {
  50. LOG_INF("Running without SSL\n");
  51. srv.reset(new httplib::Server());
  52. }
  53. #else
  54. if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
  55. LOG_ERR("Server is built without SSL support\n");
  56. return false;
  57. }
  58. srv.reset(new httplib::Server());
  59. #endif
  60. srv->set_default_headers({{"Server", "llama.cpp"}});
  61. srv->set_logger(log_server_request);
  62. srv->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
  63. // this is fail-safe; exceptions should already handled by `ex_wrapper`
  64. std::string message;
  65. try {
  66. std::rethrow_exception(ep);
  67. } catch (const std::exception & e) {
  68. message = e.what();
  69. } catch (...) {
  70. message = "Unknown Exception";
  71. }
  72. res.status = 500;
  73. res.set_content(message, "text/plain");
  74. LOG_ERR("got exception: %s\n", message.c_str());
  75. });
  76. srv->set_error_handler([](const httplib::Request &, httplib::Response & res) {
  77. if (res.status == 404) {
  78. res.set_content(
  79. safe_json_to_str(json {
  80. {"error", {
  81. {"message", "File Not Found"},
  82. {"type", "not_found_error"},
  83. {"code", 404}
  84. }}
  85. }),
  86. "application/json; charset=utf-8"
  87. );
  88. }
  89. // for other error codes, we skip processing here because it's already done by res->error()
  90. });
  91. // set timeouts and change hostname and port
  92. srv->set_read_timeout (params.timeout_read);
  93. srv->set_write_timeout(params.timeout_write);
  94. if (params.api_keys.size() == 1) {
  95. auto key = params.api_keys[0];
  96. std::string substr = key.substr(std::max((int)(key.length() - 4), 0));
  97. LOG_INF("%s: api_keys: ****%s\n", __func__, substr.c_str());
  98. } else if (params.api_keys.size() > 1) {
  99. LOG_INF("%s: api_keys: %zu keys loaded\n", __func__, params.api_keys.size());
  100. }
  101. //
  102. // Middlewares
  103. //
  104. auto middleware_validate_api_key = [api_keys = params.api_keys](const httplib::Request & req, httplib::Response & res) {
  105. static const std::unordered_set<std::string> public_endpoints = {
  106. "/health",
  107. "/v1/health",
  108. "/models",
  109. "/v1/models",
  110. "/api/tags"
  111. };
  112. // If API key is not set, skip validation
  113. if (api_keys.empty()) {
  114. return true;
  115. }
  116. // If path is public or is static file, skip validation
  117. if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") {
  118. return true;
  119. }
  120. // Check for API key in the Authorization header
  121. std::string req_api_key = req.get_header_value("Authorization");
  122. if (req_api_key.empty()) {
  123. // retry with anthropic header
  124. req_api_key = req.get_header_value("X-Api-Key");
  125. }
  126. // remove the "Bearer " prefix if needed
  127. std::string prefix = "Bearer ";
  128. if (req_api_key.substr(0, prefix.size()) == prefix) {
  129. req_api_key = req_api_key.substr(prefix.size());
  130. }
  131. // validate the API key
  132. if (std::find(api_keys.begin(), api_keys.end(), req_api_key) != api_keys.end()) {
  133. return true; // API key is valid
  134. }
  135. // API key is invalid or not provided
  136. res.status = 401;
  137. res.set_content(
  138. safe_json_to_str(json {
  139. {"error", {
  140. {"message", "Invalid API Key"},
  141. {"type", "authentication_error"},
  142. {"code", 401}
  143. }}
  144. }),
  145. "application/json; charset=utf-8"
  146. );
  147. LOG_WRN("Unauthorized: Invalid API Key\n");
  148. return false;
  149. };
  150. auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) {
  151. bool ready = is_ready.load();
  152. if (!ready) {
  153. auto tmp = string_split<std::string>(req.path, '.');
  154. if (req.path == "/" || tmp.back() == "html") {
  155. res.status = 503;
  156. res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
  157. } else {
  158. // no endpoints is allowed to be accessed when the server is not ready
  159. // this is to prevent any data races or inconsistent states
  160. res.status = 503;
  161. res.set_content(
  162. safe_json_to_str(json {
  163. {"error", {
  164. {"message", "Loading model"},
  165. {"type", "unavailable_error"},
  166. {"code", 503}
  167. }}
  168. }),
  169. "application/json; charset=utf-8"
  170. );
  171. }
  172. return false;
  173. }
  174. return true;
  175. };
  176. // register server middlewares
  177. srv->set_pre_routing_handler([middleware_validate_api_key, middleware_server_state](const httplib::Request & req, httplib::Response & res) {
  178. res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
  179. // If this is OPTIONS request, skip validation because browsers don't include Authorization header
  180. if (req.method == "OPTIONS") {
  181. res.set_header("Access-Control-Allow-Credentials", "true");
  182. res.set_header("Access-Control-Allow-Methods", "GET, POST");
  183. res.set_header("Access-Control-Allow-Headers", "*");
  184. res.set_content("", "text/html"); // blank response, no data
  185. return httplib::Server::HandlerResponse::Handled; // skip further processing
  186. }
  187. if (!middleware_server_state(req, res)) {
  188. return httplib::Server::HandlerResponse::Handled;
  189. }
  190. if (!middleware_validate_api_key(req, res)) {
  191. return httplib::Server::HandlerResponse::Handled;
  192. }
  193. return httplib::Server::HandlerResponse::Unhandled;
  194. });
  195. int n_threads_http = params.n_threads_http;
  196. if (n_threads_http < 1) {
  197. // +2 threads for monitoring endpoints
  198. n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
  199. }
  200. LOG_INF("%s: using %d threads for HTTP server\n", __func__, n_threads_http);
  201. srv->new_task_queue = [n_threads_http] { return new httplib::ThreadPool(n_threads_http); };
  202. //
  203. // Web UI setup
  204. //
  205. if (!params.webui) {
  206. LOG_INF("Web UI is disabled\n");
  207. } else {
  208. // register static assets routes
  209. if (!params.public_path.empty()) {
  210. // Set the base directory for serving static files
  211. bool is_found = srv->set_mount_point(params.api_prefix + "/", params.public_path);
  212. if (!is_found) {
  213. LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
  214. return 1;
  215. }
  216. } else {
  217. // using embedded static index.html
  218. srv->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) {
  219. if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
  220. res.set_content("Error: gzip is not supported by this browser", "text/plain");
  221. } else {
  222. res.set_header("Content-Encoding", "gzip");
  223. // COEP and COOP headers, required by pyodide (python interpreter)
  224. res.set_header("Cross-Origin-Embedder-Policy", "require-corp");
  225. res.set_header("Cross-Origin-Opener-Policy", "same-origin");
  226. res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
  227. }
  228. return false;
  229. });
  230. }
  231. }
  232. return true;
  233. }
  234. bool server_http_context::start() {
  235. // Bind and listen
  236. auto & srv = pimpl->srv;
  237. bool was_bound = false;
  238. bool is_sock = false;
  239. if (string_ends_with(std::string(hostname), ".sock")) {
  240. is_sock = true;
  241. LOG_INF("%s: setting address family to AF_UNIX\n", __func__);
  242. srv->set_address_family(AF_UNIX);
  243. // bind_to_port requires a second arg, any value other than 0 should
  244. // simply get ignored
  245. was_bound = srv->bind_to_port(hostname, 8080);
  246. } else {
  247. LOG_INF("%s: binding port with default address family\n", __func__);
  248. // bind HTTP listen port
  249. if (port == 0) {
  250. int bound_port = srv->bind_to_any_port(hostname);
  251. was_bound = (bound_port >= 0);
  252. if (was_bound) {
  253. port = bound_port;
  254. }
  255. } else {
  256. was_bound = srv->bind_to_port(hostname, port);
  257. }
  258. }
  259. if (!was_bound) {
  260. LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, hostname.c_str(), port);
  261. return false;
  262. }
  263. // run the HTTP server in a thread
  264. thread = std::thread([this]() { pimpl->srv->listen_after_bind(); });
  265. srv->wait_until_ready();
  266. listening_address = is_sock ? string_format("unix://%s", hostname.c_str())
  267. : string_format("http://%s:%d", hostname.c_str(), port);
  268. return true;
  269. }
  270. void server_http_context::stop() const {
  271. if (pimpl->srv) {
  272. pimpl->srv->stop();
  273. }
  274. }
  275. static void set_headers(httplib::Response & res, const std::map<std::string, std::string> & headers) {
  276. for (const auto & [key, value] : headers) {
  277. res.set_header(key, value);
  278. }
  279. }
  280. static std::map<std::string, std::string> get_params(const httplib::Request & req) {
  281. std::map<std::string, std::string> params;
  282. for (const auto & [key, value] : req.params) {
  283. params[key] = value;
  284. }
  285. for (const auto & [key, value] : req.path_params) {
  286. params[key] = value;
  287. }
  288. return params;
  289. }
  290. static std::map<std::string, std::string> get_headers(const httplib::Request & req) {
  291. std::map<std::string, std::string> headers;
  292. for (const auto & [key, value] : req.headers) {
  293. headers[key] = value;
  294. }
  295. return headers;
  296. }
  297. // using unique_ptr for request to allow safe capturing in lambdas
  298. using server_http_req_ptr = std::unique_ptr<server_http_req>;
  299. static void process_handler_response(server_http_req_ptr && request, server_http_res_ptr & response, httplib::Response & res) {
  300. if (response->is_stream()) {
  301. res.status = response->status;
  302. set_headers(res, response->headers);
  303. std::string content_type = response->content_type;
  304. // convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it
  305. std::shared_ptr<server_http_req> q_ptr = std::move(request);
  306. std::shared_ptr<server_http_res> r_ptr = std::move(response);
  307. const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool {
  308. std::string chunk;
  309. bool has_next = response->next(chunk);
  310. if (!chunk.empty()) {
  311. // TODO: maybe handle sink.write unsuccessful? for now, we rely on is_connection_closed()
  312. sink.write(chunk.data(), chunk.size());
  313. SRV_DBG("http: streamed chunk: %s\n", chunk.c_str());
  314. }
  315. if (!has_next) {
  316. sink.done();
  317. SRV_DBG("%s", "http: stream ended\n");
  318. }
  319. return has_next;
  320. };
  321. const auto on_complete = [request = q_ptr, response = r_ptr](bool) mutable {
  322. response.reset(); // trigger the destruction of the response object
  323. request.reset(); // trigger the destruction of the request object
  324. };
  325. res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete);
  326. } else {
  327. res.status = response->status;
  328. set_headers(res, response->headers);
  329. res.set_content(response->data, response->content_type);
  330. }
  331. }
  332. void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const {
  333. pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
  334. server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
  335. get_params(req),
  336. get_headers(req),
  337. req.path,
  338. req.body,
  339. req.is_connection_closed
  340. });
  341. server_http_res_ptr response = handler(*request);
  342. process_handler_response(std::move(request), response, res);
  343. });
  344. }
  345. void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
  346. pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
  347. server_http_req_ptr request = std::make_unique<server_http_req>(server_http_req{
  348. get_params(req),
  349. get_headers(req),
  350. req.path,
  351. req.body,
  352. req.is_connection_closed
  353. });
  354. server_http_res_ptr response = handler(*request);
  355. process_handler_response(std::move(request), response, res);
  356. });
  357. }