server-models.cpp 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012
  1. #include "server-common.h"
  2. #include "server-models.h"
  3. #include "download.h"
  4. #include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
  5. #include <sheredom/subprocess.h>
  6. #include <functional>
  7. #include <algorithm>
  8. #include <thread>
  9. #include <mutex>
  10. #include <condition_variable>
  11. #include <cstring>
  12. #include <atomic>
  13. #include <chrono>
  14. #include <queue>
  15. #ifdef _WIN32
  16. #include <winsock2.h>
  17. #else
  18. #include <sys/socket.h>
  19. #include <netinet/in.h>
  20. #include <arpa/inet.h>
  21. #include <unistd.h>
  22. #endif
  23. #if defined(__APPLE__) && defined(__MACH__)
  24. // macOS: use _NSGetExecutablePath to get the executable path
  25. #include <mach-o/dyld.h>
  26. #include <limits.h>
  27. #endif
  28. #define CMD_EXIT "exit"
  29. static std::filesystem::path get_server_exec_path() {
  30. #if defined(_WIN32)
  31. wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths
  32. DWORD len = GetModuleFileNameW(nullptr, buf, _countof(buf));
  33. if (len == 0 || len >= _countof(buf)) {
  34. throw std::runtime_error("GetModuleFileNameW failed or path too long");
  35. }
  36. return std::filesystem::path(buf);
  37. #elif defined(__APPLE__) && defined(__MACH__)
  38. char small_path[PATH_MAX];
  39. uint32_t size = sizeof(small_path);
  40. if (_NSGetExecutablePath(small_path, &size) == 0) {
  41. // resolve any symlinks to get absolute path
  42. try {
  43. return std::filesystem::canonical(std::filesystem::path(small_path));
  44. } catch (...) {
  45. return std::filesystem::path(small_path);
  46. }
  47. } else {
  48. // buffer was too small, allocate required size and call again
  49. std::vector<char> buf(size);
  50. if (_NSGetExecutablePath(buf.data(), &size) == 0) {
  51. try {
  52. return std::filesystem::canonical(std::filesystem::path(buf.data()));
  53. } catch (...) {
  54. return std::filesystem::path(buf.data());
  55. }
  56. }
  57. throw std::runtime_error("_NSGetExecutablePath failed after buffer resize");
  58. }
  59. #else
  60. char path[FILENAME_MAX];
  61. ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX);
  62. if (count <= 0) {
  63. throw std::runtime_error("failed to resolve /proc/self/exe");
  64. }
  65. return std::filesystem::path(std::string(path, count));
  66. #endif
  67. }
  68. struct local_model {
  69. std::string name;
  70. std::string path;
  71. std::string path_mmproj;
  72. };
  73. static std::vector<local_model> list_local_models(const std::string & dir) {
  74. if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) {
  75. throw std::runtime_error(string_format("error: '%s' does not exist or is not a directory\n", dir.c_str()));
  76. }
  77. std::vector<local_model> models;
  78. auto scan_subdir = [&models](const std::string & subdir_path, const std::string & name) {
  79. auto files = fs_list(subdir_path, false);
  80. common_file_info model_file;
  81. common_file_info first_shard_file;
  82. common_file_info mmproj_file;
  83. for (const auto & file : files) {
  84. if (string_ends_with(file.name, ".gguf")) {
  85. if (file.name.find("mmproj") != std::string::npos) {
  86. mmproj_file = file;
  87. } else if (file.name.find("-00001-of-") != std::string::npos) {
  88. first_shard_file = file;
  89. } else {
  90. model_file = file;
  91. }
  92. }
  93. }
  94. // single file model
  95. local_model model{
  96. /* name */ name,
  97. /* path */ first_shard_file.path.empty() ? model_file.path : first_shard_file.path,
  98. /* path_mmproj */ mmproj_file.path // can be empty
  99. };
  100. if (!model.path.empty()) {
  101. models.push_back(model);
  102. }
  103. };
  104. auto files = fs_list(dir, true);
  105. for (const auto & file : files) {
  106. if (file.is_dir) {
  107. scan_subdir(file.path, file.name);
  108. } else if (string_ends_with(file.name, ".gguf")) {
  109. // single file model
  110. std::string name = file.name;
  111. string_replace_all(name, ".gguf", "");
  112. local_model model{
  113. /* name */ name,
  114. /* path */ file.path,
  115. /* path_mmproj */ ""
  116. };
  117. models.push_back(model);
  118. }
  119. }
  120. return models;
  121. }
  122. //
  123. // server_models
  124. //
  125. server_models::server_models(
  126. const common_params & params,
  127. int argc,
  128. char ** argv,
  129. char ** envp) : base_params(params) {
  130. for (int i = 0; i < argc; i++) {
  131. base_args.push_back(std::string(argv[i]));
  132. }
  133. for (char ** env = envp; *env != nullptr; env++) {
  134. base_env.push_back(std::string(*env));
  135. }
  136. GGML_ASSERT(!base_args.empty());
  137. // set binary path
  138. try {
  139. base_args[0] = get_server_exec_path().string();
  140. } catch (const std::exception & e) {
  141. LOG_WRN("failed to get server executable path: %s\n", e.what());
  142. LOG_WRN("using original argv[0] as fallback: %s\n", base_args[0].c_str());
  143. }
  144. // TODO: allow refreshing cached model list
  145. // add cached models
  146. auto cached_models = common_list_cached_models();
  147. for (const auto & model : cached_models) {
  148. server_model_meta meta{
  149. /* name */ model.to_string(),
  150. /* path */ model.manifest_path,
  151. /* path_mmproj */ "", // auto-detected when loading
  152. /* in_cache */ true,
  153. /* port */ 0,
  154. /* status */ SERVER_MODEL_STATUS_UNLOADED,
  155. /* last_used */ 0,
  156. /* args */ std::vector<std::string>(),
  157. /* exit_code */ 0
  158. };
  159. mapping[meta.name] = instance_t{
  160. /* subproc */ std::make_shared<subprocess_s>(),
  161. /* th */ std::thread(),
  162. /* meta */ meta
  163. };
  164. }
  165. // add local models specificed via --models-dir
  166. if (!params.models_dir.empty()) {
  167. auto local_models = list_local_models(params.models_dir);
  168. for (const auto & model : local_models) {
  169. if (mapping.find(model.name) != mapping.end()) {
  170. // already exists in cached models, skip
  171. continue;
  172. }
  173. server_model_meta meta{
  174. /* name */ model.name,
  175. /* path */ model.path,
  176. /* path_mmproj */ model.path_mmproj,
  177. /* in_cache */ false,
  178. /* port */ 0,
  179. /* status */ SERVER_MODEL_STATUS_UNLOADED,
  180. /* last_used */ 0,
  181. /* args */ std::vector<std::string>(),
  182. /* exit_code */ 0
  183. };
  184. mapping[meta.name] = instance_t{
  185. /* subproc */ std::make_shared<subprocess_s>(),
  186. /* th */ std::thread(),
  187. /* meta */ meta
  188. };
  189. }
  190. }
  191. }
  192. void server_models::update_meta(const std::string & name, const server_model_meta & meta) {
  193. std::lock_guard<std::mutex> lk(mutex);
  194. auto it = mapping.find(name);
  195. if (it != mapping.end()) {
  196. it->second.meta = meta;
  197. }
  198. cv.notify_all(); // notify wait_until_loaded
  199. }
  200. bool server_models::has_model(const std::string & name) {
  201. std::lock_guard<std::mutex> lk(mutex);
  202. return mapping.find(name) != mapping.end();
  203. }
  204. std::optional<server_model_meta> server_models::get_meta(const std::string & name) {
  205. std::lock_guard<std::mutex> lk(mutex);
  206. auto it = mapping.find(name);
  207. if (it != mapping.end()) {
  208. return it->second.meta;
  209. }
  210. return std::nullopt;
  211. }
  212. static int get_free_port() {
  213. #ifdef _WIN32
  214. WSADATA wsaData;
  215. if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
  216. return -1;
  217. }
  218. typedef SOCKET native_socket_t;
  219. #define INVALID_SOCKET_VAL INVALID_SOCKET
  220. #define CLOSE_SOCKET(s) closesocket(s)
  221. #else
  222. typedef int native_socket_t;
  223. #define INVALID_SOCKET_VAL -1
  224. #define CLOSE_SOCKET(s) close(s)
  225. #endif
  226. native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
  227. if (sock == INVALID_SOCKET_VAL) {
  228. #ifdef _WIN32
  229. WSACleanup();
  230. #endif
  231. return -1;
  232. }
  233. struct sockaddr_in serv_addr;
  234. std::memset(&serv_addr, 0, sizeof(serv_addr));
  235. serv_addr.sin_family = AF_INET;
  236. serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
  237. serv_addr.sin_port = htons(0);
  238. if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
  239. CLOSE_SOCKET(sock);
  240. #ifdef _WIN32
  241. WSACleanup();
  242. #endif
  243. return -1;
  244. }
  245. #ifdef _WIN32
  246. int namelen = sizeof(serv_addr);
  247. #else
  248. socklen_t namelen = sizeof(serv_addr);
  249. #endif
  250. if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
  251. CLOSE_SOCKET(sock);
  252. #ifdef _WIN32
  253. WSACleanup();
  254. #endif
  255. return -1;
  256. }
  257. int port = ntohs(serv_addr.sin_port);
  258. CLOSE_SOCKET(sock);
  259. #ifdef _WIN32
  260. WSACleanup();
  261. #endif
  262. return port;
  263. }
  264. // helper to convert vector<string> to char **
  265. // pointers are only valid as long as the original vector is valid
  266. static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & vec) {
  267. std::vector<char *> result;
  268. result.reserve(vec.size() + 1);
  269. for (const auto & s : vec) {
  270. result.push_back(const_cast<char*>(s.c_str()));
  271. }
  272. result.push_back(nullptr);
  273. return result;
  274. }
  275. std::vector<server_model_meta> server_models::get_all_meta() {
  276. std::lock_guard<std::mutex> lk(mutex);
  277. std::vector<server_model_meta> result;
  278. result.reserve(mapping.size());
  279. for (const auto & [name, inst] : mapping) {
  280. result.push_back(inst.meta);
  281. }
  282. return result;
  283. }
  284. void server_models::unload_lru() {
  285. if (base_params.models_max <= 0) {
  286. return; // no limit
  287. }
  288. // remove one of the servers if we passed the models_max (least recently used - LRU)
  289. std::string lru_model_name = "";
  290. int64_t lru_last_used = ggml_time_ms();
  291. size_t count_active = 0;
  292. {
  293. std::lock_guard<std::mutex> lk(mutex);
  294. for (const auto & m : mapping) {
  295. if (m.second.meta.is_active()) {
  296. count_active++;
  297. if (m.second.meta.last_used < lru_last_used) {
  298. lru_model_name = m.first;
  299. lru_last_used = m.second.meta.last_used;
  300. }
  301. }
  302. }
  303. }
  304. if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) {
  305. SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str());
  306. unload(lru_model_name);
  307. }
  308. }
  309. static void add_or_replace_arg(std::vector<std::string> & args, const std::string & key, const std::string & value) {
  310. for (size_t i = 0; i < args.size(); i++) {
  311. if (args[i] == key && i + 1 < args.size()) {
  312. args[i + 1] = value;
  313. return;
  314. }
  315. }
  316. // not found, append
  317. args.push_back(key);
  318. args.push_back(value);
  319. }
  320. void server_models::load(const std::string & name, bool auto_load) {
  321. if (!has_model(name)) {
  322. throw std::runtime_error("model name=" + name + " is not found");
  323. }
  324. unload_lru();
  325. std::lock_guard<std::mutex> lk(mutex);
  326. auto meta = mapping[name].meta;
  327. if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
  328. SRV_INF("model %s is not ready\n", name.c_str());
  329. return;
  330. }
  331. // prepare new instance info
  332. instance_t inst;
  333. inst.meta = meta;
  334. inst.meta.port = get_free_port();
  335. inst.meta.status = SERVER_MODEL_STATUS_LOADING;
  336. inst.meta.last_used = ggml_time_ms();
  337. if (inst.meta.port <= 0) {
  338. throw std::runtime_error("failed to get a port number");
  339. }
  340. inst.subproc = std::make_shared<subprocess_s>();
  341. {
  342. SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port);
  343. std::vector<std::string> child_args;
  344. if (auto_load && !meta.args.empty()) {
  345. child_args = meta.args; // copy previous args
  346. } else {
  347. child_args = base_args; // copy
  348. if (inst.meta.in_cache) {
  349. add_or_replace_arg(child_args, "-hf", inst.meta.name);
  350. } else {
  351. add_or_replace_arg(child_args, "-m", inst.meta.path);
  352. if (!inst.meta.path_mmproj.empty()) {
  353. add_or_replace_arg(child_args, "--mmproj", inst.meta.path_mmproj);
  354. }
  355. }
  356. }
  357. // set model args
  358. add_or_replace_arg(child_args, "--port", std::to_string(inst.meta.port));
  359. add_or_replace_arg(child_args, "--alias", inst.meta.name);
  360. std::vector<std::string> child_env = base_env; // copy
  361. child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));
  362. SRV_INF("%s", "spawning server instance with args:\n");
  363. for (const auto & arg : child_args) {
  364. SRV_INF(" %s\n", arg.c_str());
  365. }
  366. inst.meta.args = child_args; // save for debugging
  367. std::vector<char *> argv = to_char_ptr_array(child_args);
  368. std::vector<char *> envp = to_char_ptr_array(child_env);
  369. int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
  370. int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get());
  371. if (result != 0) {
  372. throw std::runtime_error("failed to spawn server instance");
  373. }
  374. inst.stdin_file = subprocess_stdin(inst.subproc.get());
  375. }
  376. // start a thread to manage the child process
  377. // captured variables are guaranteed to be destroyed only after the thread is joined
  378. inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port]() {
  379. // read stdout/stderr and forward to main server log
  380. FILE * p_stdout_stderr = subprocess_stdout(child_proc.get());
  381. if (p_stdout_stderr) {
  382. char buffer[4096];
  383. while (fgets(buffer, sizeof(buffer), p_stdout_stderr) != nullptr) {
  384. LOG("[%5d] %s", port, buffer);
  385. }
  386. } else {
  387. SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str());
  388. }
  389. // we reach here when the child process exits
  390. int exit_code = 0;
  391. subprocess_join(child_proc.get(), &exit_code);
  392. subprocess_destroy(child_proc.get());
  393. // update PID and status
  394. {
  395. std::lock_guard<std::mutex> lk(mutex);
  396. auto it = mapping.find(name);
  397. if (it != mapping.end()) {
  398. auto & meta = it->second.meta;
  399. meta.exit_code = exit_code;
  400. meta.status = SERVER_MODEL_STATUS_UNLOADED;
  401. }
  402. cv.notify_all();
  403. }
  404. SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
  405. });
  406. // clean up old process/thread if exists
  407. {
  408. auto & old_instance = mapping[name];
  409. // old process should have exited already, but just in case, we clean it up here
  410. if (subprocess_alive(old_instance.subproc.get())) {
  411. SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
  412. subprocess_terminate(old_instance.subproc.get()); // force kill
  413. }
  414. if (old_instance.th.joinable()) {
  415. old_instance.th.join();
  416. }
  417. }
  418. mapping[name] = std::move(inst);
  419. cv.notify_all();
  420. }
  421. static void interrupt_subprocess(FILE * stdin_file) {
  422. // because subprocess.h does not provide a way to send SIGINT,
  423. // we will send a command to the child process to exit gracefully
  424. if (stdin_file) {
  425. fprintf(stdin_file, "%s\n", CMD_EXIT);
  426. fflush(stdin_file);
  427. }
  428. }
  429. void server_models::unload(const std::string & name) {
  430. std::lock_guard<std::mutex> lk(mutex);
  431. auto it = mapping.find(name);
  432. if (it != mapping.end()) {
  433. if (it->second.meta.is_active()) {
  434. SRV_INF("unloading model instance name=%s\n", name.c_str());
  435. interrupt_subprocess(it->second.stdin_file);
  436. // status change will be handled by the managing thread
  437. } else {
  438. SRV_WRN("model instance name=%s is not loaded\n", name.c_str());
  439. }
  440. }
  441. }
  442. void server_models::unload_all() {
  443. std::vector<std::thread> to_join;
  444. {
  445. std::lock_guard<std::mutex> lk(mutex);
  446. for (auto & [name, inst] : mapping) {
  447. if (inst.meta.is_active()) {
  448. SRV_INF("unloading model instance name=%s\n", name.c_str());
  449. interrupt_subprocess(inst.stdin_file);
  450. // status change will be handled by the managing thread
  451. }
  452. // moving the thread to join list to avoid deadlock
  453. to_join.push_back(std::move(inst.th));
  454. }
  455. }
  456. for (auto & th : to_join) {
  457. if (th.joinable()) {
  458. th.join();
  459. }
  460. }
  461. }
  462. void server_models::update_status(const std::string & name, server_model_status status) {
  463. // for now, we only allow updating to LOADED status
  464. if (status != SERVER_MODEL_STATUS_LOADED) {
  465. throw std::runtime_error("invalid status value");
  466. }
  467. auto meta = get_meta(name);
  468. if (meta.has_value()) {
  469. meta->status = status;
  470. update_meta(name, meta.value());
  471. }
  472. }
  473. void server_models::wait_until_loaded(const std::string & name) {
  474. std::unique_lock<std::mutex> lk(mutex);
  475. cv.wait(lk, [this, &name]() {
  476. auto it = mapping.find(name);
  477. if (it != mapping.end()) {
  478. return it->second.meta.status != SERVER_MODEL_STATUS_LOADING;
  479. }
  480. return false;
  481. });
  482. }
  483. bool server_models::ensure_model_loaded(const std::string & name) {
  484. auto meta = get_meta(name);
  485. if (!meta.has_value()) {
  486. throw std::runtime_error("model name=" + name + " is not found");
  487. }
  488. if (meta->status == SERVER_MODEL_STATUS_LOADED) {
  489. return false; // already loaded
  490. }
  491. if (meta->status == SERVER_MODEL_STATUS_UNLOADED) {
  492. SRV_INF("model name=%s is not loaded, loading...\n", name.c_str());
  493. load(name, true);
  494. }
  495. SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str());
  496. wait_until_loaded(name);
  497. // check final status
  498. meta = get_meta(name);
  499. if (!meta.has_value() || meta->is_failed()) {
  500. throw std::runtime_error("model name=" + name + " failed to load");
  501. }
  502. return true;
  503. }
  504. server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used) {
  505. auto meta = get_meta(name);
  506. if (!meta.has_value()) {
  507. throw std::runtime_error("model name=" + name + " is not found");
  508. }
  509. if (meta->status != SERVER_MODEL_STATUS_LOADED) {
  510. throw std::invalid_argument("model name=" + name + " is not loaded");
  511. }
  512. if (update_last_used) {
  513. std::unique_lock<std::mutex> lk(mutex);
  514. mapping[name].meta.last_used = ggml_time_ms();
  515. }
  516. SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port);
  517. auto proxy = std::make_unique<server_http_proxy>(
  518. method,
  519. base_params.hostname,
  520. meta->port,
  521. req.path,
  522. req.headers,
  523. req.body,
  524. req.should_stop);
  525. return proxy;
  526. }
  527. std::thread server_models::setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function<void(int)> & shutdown_handler) {
  528. // send a notification to the router server that a model instance is ready
  529. // TODO @ngxson : use HTTP client from libcommon
  530. httplib::Client cli(base_params.hostname, router_port);
  531. cli.set_connection_timeout(0, 200000); // 200 milliseconds
  532. httplib::Request req;
  533. req.method = "POST";
  534. req.path = "/models/status";
  535. req.set_header("Content-Type", "application/json");
  536. if (!base_params.api_keys.empty()) {
  537. req.set_header("Authorization", "Bearer " + base_params.api_keys[0]);
  538. }
  539. json body;
  540. body["model"] = name;
  541. body["value"] = server_model_status_to_string(SERVER_MODEL_STATUS_LOADED);
  542. req.body = body.dump();
  543. SRV_INF("notifying router server (port=%d) that model %s is ready\n", router_port, name.c_str());
  544. auto result = cli.send(std::move(req));
  545. if (result.error() != httplib::Error::Success) {
  546. auto err_str = httplib::to_string(result.error());
  547. SRV_ERR("failed to notify router server: %s\n", err_str.c_str());
  548. exit(1); // force exit
  549. }
  550. // setup thread for monitoring stdin
  551. return std::thread([shutdown_handler]() {
  552. // wait for EOF on stdin
  553. SRV_INF("%s", "child server monitoring thread started, waiting for EOF on stdin...\n");
  554. bool eof = false;
  555. while (true) {
  556. std::string line;
  557. if (!std::getline(std::cin, line)) {
  558. // EOF detected, that means the router server is unexpectedly exit or killed
  559. eof = true;
  560. break;
  561. }
  562. if (line.find(CMD_EXIT) != std::string::npos) {
  563. SRV_INF("%s", "exit command received, exiting...\n");
  564. shutdown_handler(0);
  565. break;
  566. }
  567. }
  568. if (eof) {
  569. SRV_INF("%s", "EOF on stdin detected, forcing shutdown...\n");
  570. exit(1);
  571. }
  572. });
  573. }
  574. //
  575. // server_models_routes
  576. //
  577. static void res_ok(std::unique_ptr<server_http_res> & res, const json & response_data) {
  578. res->status = 200;
  579. res->data = safe_json_to_str(response_data);
  580. }
  581. static void res_err(std::unique_ptr<server_http_res> & res, const json & error_data) {
  582. res->status = json_value(error_data, "code", 500);
  583. res->data = safe_json_to_str({{ "error", error_data }});
  584. }
  585. static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr<server_http_res> & res) {
  586. if (name.empty()) {
  587. res_err(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST));
  588. return false;
  589. }
  590. auto meta = models.get_meta(name);
  591. if (!meta.has_value()) {
  592. res_err(res, format_error_response("model not found", ERROR_TYPE_INVALID_REQUEST));
  593. return false;
  594. }
  595. if (models_autoload) {
  596. models.ensure_model_loaded(name);
  597. } else {
  598. if (meta->status != SERVER_MODEL_STATUS_LOADED) {
  599. res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
  600. return false;
  601. }
  602. }
  603. return true;
  604. }
  605. static bool is_autoload(const common_params & params, const server_http_req & req) {
  606. std::string autoload = req.get_param("autoload");
  607. if (autoload.empty()) {
  608. return params.models_autoload;
  609. } else {
  610. return autoload == "true" || autoload == "1";
  611. }
  612. }
  613. void server_models_routes::init_routes() {
  614. this->get_router_props = [this](const server_http_req & req) {
  615. std::string name = req.get_param("model");
  616. if (name.empty()) {
  617. // main instance
  618. auto res = std::make_unique<server_http_res>();
  619. res_ok(res, {
  620. // TODO: add support for this on web UI
  621. {"role", "router"},
  622. {"max_instances", 4}, // dummy value for testing
  623. // this is a dummy response to make sure webui doesn't break
  624. {"model_alias", "llama-server"},
  625. {"model_path", "none"},
  626. {"default_generation_settings", {
  627. {"params", json{}},
  628. {"n_ctx", 0},
  629. }},
  630. });
  631. return res;
  632. }
  633. return proxy_get(req);
  634. };
  635. this->proxy_get = [this](const server_http_req & req) {
  636. std::string method = "GET";
  637. std::string name = req.get_param("model");
  638. bool autoload = is_autoload(params, req);
  639. auto error_res = std::make_unique<server_http_res>();
  640. if (!router_validate_model(name, models, autoload, error_res)) {
  641. return error_res;
  642. }
  643. return models.proxy_request(req, method, name, false);
  644. };
  645. this->proxy_post = [this](const server_http_req & req) {
  646. std::string method = "POST";
  647. json body = json::parse(req.body);
  648. std::string name = json_value(body, "model", std::string());
  649. bool autoload = is_autoload(params, req);
  650. auto error_res = std::make_unique<server_http_res>();
  651. if (!router_validate_model(name, models, autoload, error_res)) {
  652. return error_res;
  653. }
  654. return models.proxy_request(req, method, name, true); // update last usage for POST request only
  655. };
  656. this->get_router_models = [this](const server_http_req &) {
  657. auto res = std::make_unique<server_http_res>();
  658. json models_json = json::array();
  659. auto all_models = models.get_all_meta();
  660. std::time_t t = std::time(0);
  661. for (const auto & meta : all_models) {
  662. json status {
  663. {"value", server_model_status_to_string(meta.status)},
  664. {"args", meta.args},
  665. };
  666. if (meta.is_failed()) {
  667. status["exit_code"] = meta.exit_code;
  668. status["failed"] = true;
  669. }
  670. models_json.push_back(json {
  671. {"id", meta.name},
  672. {"object", "model"}, // for OAI-compat
  673. {"owned_by", "llamacpp"}, // for OAI-compat
  674. {"created", t}, // for OAI-compat
  675. {"in_cache", meta.in_cache},
  676. {"path", meta.path},
  677. {"status", status},
  678. // TODO: add other fields, may require reading GGUF metadata
  679. });
  680. }
  681. res_ok(res, {
  682. {"data", models_json},
  683. {"object", "list"},
  684. });
  685. return res;
  686. };
  687. this->post_router_models_load = [this](const server_http_req & req) {
  688. auto res = std::make_unique<server_http_res>();
  689. json body = json::parse(req.body);
  690. std::string name = json_value(body, "model", std::string());
  691. auto model = models.get_meta(name);
  692. if (!model.has_value()) {
  693. res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND));
  694. return res;
  695. }
  696. if (model->status == SERVER_MODEL_STATUS_LOADED) {
  697. res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST));
  698. return res;
  699. }
  700. models.load(name, false);
  701. res_ok(res, {{"success", true}});
  702. return res;
  703. };
  704. // used by child process to notify the router about status change
  705. // TODO @ngxson : maybe implement authentication for this endpoint in the future
  706. this->post_router_models_status = [this](const server_http_req & req) {
  707. auto res = std::make_unique<server_http_res>();
  708. json body = json::parse(req.body);
  709. std::string model = json_value(body, "model", std::string());
  710. std::string value = json_value(body, "value", std::string());
  711. models.update_status(model, server_model_status_from_string(value));
  712. res_ok(res, {{"success", true}});
  713. return res;
  714. };
  715. this->get_router_models = [this](const server_http_req &) {
  716. auto res = std::make_unique<server_http_res>();
  717. json models_json = json::array();
  718. auto all_models = models.get_all_meta();
  719. std::time_t t = std::time(0);
  720. for (const auto & meta : all_models) {
  721. json status {
  722. {"value", server_model_status_to_string(meta.status)},
  723. {"args", meta.args},
  724. };
  725. if (meta.is_failed()) {
  726. status["exit_code"] = meta.exit_code;
  727. status["failed"] = true;
  728. }
  729. models_json.push_back(json {
  730. {"id", meta.name},
  731. {"object", "model"}, // for OAI-compat
  732. {"owned_by", "llamacpp"}, // for OAI-compat
  733. {"created", t}, // for OAI-compat
  734. {"in_cache", meta.in_cache},
  735. {"path", meta.path},
  736. {"status", status},
  737. // TODO: add other fields, may require reading GGUF metadata
  738. });
  739. }
  740. res_ok(res, {
  741. {"data", models_json},
  742. {"object", "list"},
  743. });
  744. return res;
  745. };
  746. this->post_router_models_unload = [this](const server_http_req & req) {
  747. auto res = std::make_unique<server_http_res>();
  748. json body = json::parse(req.body);
  749. std::string name = json_value(body, "model", std::string());
  750. auto model = models.get_meta(name);
  751. if (!model.has_value()) {
  752. res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
  753. return res;
  754. }
  755. if (model->status != SERVER_MODEL_STATUS_LOADED) {
  756. res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
  757. return res;
  758. }
  759. models.unload(name);
  760. res_ok(res, {{"success", true}});
  761. return res;
  762. };
  763. }
  764. //
  765. // server_http_proxy
  766. //
  767. // simple implementation of a pipe
  768. // used for streaming data between threads
  769. template<typename T>
  770. struct pipe_t {
  771. std::mutex mutex;
  772. std::condition_variable cv;
  773. std::queue<T> queue;
  774. std::atomic<bool> writer_closed{false};
  775. std::atomic<bool> reader_closed{false};
  776. void close_write() {
  777. writer_closed.store(true, std::memory_order_relaxed);
  778. cv.notify_all();
  779. }
  780. void close_read() {
  781. reader_closed.store(true, std::memory_order_relaxed);
  782. cv.notify_all();
  783. }
  784. bool read(T & output, const std::function<bool()> & should_stop) {
  785. std::unique_lock<std::mutex> lk(mutex);
  786. constexpr auto poll_interval = std::chrono::milliseconds(500);
  787. while (true) {
  788. if (!queue.empty()) {
  789. output = std::move(queue.front());
  790. queue.pop();
  791. return true;
  792. }
  793. if (writer_closed.load()) {
  794. return false; // clean EOF
  795. }
  796. if (should_stop()) {
  797. close_read(); // signal broken pipe to writer
  798. return false; // cancelled / reader no longer alive
  799. }
  800. cv.wait_for(lk, poll_interval);
  801. }
  802. }
  803. bool write(T && data) {
  804. std::lock_guard<std::mutex> lk(mutex);
  805. if (reader_closed.load()) {
  806. return false; // broken pipe
  807. }
  808. queue.push(std::move(data));
  809. cv.notify_one();
  810. return true;
  811. }
  812. };
  813. static std::string to_lower_copy(const std::string & value) {
  814. std::string lowered(value.size(), '\0');
  815. std::transform(value.begin(), value.end(), lowered.begin(), [](unsigned char c) { return std::tolower(c); });
  816. return lowered;
  817. }
  818. static bool should_strip_proxy_header(const std::string & header_name) {
  819. // Headers that get duplicated when router forwards child responses
  820. if (header_name == "server" ||
  821. header_name == "transfer-encoding" ||
  822. header_name == "content-length" || // quick fix for https://github.com/ggml-org/llama.cpp/issues/17710
  823. header_name == "keep-alive") {
  824. return true;
  825. }
  826. // Router injects CORS, child also sends them: duplicate
  827. if (header_name.rfind("access-control-", 0) == 0) {
  828. return true;
  829. }
  830. return false;
  831. }
  832. server_http_proxy::server_http_proxy(
  833. const std::string & method,
  834. const std::string & host,
  835. int port,
  836. const std::string & path,
  837. const std::map<std::string, std::string> & headers,
  838. const std::string & body,
  839. const std::function<bool()> should_stop) {
  840. // shared between reader and writer threads
  841. auto cli = std::make_shared<httplib::Client>(host, port);
  842. auto pipe = std::make_shared<pipe_t<msg_t>>();
  843. // setup Client
  844. cli->set_connection_timeout(0, 200000); // 200 milliseconds
  845. this->status = 500; // to be overwritten upon response
  846. this->cleanup = [pipe]() {
  847. pipe->close_read();
  848. pipe->close_write();
  849. };
  850. // wire up the receive end of the pipe
  851. this->next = [pipe, should_stop](std::string & out) -> bool {
  852. msg_t msg;
  853. bool has_next = pipe->read(msg, should_stop);
  854. if (!msg.data.empty()) {
  855. out = std::move(msg.data);
  856. }
  857. return has_next; // false if EOF or pipe broken
  858. };
  859. // wire up the HTTP client
  860. // note: do NOT capture `this` pointer, as it may be destroyed before the thread ends
  861. httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) {
  862. msg_t msg;
  863. msg.status = response.status;
  864. for (const auto & [key, value] : response.headers) {
  865. const auto lowered = to_lower_copy(key);
  866. if (should_strip_proxy_header(lowered)) {
  867. continue;
  868. }
  869. if (lowered == "content-type") {
  870. msg.content_type = value;
  871. continue;
  872. }
  873. msg.headers[key] = value;
  874. }
  875. return pipe->write(std::move(msg)); // send headers first
  876. };
  877. httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) {
  878. // send data chunks
  879. // returns false if pipe is closed / broken (signal to stop receiving)
  880. return pipe->write({{}, 0, std::string(data, data_length), ""});
  881. };
  882. // prepare the request to destination server
  883. httplib::Request req;
  884. {
  885. req.method = method;
  886. req.path = path;
  887. for (const auto & [key, value] : headers) {
  888. req.set_header(key, value);
  889. }
  890. req.body = body;
  891. req.response_handler = response_handler;
  892. req.content_receiver = content_receiver;
  893. }
  894. // start the proxy thread
  895. SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str());
  896. this->thread = std::thread([cli, pipe, req]() {
  897. auto result = cli->send(std::move(req));
  898. if (result.error() != httplib::Error::Success) {
  899. auto err_str = httplib::to_string(result.error());
  900. SRV_ERR("http client error: %s\n", err_str.c_str());
  901. pipe->write({{}, 500, "", ""}); // header
  902. pipe->write({{}, 0, "proxy error: " + err_str, ""}); // body
  903. }
  904. pipe->close_write(); // signal EOF to reader
  905. SRV_DBG("%s", "client request thread ended\n");
  906. });
  907. this->thread.detach();
  908. // wait for the first chunk (headers)
  909. {
  910. msg_t header;
  911. if (pipe->read(header, should_stop)) {
  912. SRV_DBG("%s", "received response headers\n");
  913. this->status = header.status;
  914. this->headers = std::move(header.headers);
  915. if (!header.content_type.empty()) {
  916. this->content_type = std::move(header.content_type);
  917. }
  918. } else {
  919. SRV_DBG("%s", "no response headers received (request cancelled?)\n");
  920. }
  921. }
  922. }