Ver código fonte

server: split HTTP into its own interface (#17216)

* server: split HTTP into its own interface

* move server-http and httplib to its own file

* add the remaining endpoints

* fix exception/error handling

* renaming

* missing header

* fix missing windows header

* fix error responses from http layer

* fix slot save/restore handler

* fix case where only one stream chunk is returned

* add NOMINMAX

* do not call sink.write on empty data

* use safe_json_to_str for SSE

* clean up

* add some comments

* improve usage of next()

* bring back the "server is listening on" message

* more generic handler

* add req.headers

* move the chat template print to init()

* add req.path

* cont : minor

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Xuan-Son Nguyen 2 meses atrás
pai
commit
0de8878c96

+ 2 - 0
tools/server/CMakeLists.txt

@@ -14,6 +14,8 @@ endif()
 set(TARGET_SRCS
     server.cpp
     utils.hpp
+    server-http.cpp
+    server-http.h
 )
 set(PUBLIC_ASSETS
     index.html.gz

+ 386 - 0
tools/server/server-http.cpp

@@ -0,0 +1,386 @@
+#include "utils.hpp"
+#include "common.h"
+#include "server-http.h"
+
+#include <cpp-httplib/httplib.h>
+
+#include <functional>
+#include <string>
+#include <thread>
+
+// auto generated files (see README.md for details)
+#include "index.html.gz.hpp"
+#include "loading.html.hpp"
+
+//
+// HTTP implementation using cpp-httplib
+//
+
+class server_http_context::Impl {
+public:
+    std::unique_ptr<httplib::Server> srv;
+};
+
+server_http_context::server_http_context()
+    : pimpl(std::make_unique<server_http_context::Impl>())
+{}
+
+server_http_context::~server_http_context() = default;
+
+static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
+    // skip GH copilot requests when using default port
+    if (req.path == "/v1/health") {
+        return;
+    }
+
+    // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch
+
+    SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
+
+    SRV_DBG("request:  %s\n", req.body.c_str());
+    SRV_DBG("response: %s\n", res.body.c_str());
+}
+
+bool server_http_context::init(const common_params & params) {
+    path_prefix = params.api_prefix;
+    port = params.port;
+    hostname = params.hostname;
+
+#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
+    if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
+        LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
+        svr.reset(
+            new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str())
+        );
+    } else {
+        LOG_INF("Running without SSL\n");
+        svr.reset(new httplib::Server());
+    }
+#else
+    if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
+        LOG_ERR("Server is built without SSL support\n");
+        return false;
+    }
+    pimpl->srv.reset(new httplib::Server());
+#endif
+
+    auto & srv = pimpl->srv;
+    srv->set_default_headers({{"Server", "llama.cpp"}});
+    srv->set_logger(log_server_request);
+    srv->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
+        // this is fail-safe; exceptions should already handled by `ex_wrapper`
+
+        std::string message;
+        try {
+            std::rethrow_exception(ep);
+        } catch (const std::exception & e) {
+            message = e.what();
+        } catch (...) {
+            message = "Unknown Exception";
+        }
+
+        res.status = 500;
+        res.set_content(message, "text/plain");
+        LOG_ERR("got exception: %s\n", message.c_str());
+    });
+
+    srv->set_error_handler([](const httplib::Request &, httplib::Response & res) {
+        if (res.status == 404) {
+            res.set_content(
+                safe_json_to_str(json {
+                    {"error", {
+                        {"message", "File Not Found"},
+                        {"type", "not_found_error"},
+                        {"code", 404}
+                    }}
+                }),
+                "application/json; charset=utf-8"
+            );
+        }
+        // for other error codes, we skip processing here because it's already done by res->error()
+    });
+
+    // set timeouts and change hostname and port
+    srv->set_read_timeout (params.timeout_read);
+    srv->set_write_timeout(params.timeout_write);
+
+    if (params.api_keys.size() == 1) {
+        auto key = params.api_keys[0];
+        std::string substr = key.substr(std::max((int)(key.length() - 4), 0));
+        LOG_INF("%s: api_keys: ****%s\n", __func__, substr.c_str());
+    } else if (params.api_keys.size() > 1) {
+        LOG_INF("%s: api_keys: %zu keys loaded\n", __func__, params.api_keys.size());
+    }
+
+    //
+    // Middlewares
+    //
+
+    auto middleware_validate_api_key = [api_keys = params.api_keys](const httplib::Request & req, httplib::Response & res) {
+        static const std::unordered_set<std::string> public_endpoints = {
+            "/health",
+            "/v1/health",
+            "/models",
+            "/v1/models",
+            "/api/tags"
+        };
+
+        // If API key is not set, skip validation
+        if (api_keys.empty()) {
+            return true;
+        }
+
+        // If path is public or is static file, skip validation
+        if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") {
+            return true;
+        }
+
+        // Check for API key in the header
+        auto auth_header = req.get_header_value("Authorization");
+
+        std::string prefix = "Bearer ";
+        if (auth_header.substr(0, prefix.size()) == prefix) {
+            std::string received_api_key = auth_header.substr(prefix.size());
+            if (std::find(api_keys.begin(), api_keys.end(), received_api_key) != api_keys.end()) {
+                return true; // API key is valid
+            }
+        }
+
+        // API key is invalid or not provided
+        res.status = 401;
+        res.set_content(
+            safe_json_to_str(json {
+                {"error", {
+                    {"message", "Invalid API Key"},
+                    {"type", "authentication_error"},
+                    {"code", 401}
+                }}
+            }),
+            "application/json; charset=utf-8"
+        );
+
+        LOG_WRN("Unauthorized: Invalid API Key\n");
+
+        return false;
+    };
+
+    auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) {
+        bool ready = is_ready.load();
+        if (!ready) {
+            auto tmp = string_split<std::string>(req.path, '.');
+            if (req.path == "/" || tmp.back() == "html") {
+                res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
+                res.status = 503;
+            } else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") {
+                // allow the models endpoint to be accessed during loading
+                return true;
+            } else {
+                res.status = 503;
+                res.set_content(
+                    safe_json_to_str(json {
+                        {"error", {
+                            {"message", "Loading model"},
+                            {"type", "unavailable_error"},
+                            {"code", 503}
+                        }}
+                    }),
+                    "application/json; charset=utf-8"
+                );
+            }
+            return false;
+        }
+        return true;
+    };
+
+    // register server middlewares
+    srv->set_pre_routing_handler([middleware_validate_api_key, middleware_server_state](const httplib::Request & req, httplib::Response & res) {
+        res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
+        // If this is OPTIONS request, skip validation because browsers don't include Authorization header
+        if (req.method == "OPTIONS") {
+            res.set_header("Access-Control-Allow-Credentials", "true");
+            res.set_header("Access-Control-Allow-Methods",     "GET, POST");
+            res.set_header("Access-Control-Allow-Headers",     "*");
+            res.set_content("", "text/html"); // blank response, no data
+            return httplib::Server::HandlerResponse::Handled; // skip further processing
+        }
+        if (!middleware_server_state(req, res)) {
+            return httplib::Server::HandlerResponse::Handled;
+        }
+        if (!middleware_validate_api_key(req, res)) {
+            return httplib::Server::HandlerResponse::Handled;
+        }
+        return httplib::Server::HandlerResponse::Unhandled;
+    });
+
+    int n_threads_http = params.n_threads_http;
+    if (n_threads_http < 1) {
+        // +2 threads for monitoring endpoints
+        n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
+    }
+    LOG_INF("%s: using %d threads for HTTP server\n", __func__, n_threads_http);
+    srv->new_task_queue = [n_threads_http] { return new httplib::ThreadPool(n_threads_http); };
+
+    //
+    // Web UI setup
+    //
+
+    if (!params.webui) {
+        LOG_INF("Web UI is disabled\n");
+    } else {
+        // register static assets routes
+        if (!params.public_path.empty()) {
+            // Set the base directory for serving static files
+            bool is_found = srv->set_mount_point(params.api_prefix + "/", params.public_path);
+            if (!is_found) {
+                LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
+                return 1;
+            }
+        } else {
+            // using embedded static index.html
+            srv->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) {
+                if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
+                    res.set_content("Error: gzip is not supported by this browser", "text/plain");
+                } else {
+                    res.set_header("Content-Encoding", "gzip");
+                    // COEP and COOP headers, required by pyodide (python interpreter)
+                    res.set_header("Cross-Origin-Embedder-Policy", "require-corp");
+                    res.set_header("Cross-Origin-Opener-Policy", "same-origin");
+                    res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
+                }
+                return false;
+            });
+        }
+    }
+    return true;
+}
+
+bool server_http_context::start() {
+    // Bind and listen
+
+    auto & srv = pimpl->srv;
+    bool was_bound = false;
+    bool is_sock = false;
+    if (string_ends_with(std::string(hostname), ".sock")) {
+        is_sock = true;
+        LOG_INF("%s: setting address family to AF_UNIX\n", __func__);
+        srv->set_address_family(AF_UNIX);
+        // bind_to_port requires a second arg, any value other than 0 should
+        // simply get ignored
+        was_bound = srv->bind_to_port(hostname, 8080);
+    } else {
+        LOG_INF("%s: binding port with default address family\n", __func__);
+        // bind HTTP listen port
+        if (port == 0) {
+            int bound_port = srv->bind_to_any_port(hostname);
+            was_bound = (bound_port >= 0);
+            if (was_bound) {
+                port = bound_port;
+            }
+        } else {
+            was_bound = srv->bind_to_port(hostname, port);
+        }
+    }
+
+    if (!was_bound) {
+        LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, hostname.c_str(), port);
+        return false;
+    }
+
+    // run the HTTP server in a thread
+    thread = std::thread([this]() { pimpl->srv->listen_after_bind(); });
+    srv->wait_until_ready();
+
+    listening_address = is_sock ? string_format("unix://%s",    hostname.c_str())
+                                : string_format("http://%s:%d", hostname.c_str(), port);
+    return true;
+}
+
+void server_http_context::stop() const {
+    if (pimpl->srv) {
+        pimpl->srv->stop();
+    }
+}
+
+static void set_headers(httplib::Response & res, const std::map<std::string, std::string> & headers) {
+    for (const auto & [key, value] : headers) {
+        res.set_header(key, value);
+    }
+}
+
+static std::map<std::string, std::string> get_params(const httplib::Request & req) {
+    std::map<std::string, std::string> params;
+    for (const auto & [key, value] : req.params) {
+        params[key] = value;
+    }
+    for (const auto & [key, value] : req.path_params) {
+        params[key] = value;
+    }
+    return params;
+}
+
+static std::map<std::string, std::string> get_headers(const httplib::Request & req) {
+    std::map<std::string, std::string> headers;
+    for (const auto & [key, value] : req.headers) {
+        headers[key] = value;
+    }
+    return headers;
+}
+
+static void process_handler_response(server_http_res_ptr & response, httplib::Response & res) {
+    if (response->is_stream()) {
+        res.status = response->status;
+        set_headers(res, response->headers);
+        std::string content_type = response->content_type;
+        // convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it
+        std::shared_ptr<server_http_res> r_ptr = std::move(response);
+        const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool {
+            std::string chunk;
+            bool has_next = response->next(chunk);
+            if (!chunk.empty()) {
+                // TODO: maybe handle sink.write unsuccessful? for now, we rely on is_connection_closed()
+                sink.write(chunk.data(), chunk.size());
+                SRV_DBG("http: streamed chunk: %s\n", chunk.c_str());
+            }
+            if (!has_next) {
+                sink.done();
+                SRV_DBG("%s", "http: stream ended\n");
+            }
+            return has_next;
+        };
+        const auto on_complete = [response = r_ptr](bool) mutable {
+            response.reset(); // trigger the destruction of the response object
+        };
+        res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete);
+    } else {
+        res.status = response->status;
+        set_headers(res, response->headers);
+        res.set_content(response->data, response->content_type);
+    }
+}
+
+void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const {
+    pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
+        server_http_res_ptr response = handler(server_http_req{
+            get_params(req),
+            get_headers(req),
+            req.path,
+            req.body,
+            req.is_connection_closed
+        });
+        process_handler_response(response, res);
+    });
+}
+
+void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const {
+    pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) {
+        server_http_res_ptr response = handler(server_http_req{
+            get_params(req),
+            get_headers(req),
+            req.path,
+            req.body,
+            req.is_connection_closed
+        });
+        process_handler_response(response, res);
+    });
+}
+

