1
0

server-models.cpp 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980
  1. #include "server-common.h"
  2. #include "server-models.h"
  3. #include "preset.h"
  4. #include "download.h"
  5. #include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
  6. #include <sheredom/subprocess.h>
  7. #include <functional>
  8. #include <algorithm>
  9. #include <thread>
  10. #include <mutex>
  11. #include <condition_variable>
  12. #include <cstring>
  13. #include <atomic>
  14. #include <chrono>
  15. #include <queue>
  16. #include <filesystem>
  17. #include <cstring>
  18. #ifdef _WIN32
  19. #include <winsock2.h>
  20. #else
  21. #include <sys/socket.h>
  22. #include <netinet/in.h>
  23. #include <arpa/inet.h>
  24. #include <unistd.h>
  25. #endif
  26. #if defined(__APPLE__) && defined(__MACH__)
  27. // macOS: use _NSGetExecutablePath to get the executable path
  28. #include <mach-o/dyld.h>
  29. #include <limits.h>
  30. #endif
  31. #define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit"
  32. #define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready"
  33. // address for child process, this is needed because router may run on 0.0.0.0
  34. // ref: https://github.com/ggml-org/llama.cpp/issues/17862
  35. #define CHILD_ADDR "127.0.0.1"
  36. static std::filesystem::path get_server_exec_path() {
  37. #if defined(_WIN32)
  38. wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths
  39. DWORD len = GetModuleFileNameW(nullptr, buf, _countof(buf));
  40. if (len == 0 || len >= _countof(buf)) {
  41. throw std::runtime_error("GetModuleFileNameW failed or path too long");
  42. }
  43. return std::filesystem::path(buf);
  44. #elif defined(__APPLE__) && defined(__MACH__)
  45. char small_path[PATH_MAX];
  46. uint32_t size = sizeof(small_path);
  47. if (_NSGetExecutablePath(small_path, &size) == 0) {
  48. // resolve any symlinks to get absolute path
  49. try {
  50. return std::filesystem::canonical(std::filesystem::path(small_path));
  51. } catch (...) {
  52. return std::filesystem::path(small_path);
  53. }
  54. } else {
  55. // buffer was too small, allocate required size and call again
  56. std::vector<char> buf(size);
  57. if (_NSGetExecutablePath(buf.data(), &size) == 0) {
  58. try {
  59. return std::filesystem::canonical(std::filesystem::path(buf.data()));
  60. } catch (...) {
  61. return std::filesystem::path(buf.data());
  62. }
  63. }
  64. throw std::runtime_error("_NSGetExecutablePath failed after buffer resize");
  65. }
  66. #else
  67. char path[FILENAME_MAX];
  68. ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX);
  69. if (count <= 0) {
  70. throw std::runtime_error("failed to resolve /proc/self/exe");
  71. }
  72. return std::filesystem::path(std::string(path, count));
  73. #endif
  74. }
  75. static void unset_reserved_args(common_preset & preset, bool unset_model_args) {
  76. preset.unset_option("LLAMA_ARG_SSL_KEY_FILE");
  77. preset.unset_option("LLAMA_ARG_SSL_CERT_FILE");
  78. preset.unset_option("LLAMA_API_KEY");
  79. preset.unset_option("LLAMA_ARG_MODELS_DIR");
  80. preset.unset_option("LLAMA_ARG_MODELS_MAX");
  81. preset.unset_option("LLAMA_ARG_MODELS_PRESET");
  82. preset.unset_option("LLAMA_ARG_MODELS_AUTOLOAD");
  83. if (unset_model_args) {
  84. preset.unset_option("LLAMA_ARG_MODEL");
  85. preset.unset_option("LLAMA_ARG_MMPROJ");
  86. preset.unset_option("LLAMA_ARG_HF_REPO");
  87. }
  88. }
  89. void server_model_meta::update_args(common_preset_context & ctx_preset, std::string bin_path) {
  90. // update params
  91. unset_reserved_args(preset, false);
  92. preset.set_option(ctx_preset, "LLAMA_ARG_HOST", CHILD_ADDR);
  93. preset.set_option(ctx_preset, "LLAMA_ARG_PORT", std::to_string(port));
  94. preset.set_option(ctx_preset, "LLAMA_ARG_ALIAS", name);
  95. // TODO: maybe validate preset before rendering ?
  96. // render args
  97. args = preset.to_args(bin_path);
  98. }
  99. //
  100. // server_models
  101. //
  102. server_models::server_models(
  103. const common_params & params,
  104. int argc,
  105. char ** argv,
  106. char ** envp)
  107. : ctx_preset(LLAMA_EXAMPLE_SERVER),
  108. base_params(params),
  109. base_preset(ctx_preset.load_from_args(argc, argv)) {
  110. for (char ** env = envp; *env != nullptr; env++) {
  111. base_env.push_back(std::string(*env));
  112. }
  113. // clean up base preset
  114. unset_reserved_args(base_preset, true);
  115. // set binary path
  116. try {
  117. bin_path = get_server_exec_path().string();
  118. } catch (const std::exception & e) {
  119. bin_path = argv[0];
  120. LOG_WRN("failed to get server executable path: %s\n", e.what());
  121. LOG_WRN("using original argv[0] as fallback: %s\n", argv[0]);
  122. }
  123. load_models();
  124. }
  125. void server_models::add_model(server_model_meta && meta) {
  126. if (mapping.find(meta.name) != mapping.end()) {
  127. throw std::runtime_error(string_format("model '%s' appears multiple times", meta.name.c_str()));
  128. }
  129. meta.update_args(ctx_preset, bin_path); // render args
  130. std::string name = meta.name;
  131. mapping[name] = instance_t{
  132. /* subproc */ std::make_shared<subprocess_s>(),
  133. /* th */ std::thread(),
  134. /* meta */ std::move(meta)
  135. };
  136. }
  137. // TODO: allow refreshing cached model list
  138. void server_models::load_models() {
  139. // loading models from 3 sources:
  140. // 1. cached models
  141. common_presets cached_models = ctx_preset.load_from_cache();
  142. SRV_INF("Loaded %zu cached model presets\n", cached_models.size());
  143. // 2. local models from --models-dir
  144. common_presets local_models;
  145. if (!base_params.models_dir.empty()) {
  146. local_models = ctx_preset.load_from_models_dir(base_params.models_dir);
  147. SRV_INF("Loaded %zu local model presets from %s\n", local_models.size(), base_params.models_dir.c_str());
  148. }
  149. // 3. custom-path models from presets
  150. common_preset global = {};
  151. common_presets custom_presets = {};
  152. if (!base_params.models_preset.empty()) {
  153. custom_presets = ctx_preset.load_from_ini(base_params.models_preset, global);
  154. SRV_INF("Loaded %zu custom model presets from %s\n", custom_presets.size(), base_params.models_preset.c_str());
  155. }
  156. // cascade, apply global preset first
  157. cached_models = ctx_preset.cascade(global, cached_models);
  158. local_models = ctx_preset.cascade(global, local_models);
  159. custom_presets = ctx_preset.cascade(global, custom_presets);
  160. // note: if a model exists in both cached and local, local takes precedence
  161. common_presets final_presets;
  162. for (const auto & [name, preset] : cached_models) {
  163. final_presets[name] = preset;
  164. }
  165. for (const auto & [name, preset] : local_models) {
  166. final_presets[name] = preset;
  167. }
  168. // process custom presets from INI
  169. for (const auto & [name, custom] : custom_presets) {
  170. if (final_presets.find(name) != final_presets.end()) {
  171. // apply custom config if exists
  172. common_preset & target = final_presets[name];
  173. target.merge(custom);
  174. } else {
  175. // otherwise add directly
  176. final_presets[name] = custom;
  177. }
  178. }
  179. // server base preset from CLI args take highest precedence
  180. for (auto & [name, preset] : final_presets) {
  181. preset.merge(base_preset);
  182. }
  183. // convert presets to server_model_meta and add to mapping
  184. for (const auto & preset : final_presets) {
  185. server_model_meta meta{
  186. /* preset */ preset.second,
  187. /* name */ preset.first,
  188. /* port */ 0,
  189. /* status */ SERVER_MODEL_STATUS_UNLOADED,
  190. /* last_used */ 0,
  191. /* args */ std::vector<std::string>(),
  192. /* exit_code */ 0
  193. };
  194. add_model(std::move(meta));
  195. }
  196. // log available models
  197. {
  198. std::unordered_set<std::string> custom_names;
  199. for (const auto & [name, preset] : custom_presets) {
  200. custom_names.insert(name);
  201. }
  202. SRV_INF("Available models (%zu) (*: custom preset)\n", mapping.size());
  203. for (const auto & [name, inst] : mapping) {
  204. bool has_custom = custom_names.find(name) != custom_names.end();
  205. SRV_INF(" %c %s\n", has_custom ? '*' : ' ', name.c_str());
  206. }
  207. }
  208. // load any autoload models
  209. std::vector<std::string> models_to_load;
  210. for (const auto & [name, inst] : mapping) {
  211. std::string val;
  212. if (inst.meta.preset.get_option(COMMON_ARG_PRESET_LOAD_ON_STARTUP, val)) {
  213. models_to_load.push_back(name);
  214. }
  215. }
  216. if ((int)models_to_load.size() > base_params.models_max) {
  217. throw std::runtime_error(string_format(
  218. "number of models to load on startup (%zu) exceeds models_max (%d)",
  219. models_to_load.size(),
  220. base_params.models_max
  221. ));
  222. }
  223. for (const auto & name : models_to_load) {
  224. SRV_INF("(startup) loading model %s\n", name.c_str());
  225. load(name);
  226. }
  227. }
  228. void server_models::update_meta(const std::string & name, const server_model_meta & meta) {
  229. std::lock_guard<std::mutex> lk(mutex);
  230. auto it = mapping.find(name);
  231. if (it != mapping.end()) {
  232. it->second.meta = meta;
  233. }
  234. cv.notify_all(); // notify wait_until_loaded
  235. }
  236. bool server_models::has_model(const std::string & name) {
  237. std::lock_guard<std::mutex> lk(mutex);
  238. return mapping.find(name) != mapping.end();
  239. }
  240. std::optional<server_model_meta> server_models::get_meta(const std::string & name) {
  241. std::lock_guard<std::mutex> lk(mutex);
  242. auto it = mapping.find(name);
  243. if (it != mapping.end()) {
  244. return it->second.meta;
  245. }
  246. return std::nullopt;
  247. }
  248. static int get_free_port() {
  249. #ifdef _WIN32
  250. WSADATA wsaData;
  251. if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
  252. return -1;
  253. }
  254. typedef SOCKET native_socket_t;
  255. #define INVALID_SOCKET_VAL INVALID_SOCKET
  256. #define CLOSE_SOCKET(s) closesocket(s)
  257. #else
  258. typedef int native_socket_t;
  259. #define INVALID_SOCKET_VAL -1
  260. #define CLOSE_SOCKET(s) close(s)
  261. #endif
  262. native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
  263. if (sock == INVALID_SOCKET_VAL) {
  264. #ifdef _WIN32
  265. WSACleanup();
  266. #endif
  267. return -1;
  268. }
  269. struct sockaddr_in serv_addr;
  270. std::memset(&serv_addr, 0, sizeof(serv_addr));
  271. serv_addr.sin_family = AF_INET;
  272. serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
  273. serv_addr.sin_port = htons(0);
  274. if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
  275. CLOSE_SOCKET(sock);
  276. #ifdef _WIN32
  277. WSACleanup();
  278. #endif
  279. return -1;
  280. }
  281. #ifdef _WIN32
  282. int namelen = sizeof(serv_addr);
  283. #else
  284. socklen_t namelen = sizeof(serv_addr);
  285. #endif
  286. if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
  287. CLOSE_SOCKET(sock);
  288. #ifdef _WIN32
  289. WSACleanup();
  290. #endif
  291. return -1;
  292. }
  293. int port = ntohs(serv_addr.sin_port);
  294. CLOSE_SOCKET(sock);
  295. #ifdef _WIN32
  296. WSACleanup();
  297. #endif
  298. return port;
  299. }
  300. // helper to convert vector<string> to char **
  301. // pointers are only valid as long as the original vector is valid
  302. static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & vec) {
  303. std::vector<char *> result;
  304. result.reserve(vec.size() + 1);
  305. for (const auto & s : vec) {
  306. result.push_back(const_cast<char*>(s.c_str()));
  307. }
  308. result.push_back(nullptr);
  309. return result;
  310. }
  311. std::vector<server_model_meta> server_models::get_all_meta() {
  312. std::lock_guard<std::mutex> lk(mutex);
  313. std::vector<server_model_meta> result;
  314. result.reserve(mapping.size());
  315. for (const auto & [name, inst] : mapping) {
  316. result.push_back(inst.meta);
  317. }
  318. return result;
  319. }
  320. void server_models::unload_lru() {
  321. if (base_params.models_max <= 0) {
  322. return; // no limit
  323. }
  324. // remove one of the servers if we passed the models_max (least recently used - LRU)
  325. std::string lru_model_name = "";
  326. int64_t lru_last_used = ggml_time_ms();
  327. size_t count_active = 0;
  328. {
  329. std::lock_guard<std::mutex> lk(mutex);
  330. for (const auto & m : mapping) {
  331. if (m.second.meta.is_active()) {
  332. count_active++;
  333. if (m.second.meta.last_used < lru_last_used) {
  334. lru_model_name = m.first;
  335. lru_last_used = m.second.meta.last_used;
  336. }
  337. }
  338. }
  339. }
  340. if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) {
  341. SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str());
  342. unload(lru_model_name);
  343. }
  344. }
  345. void server_models::load(const std::string & name) {
  346. if (!has_model(name)) {
  347. throw std::runtime_error("model name=" + name + " is not found");
  348. }
  349. unload_lru();
  350. std::lock_guard<std::mutex> lk(mutex);
  351. auto meta = mapping[name].meta;
  352. if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
  353. SRV_INF("model %s is not ready\n", name.c_str());
  354. return;
  355. }
  356. // prepare new instance info
  357. instance_t inst;
  358. inst.meta = meta;
  359. inst.meta.port = get_free_port();
  360. inst.meta.status = SERVER_MODEL_STATUS_LOADING;
  361. inst.meta.last_used = ggml_time_ms();
  362. if (inst.meta.port <= 0) {
  363. throw std::runtime_error("failed to get a port number");
  364. }
  365. inst.subproc = std::make_shared<subprocess_s>();
  366. {
  367. SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port);
  368. inst.meta.update_args(ctx_preset, bin_path); // render args
  369. std::vector<std::string> child_args = inst.meta.args; // copy
  370. std::vector<std::string> child_env = base_env; // copy
  371. child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));
  372. SRV_INF("%s", "spawning server instance with args:\n");
  373. for (const auto & arg : child_args) {
  374. SRV_INF(" %s\n", arg.c_str());
  375. }
  376. inst.meta.args = child_args; // save for debugging
  377. std::vector<char *> argv = to_char_ptr_array(child_args);
  378. std::vector<char *> envp = to_char_ptr_array(child_env);
  379. // TODO @ngxson : maybe separate stdout and stderr in the future
  380. // so that we can use stdout for commands and stderr for logging
  381. int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
  382. int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get());
  383. if (result != 0) {
  384. throw std::runtime_error("failed to spawn server instance");
  385. }
  386. inst.stdin_file = subprocess_stdin(inst.subproc.get());
  387. }
  388. // start a thread to manage the child process
  389. // captured variables are guaranteed to be destroyed only after the thread is joined
  390. inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port]() {
  391. // read stdout/stderr and forward to main server log
  392. bool state_received = false; // true if child state received
  393. FILE * p_stdout_stderr = subprocess_stdout(child_proc.get());
  394. if (p_stdout_stderr) {
  395. char buffer[4096];
  396. while (fgets(buffer, sizeof(buffer), p_stdout_stderr) != nullptr) {
  397. LOG("[%5d] %s", port, buffer);
  398. if (!state_received && std::strstr(buffer, CMD_CHILD_TO_ROUTER_READY) != nullptr) {
  399. // child process is ready
  400. this->update_status(name, SERVER_MODEL_STATUS_LOADED);
  401. state_received = true;
  402. }
  403. }
  404. } else {
  405. SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str());
  406. }
  407. // we reach here when the child process exits
  408. int exit_code = 0;
  409. subprocess_join(child_proc.get(), &exit_code);
  410. subprocess_destroy(child_proc.get());
  411. // update PID and status
  412. {
  413. std::lock_guard<std::mutex> lk(mutex);
  414. auto it = mapping.find(name);
  415. if (it != mapping.end()) {
  416. auto & meta = it->second.meta;
  417. meta.exit_code = exit_code;
  418. meta.status = SERVER_MODEL_STATUS_UNLOADED;
  419. }
  420. cv.notify_all();
  421. }
  422. SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
  423. });
  424. // clean up old process/thread if exists
  425. {
  426. auto & old_instance = mapping[name];
  427. // old process should have exited already, but just in case, we clean it up here
  428. if (subprocess_alive(old_instance.subproc.get())) {
  429. SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
  430. subprocess_terminate(old_instance.subproc.get()); // force kill
  431. }
  432. if (old_instance.th.joinable()) {
  433. old_instance.th.join();
  434. }
  435. }
  436. mapping[name] = std::move(inst);
  437. cv.notify_all();
  438. }
  439. static void interrupt_subprocess(FILE * stdin_file) {
  440. // because subprocess.h does not provide a way to send SIGINT,
  441. // we will send a command to the child process to exit gracefully
  442. if (stdin_file) {
  443. fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
  444. fflush(stdin_file);
  445. }
  446. }
  447. void server_models::unload(const std::string & name) {
  448. std::lock_guard<std::mutex> lk(mutex);
  449. auto it = mapping.find(name);
  450. if (it != mapping.end()) {
  451. if (it->second.meta.is_active()) {
  452. SRV_INF("unloading model instance name=%s\n", name.c_str());
  453. interrupt_subprocess(it->second.stdin_file);
  454. // status change will be handled by the managing thread
  455. } else {
  456. SRV_WRN("model instance name=%s is not loaded\n", name.c_str());
  457. }
  458. }
  459. }
  460. void server_models::unload_all() {
  461. std::vector<std::thread> to_join;
  462. {
  463. std::lock_guard<std::mutex> lk(mutex);
  464. for (auto & [name, inst] : mapping) {
  465. if (inst.meta.is_active()) {
  466. SRV_INF("unloading model instance name=%s\n", name.c_str());
  467. interrupt_subprocess(inst.stdin_file);
  468. // status change will be handled by the managing thread
  469. }
  470. // moving the thread to join list to avoid deadlock
  471. to_join.push_back(std::move(inst.th));
  472. }
  473. }
  474. for (auto & th : to_join) {
  475. if (th.joinable()) {
  476. th.join();
  477. }
  478. }
  479. }
  480. void server_models::update_status(const std::string & name, server_model_status status) {
  481. // for now, we only allow updating to LOADED status
  482. if (status != SERVER_MODEL_STATUS_LOADED) {
  483. throw std::runtime_error("invalid status value");
  484. }
  485. auto meta = get_meta(name);
  486. if (meta.has_value()) {
  487. meta->status = status;
  488. update_meta(name, meta.value());
  489. }
  490. }
  491. void server_models::wait_until_loaded(const std::string & name) {
  492. std::unique_lock<std::mutex> lk(mutex);
  493. cv.wait(lk, [this, &name]() {
  494. auto it = mapping.find(name);
  495. if (it != mapping.end()) {
  496. return it->second.meta.status != SERVER_MODEL_STATUS_LOADING;
  497. }
  498. return false;
  499. });
  500. }
  501. bool server_models::ensure_model_loaded(const std::string & name) {
  502. auto meta = get_meta(name);
  503. if (!meta.has_value()) {
  504. throw std::runtime_error("model name=" + name + " is not found");
  505. }
  506. if (meta->status == SERVER_MODEL_STATUS_LOADED) {
  507. return false; // already loaded
  508. }
  509. if (meta->status == SERVER_MODEL_STATUS_UNLOADED) {
  510. SRV_INF("model name=%s is not loaded, loading...\n", name.c_str());
  511. load(name);
  512. }
  513. SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str());
  514. wait_until_loaded(name);
  515. // check final status
  516. meta = get_meta(name);
  517. if (!meta.has_value() || meta->is_failed()) {
  518. throw std::runtime_error("model name=" + name + " failed to load");
  519. }
  520. return true;
  521. }
  522. server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used) {
  523. auto meta = get_meta(name);
  524. if (!meta.has_value()) {
  525. throw std::runtime_error("model name=" + name + " is not found");
  526. }
  527. if (meta->status != SERVER_MODEL_STATUS_LOADED) {
  528. throw std::invalid_argument("model name=" + name + " is not loaded");
  529. }
  530. if (update_last_used) {
  531. std::unique_lock<std::mutex> lk(mutex);
  532. mapping[name].meta.last_used = ggml_time_ms();
  533. }
  534. SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port);
  535. auto proxy = std::make_unique<server_http_proxy>(
  536. method,
  537. CHILD_ADDR,
  538. meta->port,
  539. req.path,
  540. req.headers,
  541. req.body,
  542. req.should_stop);
  543. return proxy;
  544. }
  545. std::thread server_models::setup_child_server(const std::function<void(int)> & shutdown_handler) {
  546. // send a notification to the router server that a model instance is ready
  547. common_log_pause(common_log_main());
  548. fflush(stdout);
  549. fprintf(stdout, "%s\n", CMD_CHILD_TO_ROUTER_READY);
  550. fflush(stdout);
  551. common_log_resume(common_log_main());
  552. // setup thread for monitoring stdin
  553. return std::thread([shutdown_handler]() {
  554. // wait for EOF on stdin
  555. SRV_INF("%s", "child server monitoring thread started, waiting for EOF on stdin...\n");
  556. bool eof = false;
  557. while (true) {
  558. std::string line;
  559. if (!std::getline(std::cin, line)) {
  560. // EOF detected, that means the router server is unexpectedly exit or killed
  561. eof = true;
  562. break;
  563. }
  564. if (line.find(CMD_ROUTER_TO_CHILD_EXIT) != std::string::npos) {
  565. SRV_INF("%s", "exit command received, exiting...\n");
  566. shutdown_handler(0);
  567. break;
  568. }
  569. }
  570. if (eof) {
  571. SRV_INF("%s", "EOF on stdin detected, forcing shutdown...\n");
  572. exit(1);
  573. }
  574. });
  575. }
  576. //
  577. // server_models_routes
  578. //
  579. static void res_ok(std::unique_ptr<server_http_res> & res, const json & response_data) {
  580. res->status = 200;
  581. res->data = safe_json_to_str(response_data);
  582. }
  583. static void res_err(std::unique_ptr<server_http_res> & res, const json & error_data) {
  584. res->status = json_value(error_data, "code", 500);
  585. res->data = safe_json_to_str({{ "error", error_data }});
  586. }
  587. static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr<server_http_res> & res) {
  588. if (name.empty()) {
  589. res_err(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST));
  590. return false;
  591. }
  592. auto meta = models.get_meta(name);
  593. if (!meta.has_value()) {
  594. res_err(res, format_error_response("model not found", ERROR_TYPE_INVALID_REQUEST));
  595. return false;
  596. }
  597. if (models_autoload) {
  598. models.ensure_model_loaded(name);
  599. } else {
  600. if (meta->status != SERVER_MODEL_STATUS_LOADED) {
  601. res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
  602. return false;
  603. }
  604. }
  605. return true;
  606. }
  607. static bool is_autoload(const common_params & params, const server_http_req & req) {
  608. std::string autoload = req.get_param("autoload");
  609. if (autoload.empty()) {
  610. return params.models_autoload;
  611. } else {
  612. return autoload == "true" || autoload == "1";
  613. }
  614. }
  615. void server_models_routes::init_routes() {
  616. this->get_router_props = [this](const server_http_req & req) {
  617. std::string name = req.get_param("model");
  618. if (name.empty()) {
  619. // main instance
  620. auto res = std::make_unique<server_http_res>();
  621. res_ok(res, {
  622. // TODO: add support for this on web UI
  623. {"role", "router"},
  624. {"max_instances", 4}, // dummy value for testing
  625. // this is a dummy response to make sure webui doesn't break
  626. {"model_alias", "llama-server"},
  627. {"model_path", "none"},
  628. {"default_generation_settings", {
  629. {"params", json{}},
  630. {"n_ctx", 0},
  631. }},
  632. {"webui_settings", webui_settings},
  633. });
  634. return res;
  635. }
  636. return proxy_get(req);
  637. };
  638. this->proxy_get = [this](const server_http_req & req) {
  639. std::string method = "GET";
  640. std::string name = req.get_param("model");
  641. bool autoload = is_autoload(params, req);
  642. auto error_res = std::make_unique<server_http_res>();
  643. if (!router_validate_model(name, models, autoload, error_res)) {
  644. return error_res;
  645. }
  646. return models.proxy_request(req, method, name, false);
  647. };
  648. this->proxy_post = [this](const server_http_req & req) {
  649. std::string method = "POST";
  650. json body = json::parse(req.body);
  651. std::string name = json_value(body, "model", std::string());
  652. bool autoload = is_autoload(params, req);
  653. auto error_res = std::make_unique<server_http_res>();
  654. if (!router_validate_model(name, models, autoload, error_res)) {
  655. return error_res;
  656. }
  657. return models.proxy_request(req, method, name, true); // update last usage for POST request only
  658. };
  659. this->post_router_models_load = [this](const server_http_req & req) {
  660. auto res = std::make_unique<server_http_res>();
  661. json body = json::parse(req.body);
  662. std::string name = json_value(body, "model", std::string());
  663. auto model = models.get_meta(name);
  664. if (!model.has_value()) {
  665. res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND));
  666. return res;
  667. }
  668. if (model->status == SERVER_MODEL_STATUS_LOADED) {
  669. res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST));
  670. return res;
  671. }
  672. models.load(name);
  673. res_ok(res, {{"success", true}});
  674. return res;
  675. };
  676. this->get_router_models = [this](const server_http_req &) {
  677. auto res = std::make_unique<server_http_res>();
  678. json models_json = json::array();
  679. auto all_models = models.get_all_meta();
  680. std::time_t t = std::time(0);
  681. for (const auto & meta : all_models) {
  682. json status {
  683. {"value", server_model_status_to_string(meta.status)},
  684. {"args", meta.args},
  685. };
  686. if (!meta.preset.name.empty()) {
  687. common_preset preset_copy = meta.preset;
  688. unset_reserved_args(preset_copy, false);
  689. preset_copy.unset_option("LLAMA_ARG_HOST");
  690. preset_copy.unset_option("LLAMA_ARG_PORT");
  691. preset_copy.unset_option("LLAMA_ARG_ALIAS");
  692. status["preset"] = preset_copy.to_ini();
  693. }
  694. if (meta.is_failed()) {
  695. status["exit_code"] = meta.exit_code;
  696. status["failed"] = true;
  697. }
  698. models_json.push_back(json {
  699. {"id", meta.name},
  700. {"object", "model"}, // for OAI-compat
  701. {"owned_by", "llamacpp"}, // for OAI-compat
  702. {"created", t}, // for OAI-compat
  703. {"status", status},
  704. // TODO: add other fields, may require reading GGUF metadata
  705. });
  706. }
  707. res_ok(res, {
  708. {"data", models_json},
  709. {"object", "list"},
  710. });
  711. return res;
  712. };
  713. this->post_router_models_unload = [this](const server_http_req & req) {
  714. auto res = std::make_unique<server_http_res>();
  715. json body = json::parse(req.body);
  716. std::string name = json_value(body, "model", std::string());
  717. auto model = models.get_meta(name);
  718. if (!model.has_value()) {
  719. res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
  720. return res;
  721. }
  722. if (model->status != SERVER_MODEL_STATUS_LOADED) {
  723. res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
  724. return res;
  725. }
  726. models.unload(name);
  727. res_ok(res, {{"success", true}});
  728. return res;
  729. };
  730. }
  731. //
  732. // server_http_proxy
  733. //
  734. // simple implementation of a pipe
  735. // used for streaming data between threads
  736. template<typename T>
  737. struct pipe_t {
  738. std::mutex mutex;
  739. std::condition_variable cv;
  740. std::queue<T> queue;
  741. std::atomic<bool> writer_closed{false};
  742. std::atomic<bool> reader_closed{false};
  743. void close_write() {
  744. writer_closed.store(true, std::memory_order_relaxed);
  745. cv.notify_all();
  746. }
  747. void close_read() {
  748. reader_closed.store(true, std::memory_order_relaxed);
  749. cv.notify_all();
  750. }
  751. bool read(T & output, const std::function<bool()> & should_stop) {
  752. std::unique_lock<std::mutex> lk(mutex);
  753. constexpr auto poll_interval = std::chrono::milliseconds(500);
  754. while (true) {
  755. if (!queue.empty()) {
  756. output = std::move(queue.front());
  757. queue.pop();
  758. return true;
  759. }
  760. if (writer_closed.load()) {
  761. return false; // clean EOF
  762. }
  763. if (should_stop()) {
  764. close_read(); // signal broken pipe to writer
  765. return false; // cancelled / reader no longer alive
  766. }
  767. cv.wait_for(lk, poll_interval);
  768. }
  769. }
  770. bool write(T && data) {
  771. std::lock_guard<std::mutex> lk(mutex);
  772. if (reader_closed.load()) {
  773. return false; // broken pipe
  774. }
  775. queue.push(std::move(data));
  776. cv.notify_one();
  777. return true;
  778. }
  779. };
  780. static std::string to_lower_copy(const std::string & value) {
  781. std::string lowered(value.size(), '\0');
  782. std::transform(value.begin(), value.end(), lowered.begin(), [](unsigned char c) { return std::tolower(c); });
  783. return lowered;
  784. }
  785. static bool should_strip_proxy_header(const std::string & header_name) {
  786. // Headers that get duplicated when router forwards child responses
  787. if (header_name == "server" ||
  788. header_name == "transfer-encoding" ||
  789. header_name == "content-length" || // quick fix for https://github.com/ggml-org/llama.cpp/issues/17710
  790. header_name == "keep-alive") {
  791. return true;
  792. }
  793. // Router injects CORS, child also sends them: duplicate
  794. if (header_name.rfind("access-control-", 0) == 0) {
  795. return true;
  796. }
  797. return false;
  798. }
  799. server_http_proxy::server_http_proxy(
  800. const std::string & method,
  801. const std::string & host,
  802. int port,
  803. const std::string & path,
  804. const std::map<std::string, std::string> & headers,
  805. const std::string & body,
  806. const std::function<bool()> should_stop) {
  807. // shared between reader and writer threads
  808. auto cli = std::make_shared<httplib::Client>(host, port);
  809. auto pipe = std::make_shared<pipe_t<msg_t>>();
  810. // setup Client
  811. cli->set_connection_timeout(0, 200000); // 200 milliseconds
  812. this->status = 500; // to be overwritten upon response
  813. this->cleanup = [pipe]() {
  814. pipe->close_read();
  815. pipe->close_write();
  816. };
  817. // wire up the receive end of the pipe
  818. this->next = [pipe, should_stop](std::string & out) -> bool {
  819. msg_t msg;
  820. bool has_next = pipe->read(msg, should_stop);
  821. if (!msg.data.empty()) {
  822. out = std::move(msg.data);
  823. }
  824. return has_next; // false if EOF or pipe broken
  825. };
  826. // wire up the HTTP client
  827. // note: do NOT capture `this` pointer, as it may be destroyed before the thread ends
  828. httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) {
  829. msg_t msg;
  830. msg.status = response.status;
  831. for (const auto & [key, value] : response.headers) {
  832. const auto lowered = to_lower_copy(key);
  833. if (should_strip_proxy_header(lowered)) {
  834. continue;
  835. }
  836. if (lowered == "content-type") {
  837. msg.content_type = value;
  838. continue;
  839. }
  840. msg.headers[key] = value;
  841. }
  842. return pipe->write(std::move(msg)); // send headers first
  843. };
  844. httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) {
  845. // send data chunks
  846. // returns false if pipe is closed / broken (signal to stop receiving)
  847. return pipe->write({{}, 0, std::string(data, data_length), ""});
  848. };
  849. // prepare the request to destination server
  850. httplib::Request req;
  851. {
  852. req.method = method;
  853. req.path = path;
  854. for (const auto & [key, value] : headers) {
  855. req.set_header(key, value);
  856. }
  857. req.body = body;
  858. req.response_handler = response_handler;
  859. req.content_receiver = content_receiver;
  860. }
  861. // start the proxy thread
  862. SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str());
  863. this->thread = std::thread([cli, pipe, req]() {
  864. auto result = cli->send(std::move(req));
  865. if (result.error() != httplib::Error::Success) {
  866. auto err_str = httplib::to_string(result.error());
  867. SRV_ERR("http client error: %s\n", err_str.c_str());
  868. pipe->write({{}, 500, "", ""}); // header
  869. pipe->write({{}, 0, "proxy error: " + err_str, ""}); // body
  870. }
  871. pipe->close_write(); // signal EOF to reader
  872. SRV_DBG("%s", "client request thread ended\n");
  873. });
  874. this->thread.detach();
  875. // wait for the first chunk (headers)
  876. {
  877. msg_t header;
  878. if (pipe->read(header, should_stop)) {
  879. SRV_DBG("%s", "received response headers\n");
  880. this->status = header.status;
  881. this->headers = std::move(header.headers);
  882. if (!header.content_type.empty()) {
  883. this->content_type = std::move(header.content_type);
  884. }
  885. } else {
  886. SRV_DBG("%s", "no response headers received (request cancelled?)\n");
  887. }
  888. }
  889. }