server-models.cpp 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920
  1. #include "server-common.h"
  2. #include "server-models.h"
  3. #include "download.h"
  4. #include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
  5. #include <sheredom/subprocess.h>
  6. #include <functional>
  7. #include <thread>
  8. #include <mutex>
  9. #include <condition_variable>
  10. #include <cstring>
  11. #include <atomic>
  12. #include <chrono>
  13. #include <queue>
  14. #ifdef _WIN32
  15. #include <winsock2.h>
  16. #else
  17. #include <sys/socket.h>
  18. #include <netinet/in.h>
  19. #include <arpa/inet.h>
  20. #include <unistd.h>
  21. #endif
  22. #define CMD_EXIT "exit"
  23. struct local_model {
  24. std::string name;
  25. std::string path;
  26. std::string path_mmproj;
  27. };
  28. static std::vector<local_model> list_local_models(const std::string & dir) {
  29. if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) {
  30. throw std::runtime_error(string_format("error: '%s' does not exist or is not a directory\n", dir.c_str()));
  31. }
  32. std::vector<local_model> models;
  33. auto scan_subdir = [&models](const std::string & subdir_path, const std::string & name) {
  34. auto files = fs_list(subdir_path, false);
  35. common_file_info model_file;
  36. common_file_info first_shard_file;
  37. common_file_info mmproj_file;
  38. for (const auto & file : files) {
  39. if (string_ends_with(file.name, ".gguf")) {
  40. if (file.name.find("mmproj") != std::string::npos) {
  41. mmproj_file = file;
  42. } else if (file.name.find("-00001-of-") != std::string::npos) {
  43. first_shard_file = file;
  44. } else {
  45. model_file = file;
  46. }
  47. }
  48. }
  49. // single file model
  50. local_model model{
  51. /* name */ name,
  52. /* path */ first_shard_file.path.empty() ? model_file.path : first_shard_file.path,
  53. /* path_mmproj */ mmproj_file.path // can be empty
  54. };
  55. if (!model.path.empty()) {
  56. models.push_back(model);
  57. }
  58. };
  59. auto files = fs_list(dir, true);
  60. for (const auto & file : files) {
  61. if (file.is_dir) {
  62. scan_subdir(file.path, file.name);
  63. } else if (string_ends_with(file.name, ".gguf")) {
  64. // single file model
  65. std::string name = file.name;
  66. string_replace_all(name, ".gguf", "");
  67. local_model model{
  68. /* name */ name,
  69. /* path */ file.path,
  70. /* path_mmproj */ ""
  71. };
  72. models.push_back(model);
  73. }
  74. }
  75. return models;
  76. }
  77. //
  78. // server_models
  79. //
  80. server_models::server_models(
  81. const common_params & params,
  82. int argc,
  83. char ** argv,
  84. char ** envp) : base_params(params) {
  85. for (int i = 0; i < argc; i++) {
  86. base_args.push_back(std::string(argv[i]));
  87. }
  88. for (char ** env = envp; *env != nullptr; env++) {
  89. base_env.push_back(std::string(*env));
  90. }
  91. // TODO: allow refreshing cached model list
  92. // add cached models
  93. auto cached_models = common_list_cached_models();
  94. for (const auto & model : cached_models) {
  95. server_model_meta meta{
  96. /* name */ model.to_string(),
  97. /* path */ model.manifest_path,
  98. /* path_mmproj */ "", // auto-detected when loading
  99. /* in_cache */ true,
  100. /* port */ 0,
  101. /* status */ SERVER_MODEL_STATUS_UNLOADED,
  102. /* last_used */ 0,
  103. /* args */ std::vector<std::string>(),
  104. /* exit_code */ 0
  105. };
  106. mapping[meta.name] = instance_t{
  107. /* subproc */ std::make_shared<subprocess_s>(),
  108. /* th */ std::thread(),
  109. /* meta */ meta
  110. };
  111. }
  112. // add local models specificed via --models-dir
  113. if (!params.models_dir.empty()) {
  114. auto local_models = list_local_models(params.models_dir);
  115. for (const auto & model : local_models) {
  116. if (mapping.find(model.name) != mapping.end()) {
  117. // already exists in cached models, skip
  118. continue;
  119. }
  120. server_model_meta meta{
  121. /* name */ model.name,
  122. /* path */ model.path,
  123. /* path_mmproj */ model.path_mmproj,
  124. /* in_cache */ false,
  125. /* port */ 0,
  126. /* status */ SERVER_MODEL_STATUS_UNLOADED,
  127. /* last_used */ 0,
  128. /* args */ std::vector<std::string>(),
  129. /* exit_code */ 0
  130. };
  131. mapping[meta.name] = instance_t{
  132. /* subproc */ std::make_shared<subprocess_s>(),
  133. /* th */ std::thread(),
  134. /* meta */ meta
  135. };
  136. }
  137. }
  138. }
  139. void server_models::update_meta(const std::string & name, const server_model_meta & meta) {
  140. std::lock_guard<std::mutex> lk(mutex);
  141. auto it = mapping.find(name);
  142. if (it != mapping.end()) {
  143. it->second.meta = meta;
  144. }
  145. cv.notify_all(); // notify wait_until_loaded
  146. }
  147. bool server_models::has_model(const std::string & name) {
  148. std::lock_guard<std::mutex> lk(mutex);
  149. return mapping.find(name) != mapping.end();
  150. }
  151. std::optional<server_model_meta> server_models::get_meta(const std::string & name) {
  152. std::lock_guard<std::mutex> lk(mutex);
  153. auto it = mapping.find(name);
  154. if (it != mapping.end()) {
  155. return it->second.meta;
  156. }
  157. return std::nullopt;
  158. }
  159. static int get_free_port() {
  160. #ifdef _WIN32
  161. WSADATA wsaData;
  162. if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
  163. return -1;
  164. }
  165. typedef SOCKET native_socket_t;
  166. #define INVALID_SOCKET_VAL INVALID_SOCKET
  167. #define CLOSE_SOCKET(s) closesocket(s)
  168. #else
  169. typedef int native_socket_t;
  170. #define INVALID_SOCKET_VAL -1
  171. #define CLOSE_SOCKET(s) close(s)
  172. #endif
  173. native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
  174. if (sock == INVALID_SOCKET_VAL) {
  175. #ifdef _WIN32
  176. WSACleanup();
  177. #endif
  178. return -1;
  179. }
  180. struct sockaddr_in serv_addr;
  181. std::memset(&serv_addr, 0, sizeof(serv_addr));
  182. serv_addr.sin_family = AF_INET;
  183. serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
  184. serv_addr.sin_port = htons(0);
  185. if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
  186. CLOSE_SOCKET(sock);
  187. #ifdef _WIN32
  188. WSACleanup();
  189. #endif
  190. return -1;
  191. }
  192. #ifdef _WIN32
  193. int namelen = sizeof(serv_addr);
  194. #else
  195. socklen_t namelen = sizeof(serv_addr);
  196. #endif
  197. if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
  198. CLOSE_SOCKET(sock);
  199. #ifdef _WIN32
  200. WSACleanup();
  201. #endif
  202. return -1;
  203. }
  204. int port = ntohs(serv_addr.sin_port);
  205. CLOSE_SOCKET(sock);
  206. #ifdef _WIN32
  207. WSACleanup();
  208. #endif
  209. return port;
  210. }
  211. // helper to convert vector<string> to char **
  212. // pointers are only valid as long as the original vector is valid
  213. static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & vec) {
  214. std::vector<char *> result;
  215. result.reserve(vec.size() + 1);
  216. for (const auto & s : vec) {
  217. result.push_back(const_cast<char*>(s.c_str()));
  218. }
  219. result.push_back(nullptr);
  220. return result;
  221. }
  222. std::vector<server_model_meta> server_models::get_all_meta() {
  223. std::lock_guard<std::mutex> lk(mutex);
  224. std::vector<server_model_meta> result;
  225. result.reserve(mapping.size());
  226. for (const auto & [name, inst] : mapping) {
  227. result.push_back(inst.meta);
  228. }
  229. return result;
  230. }
  231. void server_models::unload_lru() {
  232. if (base_params.models_max <= 0) {
  233. return; // no limit
  234. }
  235. // remove one of the servers if we passed the models_max (least recently used - LRU)
  236. std::string lru_model_name = "";
  237. int64_t lru_last_used = ggml_time_ms();
  238. size_t count_active = 0;
  239. {
  240. std::lock_guard<std::mutex> lk(mutex);
  241. for (const auto & m : mapping) {
  242. if (m.second.meta.is_active()) {
  243. count_active++;
  244. if (m.second.meta.last_used < lru_last_used) {
  245. lru_model_name = m.first;
  246. lru_last_used = m.second.meta.last_used;
  247. }
  248. }
  249. }
  250. }
  251. if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) {
  252. SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str());
  253. unload(lru_model_name);
  254. }
  255. }
  256. static void add_or_replace_arg(std::vector<std::string> & args, const std::string & key, const std::string & value) {
  257. for (size_t i = 0; i < args.size(); i++) {
  258. if (args[i] == key && i + 1 < args.size()) {
  259. args[i + 1] = value;
  260. return;
  261. }
  262. }
  263. // not found, append
  264. args.push_back(key);
  265. args.push_back(value);
  266. }
  267. void server_models::load(const std::string & name, bool auto_load) {
  268. if (!has_model(name)) {
  269. throw std::runtime_error("model name=" + name + " is not found");
  270. }
  271. unload_lru();
  272. std::lock_guard<std::mutex> lk(mutex);
  273. auto meta = mapping[name].meta;
  274. if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
  275. SRV_INF("model %s is not ready\n", name.c_str());
  276. return;
  277. }
  278. // prepare new instance info
  279. instance_t inst;
  280. inst.meta = meta;
  281. inst.meta.port = get_free_port();
  282. inst.meta.status = SERVER_MODEL_STATUS_LOADING;
  283. inst.meta.last_used = ggml_time_ms();
  284. if (inst.meta.port <= 0) {
  285. throw std::runtime_error("failed to get a port number");
  286. }
  287. inst.subproc = std::make_shared<subprocess_s>();
  288. {
  289. SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port);
  290. std::vector<std::string> child_args;
  291. if (auto_load && !meta.args.empty()) {
  292. child_args = meta.args; // copy previous args
  293. } else {
  294. child_args = base_args; // copy
  295. if (inst.meta.in_cache) {
  296. add_or_replace_arg(child_args, "-hf", inst.meta.name);
  297. } else {
  298. add_or_replace_arg(child_args, "-m", inst.meta.path);
  299. if (!inst.meta.path_mmproj.empty()) {
  300. add_or_replace_arg(child_args, "--mmproj", inst.meta.path_mmproj);
  301. }
  302. }
  303. }
  304. // set model args
  305. add_or_replace_arg(child_args, "--port", std::to_string(inst.meta.port));
  306. add_or_replace_arg(child_args, "--alias", inst.meta.name);
  307. std::vector<std::string> child_env = base_env; // copy
  308. child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));
  309. SRV_INF("%s", "spawning server instance with args:\n");
  310. for (const auto & arg : child_args) {
  311. SRV_INF(" %s\n", arg.c_str());
  312. }
  313. inst.meta.args = child_args; // save for debugging
  314. std::vector<char *> argv = to_char_ptr_array(child_args);
  315. std::vector<char *> envp = to_char_ptr_array(child_env);
  316. int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
  317. int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get());
  318. if (result != 0) {
  319. throw std::runtime_error("failed to spawn server instance");
  320. }
  321. inst.stdin_file = subprocess_stdin(inst.subproc.get());
  322. }
  323. // start a thread to manage the child process
  324. // captured variables are guaranteed to be destroyed only after the thread is joined
  325. inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port]() {
  326. // read stdout/stderr and forward to main server log
  327. FILE * p_stdout_stderr = subprocess_stdout(child_proc.get());
  328. if (p_stdout_stderr) {
  329. char buffer[4096];
  330. while (fgets(buffer, sizeof(buffer), p_stdout_stderr) != nullptr) {
  331. LOG("[%5d] %s", port, buffer);
  332. }
  333. } else {
  334. SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str());
  335. }
  336. // we reach here when the child process exits
  337. int exit_code = 0;
  338. subprocess_join(child_proc.get(), &exit_code);
  339. subprocess_destroy(child_proc.get());
  340. // update PID and status
  341. {
  342. std::lock_guard<std::mutex> lk(mutex);
  343. auto it = mapping.find(name);
  344. if (it != mapping.end()) {
  345. auto & meta = it->second.meta;
  346. meta.exit_code = exit_code;
  347. meta.status = SERVER_MODEL_STATUS_UNLOADED;
  348. }
  349. cv.notify_all();
  350. }
  351. SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
  352. });
  353. // clean up old process/thread if exists
  354. {
  355. auto & old_instance = mapping[name];
  356. // old process should have exited already, but just in case, we clean it up here
  357. if (subprocess_alive(old_instance.subproc.get())) {
  358. SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
  359. subprocess_terminate(old_instance.subproc.get()); // force kill
  360. }
  361. if (old_instance.th.joinable()) {
  362. old_instance.th.join();
  363. }
  364. }
  365. mapping[name] = std::move(inst);
  366. cv.notify_all();
  367. }
  368. static void interrupt_subprocess(FILE * stdin_file) {
  369. // because subprocess.h does not provide a way to send SIGINT,
  370. // we will send a command to the child process to exit gracefully
  371. if (stdin_file) {
  372. fprintf(stdin_file, "%s\n", CMD_EXIT);
  373. fflush(stdin_file);
  374. }
  375. }
  376. void server_models::unload(const std::string & name) {
  377. std::lock_guard<std::mutex> lk(mutex);
  378. auto it = mapping.find(name);
  379. if (it != mapping.end()) {
  380. if (it->second.meta.is_active()) {
  381. SRV_INF("unloading model instance name=%s\n", name.c_str());
  382. interrupt_subprocess(it->second.stdin_file);
  383. // status change will be handled by the managing thread
  384. } else {
  385. SRV_WRN("model instance name=%s is not loaded\n", name.c_str());
  386. }
  387. }
  388. }
  389. void server_models::unload_all() {
  390. std::vector<std::thread> to_join;
  391. {
  392. std::lock_guard<std::mutex> lk(mutex);
  393. for (auto & [name, inst] : mapping) {
  394. if (inst.meta.is_active()) {
  395. SRV_INF("unloading model instance name=%s\n", name.c_str());
  396. interrupt_subprocess(inst.stdin_file);
  397. // status change will be handled by the managing thread
  398. }
  399. // moving the thread to join list to avoid deadlock
  400. to_join.push_back(std::move(inst.th));
  401. }
  402. }
  403. for (auto & th : to_join) {
  404. if (th.joinable()) {
  405. th.join();
  406. }
  407. }
  408. }
  409. void server_models::update_status(const std::string & name, server_model_status status) {
  410. // for now, we only allow updating to LOADED status
  411. if (status != SERVER_MODEL_STATUS_LOADED) {
  412. throw std::runtime_error("invalid status value");
  413. }
  414. auto meta = get_meta(name);
  415. if (meta.has_value()) {
  416. meta->status = status;
  417. update_meta(name, meta.value());
  418. }
  419. }
  420. void server_models::wait_until_loaded(const std::string & name) {
  421. std::unique_lock<std::mutex> lk(mutex);
  422. cv.wait(lk, [this, &name]() {
  423. auto it = mapping.find(name);
  424. if (it != mapping.end()) {
  425. return it->second.meta.status != SERVER_MODEL_STATUS_LOADING;
  426. }
  427. return false;
  428. });
  429. }
  430. bool server_models::ensure_model_loaded(const std::string & name) {
  431. auto meta = get_meta(name);
  432. if (!meta.has_value()) {
  433. throw std::runtime_error("model name=" + name + " is not found");
  434. }
  435. if (meta->status == SERVER_MODEL_STATUS_LOADED) {
  436. return false; // already loaded
  437. }
  438. if (meta->status == SERVER_MODEL_STATUS_UNLOADED) {
  439. SRV_INF("model name=%s is not loaded, loading...\n", name.c_str());
  440. load(name, true);
  441. }
  442. SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str());
  443. wait_until_loaded(name);
  444. // check final status
  445. meta = get_meta(name);
  446. if (!meta.has_value() || meta->is_failed()) {
  447. throw std::runtime_error("model name=" + name + " failed to load");
  448. }
  449. return true;
  450. }
  451. server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used) {
  452. auto meta = get_meta(name);
  453. if (!meta.has_value()) {
  454. throw std::runtime_error("model name=" + name + " is not found");
  455. }
  456. if (meta->status != SERVER_MODEL_STATUS_LOADED) {
  457. throw std::invalid_argument("model name=" + name + " is not loaded");
  458. }
  459. if (update_last_used) {
  460. std::unique_lock<std::mutex> lk(mutex);
  461. mapping[name].meta.last_used = ggml_time_ms();
  462. }
  463. SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port);
  464. auto proxy = std::make_unique<server_http_proxy>(
  465. method,
  466. base_params.hostname,
  467. meta->port,
  468. req.path,
  469. req.headers,
  470. req.body,
  471. req.should_stop);
  472. return proxy;
  473. }
  474. std::thread server_models::setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function<void(int)> & shutdown_handler) {
  475. // send a notification to the router server that a model instance is ready
  476. // TODO @ngxson : use HTTP client from libcommon
  477. httplib::Client cli(base_params.hostname, router_port);
  478. cli.set_connection_timeout(0, 200000); // 200 milliseconds
  479. httplib::Request req;
  480. req.method = "POST";
  481. req.path = "/models/status";
  482. req.set_header("Content-Type", "application/json");
  483. if (!base_params.api_keys.empty()) {
  484. req.set_header("Authorization", "Bearer " + base_params.api_keys[0]);
  485. }
  486. json body;
  487. body["model"] = name;
  488. body["value"] = server_model_status_to_string(SERVER_MODEL_STATUS_LOADED);
  489. req.body = body.dump();
  490. SRV_INF("notifying router server (port=%d) that model %s is ready\n", router_port, name.c_str());
  491. auto result = cli.send(std::move(req));
  492. if (result.error() != httplib::Error::Success) {
  493. auto err_str = httplib::to_string(result.error());
  494. SRV_ERR("failed to notify router server: %s\n", err_str.c_str());
  495. exit(1); // force exit
  496. }
  497. // setup thread for monitoring stdin
  498. return std::thread([shutdown_handler]() {
  499. // wait for EOF on stdin
  500. SRV_INF("%s", "child server monitoring thread started, waiting for EOF on stdin...\n");
  501. bool eof = false;
  502. while (true) {
  503. std::string line;
  504. if (!std::getline(std::cin, line)) {
  505. // EOF detected, that means the router server is unexpectedly exit or killed
  506. eof = true;
  507. break;
  508. }
  509. if (line.find(CMD_EXIT) != std::string::npos) {
  510. SRV_INF("%s", "exit command received, exiting...\n");
  511. shutdown_handler(0);
  512. break;
  513. }
  514. }
  515. if (eof) {
  516. SRV_INF("%s", "EOF on stdin detected, forcing shutdown...\n");
  517. exit(1);
  518. }
  519. });
  520. }
  521. //
  522. // server_models_routes
  523. //
  524. static void res_ok(std::unique_ptr<server_http_res> & res, const json & response_data) {
  525. res->status = 200;
  526. res->data = safe_json_to_str(response_data);
  527. }
  528. static void res_error(std::unique_ptr<server_http_res> & res, const json & error_data) {
  529. res->status = json_value(error_data, "code", 500);
  530. res->data = safe_json_to_str({{ "error", error_data }});
  531. }
  532. static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr<server_http_res> & res) {
  533. if (name.empty()) {
  534. res_error(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST));
  535. return false;
  536. }
  537. auto meta = models.get_meta(name);
  538. if (!meta.has_value()) {
  539. res_error(res, format_error_response("model not found", ERROR_TYPE_INVALID_REQUEST));
  540. return false;
  541. }
  542. if (models_autoload) {
  543. models.ensure_model_loaded(name);
  544. } else {
  545. if (meta->status != SERVER_MODEL_STATUS_LOADED) {
  546. res_error(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
  547. return false;
  548. }
  549. }
  550. return true;
  551. }
  552. static bool is_autoload(const common_params & params, const server_http_req & req) {
  553. std::string autoload = req.get_param("autoload");
  554. if (autoload.empty()) {
  555. return params.models_autoload;
  556. } else {
  557. return autoload == "true" || autoload == "1";
  558. }
  559. }
  560. void server_models_routes::init_routes() {
  561. this->get_router_props = [this](const server_http_req & req) {
  562. std::string name = req.get_param("model");
  563. if (name.empty()) {
  564. // main instance
  565. auto res = std::make_unique<server_http_res>();
  566. res_ok(res, {
  567. // TODO: add support for this on web UI
  568. {"role", "router"},
  569. {"max_instances", 4}, // dummy value for testing
  570. // this is a dummy response to make sure webui doesn't break
  571. {"model_alias", "llama-server"},
  572. {"model_path", "none"},
  573. {"default_generation_settings", {
  574. {"params", json{}},
  575. {"n_ctx", 0},
  576. }},
  577. });
  578. return res;
  579. }
  580. return proxy_get(req);
  581. };
  582. this->proxy_get = [this](const server_http_req & req) {
  583. std::string method = "GET";
  584. std::string name = req.get_param("model");
  585. bool autoload = is_autoload(params, req);
  586. auto error_res = std::make_unique<server_http_res>();
  587. if (!router_validate_model(name, models, autoload, error_res)) {
  588. return error_res;
  589. }
  590. return models.proxy_request(req, method, name, false);
  591. };
  592. this->proxy_post = [this](const server_http_req & req) {
  593. std::string method = "POST";
  594. json body = json::parse(req.body);
  595. std::string name = json_value(body, "model", std::string());
  596. bool autoload = is_autoload(params, req);
  597. auto error_res = std::make_unique<server_http_res>();
  598. if (!router_validate_model(name, models, autoload, error_res)) {
  599. return error_res;
  600. }
  601. return models.proxy_request(req, method, name, true); // update last usage for POST request only
  602. };
  603. this->get_router_models = [this](const server_http_req &) {
  604. auto res = std::make_unique<server_http_res>();
  605. json models_json = json::array();
  606. auto all_models = models.get_all_meta();
  607. std::time_t t = std::time(0);
  608. for (const auto & meta : all_models) {
  609. json status {
  610. {"value", server_model_status_to_string(meta.status)},
  611. {"args", meta.args},
  612. };
  613. if (meta.is_failed()) {
  614. status["exit_code"] = meta.exit_code;
  615. status["failed"] = true;
  616. }
  617. models_json.push_back(json {
  618. {"id", meta.name},
  619. {"object", "model"}, // for OAI-compat
  620. {"owned_by", "llamacpp"}, // for OAI-compat
  621. {"created", t}, // for OAI-compat
  622. {"in_cache", meta.in_cache},
  623. {"path", meta.path},
  624. {"status", status},
  625. // TODO: add other fields, may require reading GGUF metadata
  626. });
  627. }
  628. res_ok(res, {
  629. {"data", models_json},
  630. {"object", "list"},
  631. });
  632. return res;
  633. };
  634. this->post_router_models_load = [this](const server_http_req & req) {
  635. auto res = std::make_unique<server_http_res>();
  636. json body = json::parse(req.body);
  637. std::string name = json_value(body, "model", std::string());
  638. auto model = models.get_meta(name);
  639. if (!model.has_value()) {
  640. res_error(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND));
  641. return res;
  642. }
  643. if (model->status == SERVER_MODEL_STATUS_LOADED) {
  644. res_error(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST));
  645. return res;
  646. }
  647. models.load(name, false);
  648. res_ok(res, {{"success", true}});
  649. return res;
  650. };
  651. // used by child process to notify the router about status change
  652. // TODO @ngxson : maybe implement authentication for this endpoint in the future
  653. this->post_router_models_status = [this](const server_http_req & req) {
  654. auto res = std::make_unique<server_http_res>();
  655. json body = json::parse(req.body);
  656. std::string model = json_value(body, "model", std::string());
  657. std::string value = json_value(body, "value", std::string());
  658. models.update_status(model, server_model_status_from_string(value));
  659. res_ok(res, {{"success", true}});
  660. return res;
  661. };
  662. this->get_router_models = [this](const server_http_req &) {
  663. auto res = std::make_unique<server_http_res>();
  664. json models_json = json::array();
  665. auto all_models = models.get_all_meta();
  666. std::time_t t = std::time(0);
  667. for (const auto & meta : all_models) {
  668. json status {
  669. {"value", server_model_status_to_string(meta.status)},
  670. {"args", meta.args},
  671. };
  672. if (meta.is_failed()) {
  673. status["exit_code"] = meta.exit_code;
  674. status["failed"] = true;
  675. }
  676. models_json.push_back(json {
  677. {"id", meta.name},
  678. {"object", "model"}, // for OAI-compat
  679. {"owned_by", "llamacpp"}, // for OAI-compat
  680. {"created", t}, // for OAI-compat
  681. {"in_cache", meta.in_cache},
  682. {"path", meta.path},
  683. {"status", status},
  684. // TODO: add other fields, may require reading GGUF metadata
  685. });
  686. }
  687. res_ok(res, {
  688. {"data", models_json},
  689. {"object", "list"},
  690. });
  691. return res;
  692. };
  693. this->post_router_models_unload = [this](const server_http_req & req) {
  694. auto res = std::make_unique<server_http_res>();
  695. json body = json::parse(req.body);
  696. std::string name = json_value(body, "model", std::string());
  697. auto model = models.get_meta(name);
  698. if (!model.has_value()) {
  699. res_error(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
  700. return res;
  701. }
  702. if (model->status != SERVER_MODEL_STATUS_LOADED) {
  703. res_error(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
  704. return res;
  705. }
  706. models.unload(name);
  707. res_ok(res, {{"success", true}});
  708. return res;
  709. };
  710. }
  711. //
  712. // server_http_proxy
  713. //
  714. // simple implementation of a pipe
  715. // used for streaming data between threads
  716. template<typename T>
  717. struct pipe_t {
  718. std::mutex mutex;
  719. std::condition_variable cv;
  720. std::queue<T> queue;
  721. std::atomic<bool> writer_closed{false};
  722. std::atomic<bool> reader_closed{false};
  723. void close_write() {
  724. writer_closed.store(true, std::memory_order_relaxed);
  725. cv.notify_all();
  726. }
  727. void close_read() {
  728. reader_closed.store(true, std::memory_order_relaxed);
  729. cv.notify_all();
  730. }
  731. bool read(T & output, const std::function<bool()> & should_stop) {
  732. std::unique_lock<std::mutex> lk(mutex);
  733. constexpr auto poll_interval = std::chrono::milliseconds(500);
  734. while (true) {
  735. if (!queue.empty()) {
  736. output = std::move(queue.front());
  737. queue.pop();
  738. return true;
  739. }
  740. if (writer_closed.load()) {
  741. return false; // clean EOF
  742. }
  743. if (should_stop()) {
  744. close_read(); // signal broken pipe to writer
  745. return false; // cancelled / reader no longer alive
  746. }
  747. cv.wait_for(lk, poll_interval);
  748. }
  749. }
  750. bool write(T && data) {
  751. std::lock_guard<std::mutex> lk(mutex);
  752. if (reader_closed.load()) {
  753. return false; // broken pipe
  754. }
  755. queue.push(std::move(data));
  756. cv.notify_one();
  757. return true;
  758. }
  759. };
  760. server_http_proxy::server_http_proxy(
  761. const std::string & method,
  762. const std::string & host,
  763. int port,
  764. const std::string & path,
  765. const std::map<std::string, std::string> & headers,
  766. const std::string & body,
  767. const std::function<bool()> should_stop) {
  768. // shared between reader and writer threads
  769. auto cli = std::make_shared<httplib::Client>(host, port);
  770. auto pipe = std::make_shared<pipe_t<msg_t>>();
  771. // setup Client
  772. cli->set_connection_timeout(0, 200000); // 200 milliseconds
  773. this->status = 500; // to be overwritten upon response
  774. this->cleanup = [pipe]() {
  775. pipe->close_read();
  776. pipe->close_write();
  777. };
  778. // wire up the receive end of the pipe
  779. this->next = [pipe, should_stop](std::string & out) -> bool {
  780. msg_t msg;
  781. bool has_next = pipe->read(msg, should_stop);
  782. if (!msg.data.empty()) {
  783. out = std::move(msg.data);
  784. }
  785. return has_next; // false if EOF or pipe broken
  786. };
  787. // wire up the HTTP client
  788. // note: do NOT capture `this` pointer, as it may be destroyed before the thread ends
  789. httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) {
  790. msg_t msg;
  791. msg.status = response.status;
  792. for (const auto & [key, value] : response.headers) {
  793. msg.headers[key] = value;
  794. }
  795. return pipe->write(std::move(msg)); // send headers first
  796. };
  797. httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) {
  798. // send data chunks
  799. // returns false if pipe is closed / broken (signal to stop receiving)
  800. return pipe->write({{}, 0, std::string(data, data_length)});
  801. };
  802. // prepare the request to destination server
  803. httplib::Request req;
  804. {
  805. req.method = method;
  806. req.path = path;
  807. for (const auto & [key, value] : headers) {
  808. req.set_header(key, value);
  809. }
  810. req.body = body;
  811. req.response_handler = response_handler;
  812. req.content_receiver = content_receiver;
  813. }
  814. // start the proxy thread
  815. SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str());
  816. this->thread = std::thread([cli, pipe, req]() {
  817. auto result = cli->send(std::move(req));
  818. if (result.error() != httplib::Error::Success) {
  819. auto err_str = httplib::to_string(result.error());
  820. SRV_ERR("http client error: %s\n", err_str.c_str());
  821. pipe->write({{}, 500, ""}); // header
  822. pipe->write({{}, 0, "proxy error: " + err_str}); // body
  823. }
  824. pipe->close_write(); // signal EOF to reader
  825. SRV_DBG("%s", "client request thread ended\n");
  826. });
  827. this->thread.detach();
  828. // wait for the first chunk (headers)
  829. msg_t header;
  830. if (pipe->read(header, should_stop)) {
  831. SRV_DBG("%s", "received response headers\n");
  832. this->status = header.status;
  833. this->headers = header.headers;
  834. } else {
  835. SRV_DBG("%s", "no response headers received (request cancelled?)\n");
  836. }
  837. }