+ 78 - 0
tools/server/server-http.h

@@ -0,0 +1,78 @@
+#pragma once
+
+#include <atomic>
+#include <functional>
+#include <map>
+#include <string>
+#include <thread>
+
+struct common_params;
+
+// generator-like API for HTTP response generation
+// this object response with one of the 2 modes:
+// 1) normal response: `data` contains the full response body
+// 2) streaming response: each call to next(output) generates the next chunk
+//    when next(output) returns false, no more data after the current chunk
+//    note: some chunks can be empty, in which case no data is sent for that chunk
+struct server_http_res {
+    std::string content_type = "application/json; charset=utf-8";
+    int status = 200;
+    std::string data;
+    std::map<std::string, std::string> headers;
+
+    // TODO: move this to a virtual function once we have proper polymorphism support
+    std::function<bool(std::string &)> next = nullptr;
+    bool is_stream() const {
+        return next != nullptr;
+    }
+
+    virtual ~server_http_res() = default;
+};
+
+// unique pointer, used by set_chunked_content_provider
+// httplib requires the stream provider to be stored in heap
+using server_http_res_ptr = std::unique_ptr<server_http_res>;
+
+struct server_http_req {
+    std::map<std::string, std::string> params; // path_params + query_params
+    std::map<std::string, std::string> headers; // reserved for future use
+    std::string path; // reserved for future use
+    std::string body;
+    const std::function<bool()> & should_stop;
+
+    std::string get_param(const std::string & key, const std::string & def = "") const {
+        auto it = params.find(key);
+        if (it != params.end()) {
+            return it->second;
+        }
+        return def;
+    }
+};
+
+struct server_http_context {
+    class Impl;
+    std::unique_ptr<Impl> pimpl;
+
+    std::thread thread; // server thread
+    std::atomic<bool> is_ready = false;
+
+    std::string path_prefix;
+    std::string hostname;
+    int port;
+
+    server_http_context();
+    ~server_http_context();
+
+    bool init(const common_params & params);
+    bool start();
+    void stop() const;
+
+    // note: the handler should never throw exceptions
+    using handler_t = std::function<server_http_res_ptr(const server_http_req & req)>;
+
+    void get(const std::string & path, const handler_t & handler) const;
+    void post(const std::string & path, const handler_t & handler) const;
+
+    // for debugging
+    std::string listening_address;
+};

