|
@@ -39,7 +39,7 @@ using json = nlohmann::json;
|
|
|
struct server_params
|
|
struct server_params
|
|
|
{
|
|
{
|
|
|
std::string hostname = "127.0.0.1";
|
|
std::string hostname = "127.0.0.1";
|
|
|
- std::string api_key;
|
|
|
|
|
|
|
+ std::vector<std::string> api_keys;
|
|
|
std::string public_path = "examples/server/public";
|
|
std::string public_path = "examples/server/public";
|
|
|
int32_t port = 8080;
|
|
int32_t port = 8080;
|
|
|
int32_t read_timeout = 600;
|
|
int32_t read_timeout = 600;
|
|
@@ -2021,6 +2021,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
|
|
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
|
|
printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
|
|
|
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
|
|
printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str());
|
|
|
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
|
|
printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
|
|
|
|
|
+ printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
|
|
|
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
|
printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
|
|
|
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
|
|
printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
|
|
|
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
|
|
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
|
|
@@ -2081,7 +2082,28 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
|
|
|
invalid_param = true;
|
|
invalid_param = true;
|
|
|
break;
|
|
break;
|
|
|
}
|
|
}
|
|
|
- sparams.api_key = argv[i];
|
|
|
|
|
|
|
+ sparams.api_keys.push_back(argv[i]);
|
|
|
|
|
+ }
|
|
|
|
|
+ else if (arg == "--api-key-file")
|
|
|
|
|
+ {
|
|
|
|
|
+ if (++i >= argc)
|
|
|
|
|
+ {
|
|
|
|
|
+ invalid_param = true;
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ std::ifstream key_file(argv[i]);
|
|
|
|
|
+ if (!key_file) {
|
|
|
|
|
+ fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
|
|
|
|
|
+ invalid_param = true;
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ std::string key;
|
|
|
|
|
+ while (std::getline(key_file, key)) {
|
|
|
|
|
+ if (key.size() > 0) {
|
|
|
|
|
+ sparams.api_keys.push_back(key);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ key_file.close();
|
|
|
}
|
|
}
|
|
|
else if (arg == "--timeout" || arg == "-to")
|
|
else if (arg == "--timeout" || arg == "-to")
|
|
|
{
|
|
{
|
|
@@ -2881,8 +2903,10 @@ int main(int argc, char **argv)
|
|
|
log_data["hostname"] = sparams.hostname;
|
|
log_data["hostname"] = sparams.hostname;
|
|
|
log_data["port"] = std::to_string(sparams.port);
|
|
log_data["port"] = std::to_string(sparams.port);
|
|
|
|
|
|
|
|
- if (!sparams.api_key.empty()) {
|
|
|
|
|
- log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4);
|
|
|
|
|
|
|
+ if (sparams.api_keys.size() == 1) {
|
|
|
|
|
+ log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4);
|
|
|
|
|
+ } else if (sparams.api_keys.size() > 1) {
|
|
|
|
|
+ log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
LOG_INFO("HTTP server listening", log_data);
|
|
LOG_INFO("HTTP server listening", log_data);
|
|
@@ -2912,7 +2936,7 @@ int main(int argc, char **argv)
|
|
|
// Middleware for API key validation
|
|
// Middleware for API key validation
|
|
|
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
|
|
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
|
|
|
// If API key is not set, skip validation
|
|
// If API key is not set, skip validation
|
|
|
- if (sparams.api_key.empty()) {
|
|
|
|
|
|
|
+ if (sparams.api_keys.empty()) {
|
|
|
return true;
|
|
return true;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -2921,7 +2945,7 @@ int main(int argc, char **argv)
|
|
|
std::string prefix = "Bearer ";
|
|
std::string prefix = "Bearer ";
|
|
|
if (auth_header.substr(0, prefix.size()) == prefix) {
|
|
if (auth_header.substr(0, prefix.size()) == prefix) {
|
|
|
std::string received_api_key = auth_header.substr(prefix.size());
|
|
std::string received_api_key = auth_header.substr(prefix.size());
|
|
|
- if (received_api_key == sparams.api_key) {
|
|
|
|
|
|
|
+ if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) {
|
|
|
return true; // API key is valid
|
|
return true; // API key is valid
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|