1
0

server-models.cpp 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975
  1. #include "server-common.h"
  2. #include "server-models.h"
  3. #include "download.h"
  4. #include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
  5. #include <sheredom/subprocess.h>
  6. #include <functional>
  7. #include <thread>
  8. #include <mutex>
  9. #include <condition_variable>
  10. #include <cstring>
  11. #include <atomic>
  12. #include <chrono>
  13. #include <queue>
  14. #ifdef _WIN32
  15. #include <winsock2.h>
  16. #else
  17. #include <sys/socket.h>
  18. #include <netinet/in.h>
  19. #include <arpa/inet.h>
  20. #include <unistd.h>
  21. #endif
  22. #if defined(__APPLE__) && defined(__MACH__)
  23. // macOS: use _NSGetExecutablePath to get the executable path
  24. #include <mach-o/dyld.h>
  25. #include <limits.h>
  26. #endif
  27. #define CMD_EXIT "exit"
  28. static std::filesystem::path get_server_exec_path() {
  29. #if defined(_WIN32)
  30. wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths
  31. DWORD len = GetModuleFileNameW(nullptr, buf, _countof(buf));
  32. if (len == 0 || len >= _countof(buf)) {
  33. throw std::runtime_error("GetModuleFileNameW failed or path too long");
  34. }
  35. return std::filesystem::path(buf);
  36. #elif defined(__APPLE__) && defined(__MACH__)
  37. char small_path[PATH_MAX];
  38. uint32_t size = sizeof(small_path);
  39. if (_NSGetExecutablePath(small_path, &size) == 0) {
  40. // resolve any symlinks to get absolute path
  41. try {
  42. return std::filesystem::canonical(std::filesystem::path(small_path));
  43. } catch (...) {
  44. return std::filesystem::path(small_path);
  45. }
  46. } else {
  47. // buffer was too small, allocate required size and call again
  48. std::vector<char> buf(size);
  49. if (_NSGetExecutablePath(buf.data(), &size) == 0) {
  50. try {
  51. return std::filesystem::canonical(std::filesystem::path(buf.data()));
  52. } catch (...) {
  53. return std::filesystem::path(buf.data());
  54. }
  55. }
  56. throw std::runtime_error("_NSGetExecutablePath failed after buffer resize");
  57. }
  58. #else
  59. char path[FILENAME_MAX];
  60. ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX);
  61. if (count <= 0) {
  62. throw std::runtime_error("failed to resolve /proc/self/exe");
  63. }
  64. return std::filesystem::path(std::string(path, count));
  65. #endif
  66. }
  67. struct local_model {
  68. std::string name;
  69. std::string path;
  70. std::string path_mmproj;
  71. };
  72. static std::vector<local_model> list_local_models(const std::string & dir) {
  73. if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) {
  74. throw std::runtime_error(string_format("error: '%s' does not exist or is not a directory\n", dir.c_str()));
  75. }
  76. std::vector<local_model> models;
  77. auto scan_subdir = [&models](const std::string & subdir_path, const std::string & name) {
  78. auto files = fs_list(subdir_path, false);
  79. common_file_info model_file;
  80. common_file_info first_shard_file;
  81. common_file_info mmproj_file;
  82. for (const auto & file : files) {
  83. if (string_ends_with(file.name, ".gguf")) {
  84. if (file.name.find("mmproj") != std::string::npos) {
  85. mmproj_file = file;
  86. } else if (file.name.find("-00001-of-") != std::string::npos) {
  87. first_shard_file = file;
  88. } else {
  89. model_file = file;
  90. }
  91. }
  92. }
  93. // single file model
  94. local_model model{
  95. /* name */ name,
  96. /* path */ first_shard_file.path.empty() ? model_file.path : first_shard_file.path,
  97. /* path_mmproj */ mmproj_file.path // can be empty
  98. };
  99. if (!model.path.empty()) {
  100. models.push_back(model);
  101. }
  102. };
  103. auto files = fs_list(dir, true);
  104. for (const auto & file : files) {
  105. if (file.is_dir) {
  106. scan_subdir(file.path, file.name);
  107. } else if (string_ends_with(file.name, ".gguf")) {
  108. // single file model
  109. std::string name = file.name;
  110. string_replace_all(name, ".gguf", "");
  111. local_model model{
  112. /* name */ name,
  113. /* path */ file.path,
  114. /* path_mmproj */ ""
  115. };
  116. models.push_back(model);
  117. }
  118. }
  119. return models;
  120. }
  121. //
  122. // server_models
  123. //
  124. server_models::server_models(
  125. const common_params & params,
  126. int argc,
  127. char ** argv,
  128. char ** envp) : base_params(params) {
  129. for (int i = 0; i < argc; i++) {
  130. base_args.push_back(std::string(argv[i]));
  131. }
  132. for (char ** env = envp; *env != nullptr; env++) {
  133. base_env.push_back(std::string(*env));
  134. }
  135. GGML_ASSERT(!base_args.empty());
  136. // set binary path
  137. try {
  138. base_args[0] = get_server_exec_path().string();
  139. } catch (const std::exception & e) {
  140. LOG_WRN("failed to get server executable path: %s\n", e.what());
  141. LOG_WRN("using original argv[0] as fallback: %s\n", base_args[0].c_str());
  142. }
  143. // TODO: allow refreshing cached model list
  144. // add cached models
  145. auto cached_models = common_list_cached_models();
  146. for (const auto & model : cached_models) {
  147. server_model_meta meta{
  148. /* name */ model.to_string(),
  149. /* path */ model.manifest_path,
  150. /* path_mmproj */ "", // auto-detected when loading
  151. /* in_cache */ true,
  152. /* port */ 0,
  153. /* status */ SERVER_MODEL_STATUS_UNLOADED,
  154. /* last_used */ 0,
  155. /* args */ std::vector<std::string>(),
  156. /* exit_code */ 0
  157. };
  158. mapping[meta.name] = instance_t{
  159. /* subproc */ std::make_shared<subprocess_s>(),
  160. /* th */ std::thread(),
  161. /* meta */ meta
  162. };
  163. }
  164. // add local models specificed via --models-dir
  165. if (!params.models_dir.empty()) {
  166. auto local_models = list_local_models(params.models_dir);
  167. for (const auto & model : local_models) {
  168. if (mapping.find(model.name) != mapping.end()) {
  169. // already exists in cached models, skip
  170. continue;
  171. }
  172. server_model_meta meta{
  173. /* name */ model.name,
  174. /* path */ model.path,
  175. /* path_mmproj */ model.path_mmproj,
  176. /* in_cache */ false,
  177. /* port */ 0,
  178. /* status */ SERVER_MODEL_STATUS_UNLOADED,
  179. /* last_used */ 0,
  180. /* args */ std::vector<std::string>(),
  181. /* exit_code */ 0
  182. };
  183. mapping[meta.name] = instance_t{
  184. /* subproc */ std::make_shared<subprocess_s>(),
  185. /* th */ std::thread(),
  186. /* meta */ meta
  187. };
  188. }
  189. }
  190. }
  191. void server_models::update_meta(const std::string & name, const server_model_meta & meta) {
  192. std::lock_guard<std::mutex> lk(mutex);
  193. auto it = mapping.find(name);
  194. if (it != mapping.end()) {
  195. it->second.meta = meta;
  196. }
  197. cv.notify_all(); // notify wait_until_loaded
  198. }
  199. bool server_models::has_model(const std::string & name) {
  200. std::lock_guard<std::mutex> lk(mutex);
  201. return mapping.find(name) != mapping.end();
  202. }
  203. std::optional<server_model_meta> server_models::get_meta(const std::string & name) {
  204. std::lock_guard<std::mutex> lk(mutex);
  205. auto it = mapping.find(name);
  206. if (it != mapping.end()) {
  207. return it->second.meta;
  208. }
  209. return std::nullopt;
  210. }
  211. static int get_free_port() {
  212. #ifdef _WIN32
  213. WSADATA wsaData;
  214. if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
  215. return -1;
  216. }
  217. typedef SOCKET native_socket_t;
  218. #define INVALID_SOCKET_VAL INVALID_SOCKET
  219. #define CLOSE_SOCKET(s) closesocket(s)
  220. #else
  221. typedef int native_socket_t;
  222. #define INVALID_SOCKET_VAL -1
  223. #define CLOSE_SOCKET(s) close(s)
  224. #endif
  225. native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
  226. if (sock == INVALID_SOCKET_VAL) {
  227. #ifdef _WIN32
  228. WSACleanup();
  229. #endif
  230. return -1;
  231. }
  232. struct sockaddr_in serv_addr;
  233. std::memset(&serv_addr, 0, sizeof(serv_addr));
  234. serv_addr.sin_family = AF_INET;
  235. serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
  236. serv_addr.sin_port = htons(0);
  237. if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
  238. CLOSE_SOCKET(sock);
  239. #ifdef _WIN32
  240. WSACleanup();
  241. #endif
  242. return -1;
  243. }
  244. #ifdef _WIN32
  245. int namelen = sizeof(serv_addr);
  246. #else
  247. socklen_t namelen = sizeof(serv_addr);
  248. #endif
  249. if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
  250. CLOSE_SOCKET(sock);
  251. #ifdef _WIN32
  252. WSACleanup();
  253. #endif
  254. return -1;
  255. }
  256. int port = ntohs(serv_addr.sin_port);
  257. CLOSE_SOCKET(sock);
  258. #ifdef _WIN32
  259. WSACleanup();
  260. #endif
  261. return port;
  262. }
  263. // helper to convert vector<string> to char **
  264. // pointers are only valid as long as the original vector is valid
  265. static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & vec) {
  266. std::vector<char *> result;
  267. result.reserve(vec.size() + 1);
  268. for (const auto & s : vec) {
  269. result.push_back(const_cast<char*>(s.c_str()));
  270. }
  271. result.push_back(nullptr);
  272. return result;
  273. }
  274. std::vector<server_model_meta> server_models::get_all_meta() {
  275. std::lock_guard<std::mutex> lk(mutex);
  276. std::vector<server_model_meta> result;
  277. result.reserve(mapping.size());
  278. for (const auto & [name, inst] : mapping) {
  279. result.push_back(inst.meta);
  280. }
  281. return result;
  282. }
  283. void server_models::unload_lru() {
  284. if (base_params.models_max <= 0) {
  285. return; // no limit
  286. }
  287. // remove one of the servers if we passed the models_max (least recently used - LRU)
  288. std::string lru_model_name = "";
  289. int64_t lru_last_used = ggml_time_ms();
  290. size_t count_active = 0;
  291. {
  292. std::lock_guard<std::mutex> lk(mutex);
  293. for (const auto & m : mapping) {
  294. if (m.second.meta.is_active()) {
  295. count_active++;
  296. if (m.second.meta.last_used < lru_last_used) {
  297. lru_model_name = m.first;
  298. lru_last_used = m.second.meta.last_used;
  299. }
  300. }
  301. }
  302. }
  303. if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) {
  304. SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str());
  305. unload(lru_model_name);
  306. }
  307. }
  308. static void add_or_replace_arg(std::vector<std::string> & args, const std::string & key, const std::string & value) {
  309. for (size_t i = 0; i < args.size(); i++) {
  310. if (args[i] == key && i + 1 < args.size()) {
  311. args[i + 1] = value;
  312. return;
  313. }
  314. }
  315. // not found, append
  316. args.push_back(key);
  317. args.push_back(value);
  318. }
  319. void server_models::load(const std::string & name, bool auto_load) {
  320. if (!has_model(name)) {
  321. throw std::runtime_error("model name=" + name + " is not found");
  322. }
  323. unload_lru();
  324. std::lock_guard<std::mutex> lk(mutex);
  325. auto meta = mapping[name].meta;
  326. if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
  327. SRV_INF("model %s is not ready\n", name.c_str());
  328. return;
  329. }
  330. // prepare new instance info
  331. instance_t inst;
  332. inst.meta = meta;
  333. inst.meta.port = get_free_port();
  334. inst.meta.status = SERVER_MODEL_STATUS_LOADING;
  335. inst.meta.last_used = ggml_time_ms();
  336. if (inst.meta.port <= 0) {
  337. throw std::runtime_error("failed to get a port number");
  338. }
  339. inst.subproc = std::make_shared<subprocess_s>();
  340. {
  341. SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port);
  342. std::vector<std::string> child_args;
  343. if (auto_load && !meta.args.empty()) {
  344. child_args = meta.args; // copy previous args
  345. } else {
  346. child_args = base_args; // copy
  347. if (inst.meta.in_cache) {
  348. add_or_replace_arg(child_args, "-hf", inst.meta.name);
  349. } else {
  350. add_or_replace_arg(child_args, "-m", inst.meta.path);
  351. if (!inst.meta.path_mmproj.empty()) {
  352. add_or_replace_arg(child_args, "--mmproj", inst.meta.path_mmproj);
  353. }
  354. }
  355. }
  356. // set model args
  357. add_or_replace_arg(child_args, "--port", std::to_string(inst.meta.port));
  358. add_or_replace_arg(child_args, "--alias", inst.meta.name);
  359. std::vector<std::string> child_env = base_env; // copy
  360. child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));
  361. SRV_INF("%s", "spawning server instance with args:\n");
  362. for (const auto & arg : child_args) {
  363. SRV_INF(" %s\n", arg.c_str());
  364. }
  365. inst.meta.args = child_args; // save for debugging
  366. std::vector<char *> argv = to_char_ptr_array(child_args);
  367. std::vector<char *> envp = to_char_ptr_array(child_env);
  368. int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
  369. int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get());
  370. if (result != 0) {
  371. throw std::runtime_error("failed to spawn server instance");
  372. }
  373. inst.stdin_file = subprocess_stdin(inst.subproc.get());
  374. }
  375. // start a thread to manage the child process
  376. // captured variables are guaranteed to be destroyed only after the thread is joined
  377. inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port]() {
  378. // read stdout/stderr and forward to main server log
  379. FILE * p_stdout_stderr = subprocess_stdout(child_proc.get());
  380. if (p_stdout_stderr) {
  381. char buffer[4096];
  382. while (fgets(buffer, sizeof(buffer), p_stdout_stderr) != nullptr) {
  383. LOG("[%5d] %s", port, buffer);
  384. }
  385. } else {
  386. SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str());
  387. }
  388. // we reach here when the child process exits
  389. int exit_code = 0;
  390. subprocess_join(child_proc.get(), &exit_code);
  391. subprocess_destroy(child_proc.get());
  392. // update PID and status
  393. {
  394. std::lock_guard<std::mutex> lk(mutex);
  395. auto it = mapping.find(name);
  396. if (it != mapping.end()) {
  397. auto & meta = it->second.meta;
  398. meta.exit_code = exit_code;
  399. meta.status = SERVER_MODEL_STATUS_UNLOADED;
  400. }
  401. cv.notify_all();
  402. }
  403. SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
  404. });
  405. // clean up old process/thread if exists
  406. {
  407. auto & old_instance = mapping[name];
  408. // old process should have exited already, but just in case, we clean it up here
  409. if (subprocess_alive(old_instance.subproc.get())) {
  410. SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
  411. subprocess_terminate(old_instance.subproc.get()); // force kill
  412. }
  413. if (old_instance.th.joinable()) {
  414. old_instance.th.join();
  415. }
  416. }
  417. mapping[name] = std::move(inst);
  418. cv.notify_all();
  419. }
  420. static void interrupt_subprocess(FILE * stdin_file) {
  421. // because subprocess.h does not provide a way to send SIGINT,
  422. // we will send a command to the child process to exit gracefully
  423. if (stdin_file) {
  424. fprintf(stdin_file, "%s\n", CMD_EXIT);
  425. fflush(stdin_file);
  426. }
  427. }
  428. void server_models::unload(const std::string & name) {
  429. std::lock_guard<std::mutex> lk(mutex);
  430. auto it = mapping.find(name);
  431. if (it != mapping.end()) {
  432. if (it->second.meta.is_active()) {
  433. SRV_INF("unloading model instance name=%s\n", name.c_str());
  434. interrupt_subprocess(it->second.stdin_file);
  435. // status change will be handled by the managing thread
  436. } else {
  437. SRV_WRN("model instance name=%s is not loaded\n", name.c_str());
  438. }
  439. }
  440. }
  441. void server_models::unload_all() {
  442. std::vector<std::thread> to_join;
  443. {
  444. std::lock_guard<std::mutex> lk(mutex);
  445. for (auto & [name, inst] : mapping) {
  446. if (inst.meta.is_active()) {
  447. SRV_INF("unloading model instance name=%s\n", name.c_str());
  448. interrupt_subprocess(inst.stdin_file);
  449. // status change will be handled by the managing thread
  450. }
  451. // moving the thread to join list to avoid deadlock
  452. to_join.push_back(std::move(inst.th));
  453. }
  454. }
  455. for (auto & th : to_join) {
  456. if (th.joinable()) {
  457. th.join();
  458. }
  459. }
  460. }
  461. void server_models::update_status(const std::string & name, server_model_status status) {
  462. // for now, we only allow updating to LOADED status
  463. if (status != SERVER_MODEL_STATUS_LOADED) {
  464. throw std::runtime_error("invalid status value");
  465. }
  466. auto meta = get_meta(name);
  467. if (meta.has_value()) {
  468. meta->status = status;
  469. update_meta(name, meta.value());
  470. }
  471. }
  472. void server_models::wait_until_loaded(const std::string & name) {
  473. std::unique_lock<std::mutex> lk(mutex);
  474. cv.wait(lk, [this, &name]() {
  475. auto it = mapping.find(name);
  476. if (it != mapping.end()) {
  477. return it->second.meta.status != SERVER_MODEL_STATUS_LOADING;
  478. }
  479. return false;
  480. });
  481. }
  482. bool server_models::ensure_model_loaded(const std::string & name) {
  483. auto meta = get_meta(name);
  484. if (!meta.has_value()) {
  485. throw std::runtime_error("model name=" + name + " is not found");
  486. }
  487. if (meta->status == SERVER_MODEL_STATUS_LOADED) {
  488. return false; // already loaded
  489. }
  490. if (meta->status == SERVER_MODEL_STATUS_UNLOADED) {
  491. SRV_INF("model name=%s is not loaded, loading...\n", name.c_str());
  492. load(name, true);
  493. }
  494. SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str());
  495. wait_until_loaded(name);
  496. // check final status
  497. meta = get_meta(name);
  498. if (!meta.has_value() || meta->is_failed()) {
  499. throw std::runtime_error("model name=" + name + " failed to load");
  500. }
  501. return true;
  502. }
  503. server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used) {
  504. auto meta = get_meta(name);
  505. if (!meta.has_value()) {
  506. throw std::runtime_error("model name=" + name + " is not found");
  507. }
  508. if (meta->status != SERVER_MODEL_STATUS_LOADED) {
  509. throw std::invalid_argument("model name=" + name + " is not loaded");
  510. }
  511. if (update_last_used) {
  512. std::unique_lock<std::mutex> lk(mutex);
  513. mapping[name].meta.last_used = ggml_time_ms();
  514. }
  515. SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port);
  516. auto proxy = std::make_unique<server_http_proxy>(
  517. method,
  518. base_params.hostname,
  519. meta->port,
  520. req.path,
  521. req.headers,
  522. req.body,
  523. req.should_stop);
  524. return proxy;
  525. }
  526. std::thread server_models::setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function<void(int)> & shutdown_handler) {
  527. // send a notification to the router server that a model instance is ready
  528. // TODO @ngxson : use HTTP client from libcommon
  529. httplib::Client cli(base_params.hostname, router_port);
  530. cli.set_connection_timeout(0, 200000); // 200 milliseconds
  531. httplib::Request req;
  532. req.method = "POST";
  533. req.path = "/models/status";
  534. req.set_header("Content-Type", "application/json");
  535. if (!base_params.api_keys.empty()) {
  536. req.set_header("Authorization", "Bearer " + base_params.api_keys[0]);
  537. }
  538. json body;
  539. body["model"] = name;
  540. body["value"] = server_model_status_to_string(SERVER_MODEL_STATUS_LOADED);
  541. req.body = body.dump();
  542. SRV_INF("notifying router server (port=%d) that model %s is ready\n", router_port, name.c_str());
  543. auto result = cli.send(std::move(req));
  544. if (result.error() != httplib::Error::Success) {
  545. auto err_str = httplib::to_string(result.error());
  546. SRV_ERR("failed to notify router server: %s\n", err_str.c_str());
  547. exit(1); // force exit
  548. }
  549. // setup thread for monitoring stdin
  550. return std::thread([shutdown_handler]() {
  551. // wait for EOF on stdin
  552. SRV_INF("%s", "child server monitoring thread started, waiting for EOF on stdin...\n");
  553. bool eof = false;
  554. while (true) {
  555. std::string line;
  556. if (!std::getline(std::cin, line)) {
  557. // EOF detected, that means the router server is unexpectedly exit or killed
  558. eof = true;
  559. break;
  560. }
  561. if (line.find(CMD_EXIT) != std::string::npos) {
  562. SRV_INF("%s", "exit command received, exiting...\n");
  563. shutdown_handler(0);
  564. break;
  565. }
  566. }
  567. if (eof) {
  568. SRV_INF("%s", "EOF on stdin detected, forcing shutdown...\n");
  569. exit(1);
  570. }
  571. });
  572. }
  573. //
  574. // server_models_routes
  575. //
  576. static void res_ok(std::unique_ptr<server_http_res> & res, const json & response_data) {
  577. res->status = 200;
  578. res->data = safe_json_to_str(response_data);
  579. }
  580. static void res_error(std::unique_ptr<server_http_res> & res, const json & error_data) {
  581. res->status = json_value(error_data, "code", 500);
  582. res->data = safe_json_to_str({{ "error", error_data }});
  583. }
  584. static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr<server_http_res> & res) {
  585. if (name.empty()) {
  586. res_error(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST));
  587. return false;
  588. }
  589. auto meta = models.get_meta(name);
  590. if (!meta.has_value()) {
  591. res_error(res, format_error_response("model not found", ERROR_TYPE_INVALID_REQUEST));
  592. return false;
  593. }
  594. if (models_autoload) {
  595. models.ensure_model_loaded(name);
  596. } else {
  597. if (meta->status != SERVER_MODEL_STATUS_LOADED) {
  598. res_error(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
  599. return false;
  600. }
  601. }
  602. return true;
  603. }
  604. static bool is_autoload(const common_params & params, const server_http_req & req) {
  605. std::string autoload = req.get_param("autoload");
  606. if (autoload.empty()) {
  607. return params.models_autoload;
  608. } else {
  609. return autoload == "true" || autoload == "1";
  610. }
  611. }
  612. void server_models_routes::init_routes() {
  613. this->get_router_props = [this](const server_http_req & req) {
  614. std::string name = req.get_param("model");
  615. if (name.empty()) {
  616. // main instance
  617. auto res = std::make_unique<server_http_res>();
  618. res_ok(res, {
  619. // TODO: add support for this on web UI
  620. {"role", "router"},
  621. {"max_instances", 4}, // dummy value for testing
  622. // this is a dummy response to make sure webui doesn't break
  623. {"model_alias", "llama-server"},
  624. {"model_path", "none"},
  625. {"default_generation_settings", {
  626. {"params", json{}},
  627. {"n_ctx", 0},
  628. }},
  629. });
  630. return res;
  631. }
  632. return proxy_get(req);
  633. };
  634. this->proxy_get = [this](const server_http_req & req) {
  635. std::string method = "GET";
  636. std::string name = req.get_param("model");
  637. bool autoload = is_autoload(params, req);
  638. auto error_res = std::make_unique<server_http_res>();
  639. if (!router_validate_model(name, models, autoload, error_res)) {
  640. return error_res;
  641. }
  642. return models.proxy_request(req, method, name, false);
  643. };
  644. this->proxy_post = [this](const server_http_req & req) {
  645. std::string method = "POST";
  646. json body = json::parse(req.body);
  647. std::string name = json_value(body, "model", std::string());
  648. bool autoload = is_autoload(params, req);
  649. auto error_res = std::make_unique<server_http_res>();
  650. if (!router_validate_model(name, models, autoload, error_res)) {
  651. return error_res;
  652. }
  653. return models.proxy_request(req, method, name, true); // update last usage for POST request only
  654. };
  655. this->get_router_models = [this](const server_http_req &) {
  656. auto res = std::make_unique<server_http_res>();
  657. json models_json = json::array();
  658. auto all_models = models.get_all_meta();
  659. std::time_t t = std::time(0);
  660. for (const auto & meta : all_models) {
  661. json status {
  662. {"value", server_model_status_to_string(meta.status)},
  663. {"args", meta.args},
  664. };
  665. if (meta.is_failed()) {
  666. status["exit_code"] = meta.exit_code;
  667. status["failed"] = true;
  668. }
  669. models_json.push_back(json {
  670. {"id", meta.name},
  671. {"object", "model"}, // for OAI-compat
  672. {"owned_by", "llamacpp"}, // for OAI-compat
  673. {"created", t}, // for OAI-compat
  674. {"in_cache", meta.in_cache},
  675. {"path", meta.path},
  676. {"status", status},
  677. // TODO: add other fields, may require reading GGUF metadata
  678. });
  679. }
  680. res_ok(res, {
  681. {"data", models_json},
  682. {"object", "list"},
  683. });
  684. return res;
  685. };
  686. this->post_router_models_load = [this](const server_http_req & req) {
  687. auto res = std::make_unique<server_http_res>();
  688. json body = json::parse(req.body);
  689. std::string name = json_value(body, "model", std::string());
  690. auto model = models.get_meta(name);
  691. if (!model.has_value()) {
  692. res_error(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND));
  693. return res;
  694. }
  695. if (model->status == SERVER_MODEL_STATUS_LOADED) {
  696. res_error(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST));
  697. return res;
  698. }
  699. models.load(name, false);
  700. res_ok(res, {{"success", true}});
  701. return res;
  702. };
  703. // used by child process to notify the router about status change
  704. // TODO @ngxson : maybe implement authentication for this endpoint in the future
  705. this->post_router_models_status = [this](const server_http_req & req) {
  706. auto res = std::make_unique<server_http_res>();
  707. json body = json::parse(req.body);
  708. std::string model = json_value(body, "model", std::string());
  709. std::string value = json_value(body, "value", std::string());
  710. models.update_status(model, server_model_status_from_string(value));
  711. res_ok(res, {{"success", true}});
  712. return res;
  713. };
  714. this->get_router_models = [this](const server_http_req &) {
  715. auto res = std::make_unique<server_http_res>();
  716. json models_json = json::array();
  717. auto all_models = models.get_all_meta();
  718. std::time_t t = std::time(0);
  719. for (const auto & meta : all_models) {
  720. json status {
  721. {"value", server_model_status_to_string(meta.status)},
  722. {"args", meta.args},
  723. };
  724. if (meta.is_failed()) {
  725. status["exit_code"] = meta.exit_code;
  726. status["failed"] = true;
  727. }
  728. models_json.push_back(json {
  729. {"id", meta.name},
  730. {"object", "model"}, // for OAI-compat
  731. {"owned_by", "llamacpp"}, // for OAI-compat
  732. {"created", t}, // for OAI-compat
  733. {"in_cache", meta.in_cache},
  734. {"path", meta.path},
  735. {"status", status},
  736. // TODO: add other fields, may require reading GGUF metadata
  737. });
  738. }
  739. res_ok(res, {
  740. {"data", models_json},
  741. {"object", "list"},
  742. });
  743. return res;
  744. };
  745. this->post_router_models_unload = [this](const server_http_req & req) {
  746. auto res = std::make_unique<server_http_res>();
  747. json body = json::parse(req.body);
  748. std::string name = json_value(body, "model", std::string());
  749. auto model = models.get_meta(name);
  750. if (!model.has_value()) {
  751. res_error(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
  752. return res;
  753. }
  754. if (model->status != SERVER_MODEL_STATUS_LOADED) {
  755. res_error(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
  756. return res;
  757. }
  758. models.unload(name);
  759. res_ok(res, {{"success", true}});
  760. return res;
  761. };
  762. }
  763. //
  764. // server_http_proxy
  765. //
  766. // simple implementation of a pipe
  767. // used for streaming data between threads
  768. template<typename T>
  769. struct pipe_t {
  770. std::mutex mutex;
  771. std::condition_variable cv;
  772. std::queue<T> queue;
  773. std::atomic<bool> writer_closed{false};
  774. std::atomic<bool> reader_closed{false};
  775. void close_write() {
  776. writer_closed.store(true, std::memory_order_relaxed);
  777. cv.notify_all();
  778. }
  779. void close_read() {
  780. reader_closed.store(true, std::memory_order_relaxed);
  781. cv.notify_all();
  782. }
  783. bool read(T & output, const std::function<bool()> & should_stop) {
  784. std::unique_lock<std::mutex> lk(mutex);
  785. constexpr auto poll_interval = std::chrono::milliseconds(500);
  786. while (true) {
  787. if (!queue.empty()) {
  788. output = std::move(queue.front());
  789. queue.pop();
  790. return true;
  791. }
  792. if (writer_closed.load()) {
  793. return false; // clean EOF
  794. }
  795. if (should_stop()) {
  796. close_read(); // signal broken pipe to writer
  797. return false; // cancelled / reader no longer alive
  798. }
  799. cv.wait_for(lk, poll_interval);
  800. }
  801. }
  802. bool write(T && data) {
  803. std::lock_guard<std::mutex> lk(mutex);
  804. if (reader_closed.load()) {
  805. return false; // broken pipe
  806. }
  807. queue.push(std::move(data));
  808. cv.notify_one();
  809. return true;
  810. }
  811. };
  812. server_http_proxy::server_http_proxy(
  813. const std::string & method,
  814. const std::string & host,
  815. int port,
  816. const std::string & path,
  817. const std::map<std::string, std::string> & headers,
  818. const std::string & body,
  819. const std::function<bool()> should_stop) {
  820. // shared between reader and writer threads
  821. auto cli = std::make_shared<httplib::Client>(host, port);
  822. auto pipe = std::make_shared<pipe_t<msg_t>>();
  823. // setup Client
  824. cli->set_connection_timeout(0, 200000); // 200 milliseconds
  825. this->status = 500; // to be overwritten upon response
  826. this->cleanup = [pipe]() {
  827. pipe->close_read();
  828. pipe->close_write();
  829. };
  830. // wire up the receive end of the pipe
  831. this->next = [pipe, should_stop](std::string & out) -> bool {
  832. msg_t msg;
  833. bool has_next = pipe->read(msg, should_stop);
  834. if (!msg.data.empty()) {
  835. out = std::move(msg.data);
  836. }
  837. return has_next; // false if EOF or pipe broken
  838. };
  839. // wire up the HTTP client
  840. // note: do NOT capture `this` pointer, as it may be destroyed before the thread ends
  841. httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) {
  842. msg_t msg;
  843. msg.status = response.status;
  844. for (const auto & [key, value] : response.headers) {
  845. msg.headers[key] = value;
  846. }
  847. return pipe->write(std::move(msg)); // send headers first
  848. };
  849. httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) {
  850. // send data chunks
  851. // returns false if pipe is closed / broken (signal to stop receiving)
  852. return pipe->write({{}, 0, std::string(data, data_length)});
  853. };
  854. // prepare the request to destination server
  855. httplib::Request req;
  856. {
  857. req.method = method;
  858. req.path = path;
  859. for (const auto & [key, value] : headers) {
  860. req.set_header(key, value);
  861. }
  862. req.body = body;
  863. req.response_handler = response_handler;
  864. req.content_receiver = content_receiver;
  865. }
  866. // start the proxy thread
  867. SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str());
  868. this->thread = std::thread([cli, pipe, req]() {
  869. auto result = cli->send(std::move(req));
  870. if (result.error() != httplib::Error::Success) {
  871. auto err_str = httplib::to_string(result.error());
  872. SRV_ERR("http client error: %s\n", err_str.c_str());
  873. pipe->write({{}, 500, ""}); // header
  874. pipe->write({{}, 0, "proxy error: " + err_str}); // body
  875. }
  876. pipe->close_write(); // signal EOF to reader
  877. SRV_DBG("%s", "client request thread ended\n");
  878. });
  879. this->thread.detach();
  880. // wait for the first chunk (headers)
  881. msg_t header;
  882. if (pipe->read(header, should_stop)) {
  883. SRV_DBG("%s", "received response headers\n");
  884. this->status = header.status;
  885. this->headers = header.headers;
  886. } else {
  887. SRV_DBG("%s", "no response headers received (request cancelled?)\n");
  888. }
  889. }