Diferenças do arquivo suprimidas por serem muito extensas
+ 175 - 696
tools/server/server.cpp


+ 13 - 19
tools/server/utils.hpp

@@ -9,8 +9,6 @@
 #include "mtmd-helper.h"
 #include "chat.h"
 
-#include <cpp-httplib/httplib.h>
-
 #define JSON_ASSERT GGML_ASSERT
 #include <nlohmann/json.hpp>
 
@@ -426,6 +424,10 @@ static std::string gen_tool_call_id() {
 // other common utils
 //
 
+static std::string safe_json_to_str(const json & data) {
+    return data.dump(-1, ' ', false, json::error_handler_t::replace);
+}
+
 // TODO: reuse llama_detokenize
 template <class Iter>
 static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
@@ -453,29 +455,25 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx,
     return out;
 }
 
+// format server-sent event (SSE), return the formatted string to send
 // note: if data is a json array, it will be sent as multiple events, one per item
-static bool server_sent_event(httplib::DataSink & sink, const json & data) {
-    static auto send_single = [](httplib::DataSink & sink, const json & data) -> bool {
-        const std::string str =
-            "data: " +
-            data.dump(-1, ' ', false, json::error_handler_t::replace) +
+static std::string format_sse(const json & data) {
+    std::ostringstream ss;
+    auto send_single = [&ss](const json & data) {
+        ss << "data: " <<
+            safe_json_to_str(data) <<
             "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
-
-        LOG_DBG("data stream, to_send: %s", str.c_str());
-        return sink.write(str.c_str(), str.size());
     };
 
     if (data.is_array()) {
         for (const auto & item : data) {
-            if (!send_single(sink, item)) {
-                return false;
-            }
+            send_single(item);
         }
     } else {
-        return send_single(sink, data);
+        send_single(data);
     }
 
-    return true;
+    return ss.str();
 }
 
 //
@@ -954,10 +952,6 @@ static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias)
     return data;
 }
 
-static std::string safe_json_to_str(const json & data) {
-    return data.dump(-1, ' ', false, json::error_handler_t::replace);
-}
-
 static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
     std::vector<llama_token_data> cur;
     const auto * logits = llama_get_logits_ith(ctx, idx);

Alguns arquivos não foram mostrados porque muitos arquivos mudaram nesse diff