server-models.cpp 36 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042
  1. #include "server-common.h"
  2. #include "server-models.h"
  3. #include "preset.h"
  4. #include "download.h"
  5. #include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
  6. #include <sheredom/subprocess.h>
  7. #include <functional>
  8. #include <algorithm>
  9. #include <thread>
  10. #include <mutex>
  11. #include <condition_variable>
  12. #include <cstring>
  13. #include <atomic>
  14. #include <chrono>
  15. #include <queue>
  16. #include <filesystem>
  17. #include <cstring>
  18. #ifdef _WIN32
  19. #include <winsock2.h>
  20. #else
  21. #include <sys/socket.h>
  22. #include <netinet/in.h>
  23. #include <arpa/inet.h>
  24. #include <unistd.h>
  25. #endif
  26. #if defined(__APPLE__) && defined(__MACH__)
  27. // macOS: use _NSGetExecutablePath to get the executable path
  28. #include <mach-o/dyld.h>
  29. #include <limits.h>
  30. #endif
  31. #define DEFAULT_STOP_TIMEOUT 10 // seconds
  32. #define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit"
  33. #define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready"
  34. // address for child process, this is needed because router may run on 0.0.0.0
  35. // ref: https://github.com/ggml-org/llama.cpp/issues/17862
  36. #define CHILD_ADDR "127.0.0.1"
  37. static std::filesystem::path get_server_exec_path() {
  38. #if defined(_WIN32)
  39. wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths
  40. DWORD len = GetModuleFileNameW(nullptr, buf, _countof(buf));
  41. if (len == 0 || len >= _countof(buf)) {
  42. throw std::runtime_error("GetModuleFileNameW failed or path too long");
  43. }
  44. return std::filesystem::path(buf);
  45. #elif defined(__APPLE__) && defined(__MACH__)
  46. char small_path[PATH_MAX];
  47. uint32_t size = sizeof(small_path);
  48. if (_NSGetExecutablePath(small_path, &size) == 0) {
  49. // resolve any symlinks to get absolute path
  50. try {
  51. return std::filesystem::canonical(std::filesystem::path(small_path));
  52. } catch (...) {
  53. return std::filesystem::path(small_path);
  54. }
  55. } else {
  56. // buffer was too small, allocate required size and call again
  57. std::vector<char> buf(size);
  58. if (_NSGetExecutablePath(buf.data(), &size) == 0) {
  59. try {
  60. return std::filesystem::canonical(std::filesystem::path(buf.data()));
  61. } catch (...) {
  62. return std::filesystem::path(buf.data());
  63. }
  64. }
  65. throw std::runtime_error("_NSGetExecutablePath failed after buffer resize");
  66. }
  67. #else
  68. char path[FILENAME_MAX];
  69. ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX);
  70. if (count <= 0) {
  71. throw std::runtime_error("failed to resolve /proc/self/exe");
  72. }
  73. return std::filesystem::path(std::string(path, count));
  74. #endif
  75. }
  76. static void unset_reserved_args(common_preset & preset, bool unset_model_args) {
  77. preset.unset_option("LLAMA_ARG_SSL_KEY_FILE");
  78. preset.unset_option("LLAMA_ARG_SSL_CERT_FILE");
  79. preset.unset_option("LLAMA_API_KEY");
  80. preset.unset_option("LLAMA_ARG_MODELS_DIR");
  81. preset.unset_option("LLAMA_ARG_MODELS_MAX");
  82. preset.unset_option("LLAMA_ARG_MODELS_PRESET");
  83. preset.unset_option("LLAMA_ARG_MODELS_AUTOLOAD");
  84. if (unset_model_args) {
  85. preset.unset_option("LLAMA_ARG_MODEL");
  86. preset.unset_option("LLAMA_ARG_MMPROJ");
  87. preset.unset_option("LLAMA_ARG_HF_REPO");
  88. }
  89. }
  90. void server_model_meta::update_args(common_preset_context & ctx_preset, std::string bin_path) {
  91. // update params
  92. unset_reserved_args(preset, false);
  93. preset.set_option(ctx_preset, "LLAMA_ARG_HOST", CHILD_ADDR);
  94. preset.set_option(ctx_preset, "LLAMA_ARG_PORT", std::to_string(port));
  95. preset.set_option(ctx_preset, "LLAMA_ARG_ALIAS", name);
  96. // TODO: maybe validate preset before rendering ?
  97. // render args
  98. args = preset.to_args(bin_path);
  99. }
  100. //
  101. // server_models
  102. //
  103. server_models::server_models(
  104. const common_params & params,
  105. int argc,
  106. char ** argv,
  107. char ** envp)
  108. : ctx_preset(LLAMA_EXAMPLE_SERVER),
  109. base_params(params),
  110. base_preset(ctx_preset.load_from_args(argc, argv)) {
  111. for (char ** env = envp; *env != nullptr; env++) {
  112. base_env.push_back(std::string(*env));
  113. }
  114. // clean up base preset
  115. unset_reserved_args(base_preset, true);
  116. // set binary path
  117. try {
  118. bin_path = get_server_exec_path().string();
  119. } catch (const std::exception & e) {
  120. bin_path = argv[0];
  121. LOG_WRN("failed to get server executable path: %s\n", e.what());
  122. LOG_WRN("using original argv[0] as fallback: %s\n", argv[0]);
  123. }
  124. load_models();
  125. }
  126. void server_models::add_model(server_model_meta && meta) {
  127. if (mapping.find(meta.name) != mapping.end()) {
  128. throw std::runtime_error(string_format("model '%s' appears multiple times", meta.name.c_str()));
  129. }
  130. meta.update_args(ctx_preset, bin_path); // render args
  131. std::string name = meta.name;
  132. mapping[name] = instance_t{
  133. /* subproc */ std::make_shared<subprocess_s>(),
  134. /* th */ std::thread(),
  135. /* meta */ std::move(meta)
  136. };
  137. }
  138. // TODO: allow refreshing cached model list
  139. void server_models::load_models() {
  140. // loading models from 3 sources:
  141. // 1. cached models
  142. common_presets cached_models = ctx_preset.load_from_cache();
  143. SRV_INF("Loaded %zu cached model presets\n", cached_models.size());
  144. // 2. local models from --models-dir
  145. common_presets local_models;
  146. if (!base_params.models_dir.empty()) {
  147. local_models = ctx_preset.load_from_models_dir(base_params.models_dir);
  148. SRV_INF("Loaded %zu local model presets from %s\n", local_models.size(), base_params.models_dir.c_str());
  149. }
  150. // 3. custom-path models from presets
  151. common_preset global = {};
  152. common_presets custom_presets = {};
  153. if (!base_params.models_preset.empty()) {
  154. custom_presets = ctx_preset.load_from_ini(base_params.models_preset, global);
  155. SRV_INF("Loaded %zu custom model presets from %s\n", custom_presets.size(), base_params.models_preset.c_str());
  156. }
  157. // cascade, apply global preset first
  158. cached_models = ctx_preset.cascade(global, cached_models);
  159. local_models = ctx_preset.cascade(global, local_models);
  160. custom_presets = ctx_preset.cascade(global, custom_presets);
  161. // note: if a model exists in both cached and local, local takes precedence
  162. common_presets final_presets;
  163. for (const auto & [name, preset] : cached_models) {
  164. final_presets[name] = preset;
  165. }
  166. for (const auto & [name, preset] : local_models) {
  167. final_presets[name] = preset;
  168. }
  169. // process custom presets from INI
  170. for (const auto & [name, custom] : custom_presets) {
  171. if (final_presets.find(name) != final_presets.end()) {
  172. // apply custom config if exists
  173. common_preset & target = final_presets[name];
  174. target.merge(custom);
  175. } else {
  176. // otherwise add directly
  177. final_presets[name] = custom;
  178. }
  179. }
  180. // server base preset from CLI args take highest precedence
  181. for (auto & [name, preset] : final_presets) {
  182. preset.merge(base_preset);
  183. }
  184. // convert presets to server_model_meta and add to mapping
  185. for (const auto & preset : final_presets) {
  186. server_model_meta meta{
  187. /* preset */ preset.second,
  188. /* name */ preset.first,
  189. /* port */ 0,
  190. /* status */ SERVER_MODEL_STATUS_UNLOADED,
  191. /* last_used */ 0,
  192. /* args */ std::vector<std::string>(),
  193. /* exit_code */ 0,
  194. /* stop_timeout */ DEFAULT_STOP_TIMEOUT,
  195. };
  196. add_model(std::move(meta));
  197. }
  198. // log available models
  199. {
  200. std::unordered_set<std::string> custom_names;
  201. for (const auto & [name, preset] : custom_presets) {
  202. custom_names.insert(name);
  203. }
  204. SRV_INF("Available models (%zu) (*: custom preset)\n", mapping.size());
  205. for (const auto & [name, inst] : mapping) {
  206. bool has_custom = custom_names.find(name) != custom_names.end();
  207. SRV_INF(" %c %s\n", has_custom ? '*' : ' ', name.c_str());
  208. }
  209. }
  210. // handle custom stop-timeout option
  211. for (auto & [name, inst] : mapping) {
  212. std::string val;
  213. if (inst.meta.preset.get_option(COMMON_ARG_PRESET_STOP_TIMEOUT, val)) {
  214. try {
  215. inst.meta.stop_timeout = std::stoi(val);
  216. } catch (...) {
  217. SRV_WRN("invalid stop-timeout value '%s' for model '%s', using default %d seconds\n",
  218. val.c_str(), name.c_str(), DEFAULT_STOP_TIMEOUT);
  219. inst.meta.stop_timeout = DEFAULT_STOP_TIMEOUT;
  220. }
  221. }
  222. }
  223. // load any autoload models
  224. std::vector<std::string> models_to_load;
  225. for (const auto & [name, inst] : mapping) {
  226. std::string val;
  227. if (inst.meta.preset.get_option(COMMON_ARG_PRESET_LOAD_ON_STARTUP, val)) {
  228. models_to_load.push_back(name);
  229. }
  230. }
  231. if ((int)models_to_load.size() > base_params.models_max) {
  232. throw std::runtime_error(string_format(
  233. "number of models to load on startup (%zu) exceeds models_max (%d)",
  234. models_to_load.size(),
  235. base_params.models_max
  236. ));
  237. }
  238. for (const auto & name : models_to_load) {
  239. SRV_INF("(startup) loading model %s\n", name.c_str());
  240. load(name);
  241. }
  242. }
  243. void server_models::update_meta(const std::string & name, const server_model_meta & meta) {
  244. std::lock_guard<std::mutex> lk(mutex);
  245. auto it = mapping.find(name);
  246. if (it != mapping.end()) {
  247. it->second.meta = meta;
  248. }
  249. cv.notify_all(); // notify wait_until_loaded
  250. }
  251. bool server_models::has_model(const std::string & name) {
  252. std::lock_guard<std::mutex> lk(mutex);
  253. return mapping.find(name) != mapping.end();
  254. }
  255. std::optional<server_model_meta> server_models::get_meta(const std::string & name) {
  256. std::lock_guard<std::mutex> lk(mutex);
  257. auto it = mapping.find(name);
  258. if (it != mapping.end()) {
  259. return it->second.meta;
  260. }
  261. return std::nullopt;
  262. }
  263. static int get_free_port() {
  264. #ifdef _WIN32
  265. WSADATA wsaData;
  266. if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
  267. return -1;
  268. }
  269. typedef SOCKET native_socket_t;
  270. #define INVALID_SOCKET_VAL INVALID_SOCKET
  271. #define CLOSE_SOCKET(s) closesocket(s)
  272. #else
  273. typedef int native_socket_t;
  274. #define INVALID_SOCKET_VAL -1
  275. #define CLOSE_SOCKET(s) close(s)
  276. #endif
  277. native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
  278. if (sock == INVALID_SOCKET_VAL) {
  279. #ifdef _WIN32
  280. WSACleanup();
  281. #endif
  282. return -1;
  283. }
  284. struct sockaddr_in serv_addr;
  285. std::memset(&serv_addr, 0, sizeof(serv_addr));
  286. serv_addr.sin_family = AF_INET;
  287. serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
  288. serv_addr.sin_port = htons(0);
  289. if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
  290. CLOSE_SOCKET(sock);
  291. #ifdef _WIN32
  292. WSACleanup();
  293. #endif
  294. return -1;
  295. }
  296. #ifdef _WIN32
  297. int namelen = sizeof(serv_addr);
  298. #else
  299. socklen_t namelen = sizeof(serv_addr);
  300. #endif
  301. if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
  302. CLOSE_SOCKET(sock);
  303. #ifdef _WIN32
  304. WSACleanup();
  305. #endif
  306. return -1;
  307. }
  308. int port = ntohs(serv_addr.sin_port);
  309. CLOSE_SOCKET(sock);
  310. #ifdef _WIN32
  311. WSACleanup();
  312. #endif
  313. return port;
  314. }
  315. // helper to convert vector<string> to char **
  316. // pointers are only valid as long as the original vector is valid
  317. static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & vec) {
  318. std::vector<char *> result;
  319. result.reserve(vec.size() + 1);
  320. for (const auto & s : vec) {
  321. result.push_back(const_cast<char*>(s.c_str()));
  322. }
  323. result.push_back(nullptr);
  324. return result;
  325. }
  326. std::vector<server_model_meta> server_models::get_all_meta() {
  327. std::lock_guard<std::mutex> lk(mutex);
  328. std::vector<server_model_meta> result;
  329. result.reserve(mapping.size());
  330. for (const auto & [name, inst] : mapping) {
  331. result.push_back(inst.meta);
  332. }
  333. return result;
  334. }
  335. void server_models::unload_lru() {
  336. if (base_params.models_max <= 0) {
  337. return; // no limit
  338. }
  339. // remove one of the servers if we passed the models_max (least recently used - LRU)
  340. std::string lru_model_name = "";
  341. int64_t lru_last_used = ggml_time_ms();
  342. size_t count_active = 0;
  343. {
  344. std::unique_lock<std::mutex> lk(mutex);
  345. for (const auto & m : mapping) {
  346. if (m.second.meta.is_active()) {
  347. count_active++;
  348. if (m.second.meta.last_used < lru_last_used) {
  349. lru_model_name = m.first;
  350. lru_last_used = m.second.meta.last_used;
  351. }
  352. }
  353. }
  354. }
  355. if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) {
  356. SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str());
  357. unload(lru_model_name);
  358. // wait for unload to complete
  359. {
  360. std::unique_lock<std::mutex> lk(mutex);
  361. cv.wait(lk, [this, &lru_model_name]() {
  362. return mapping[lru_model_name].meta.status == SERVER_MODEL_STATUS_UNLOADED;
  363. });
  364. }
  365. }
  366. }
  367. void server_models::load(const std::string & name) {
  368. if (!has_model(name)) {
  369. throw std::runtime_error("model name=" + name + " is not found");
  370. }
  371. unload_lru();
  372. std::lock_guard<std::mutex> lk(mutex);
  373. auto meta = mapping[name].meta;
  374. if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
  375. SRV_INF("model %s is not ready\n", name.c_str());
  376. return;
  377. }
  378. // prepare new instance info
  379. instance_t inst;
  380. inst.meta = meta;
  381. inst.meta.port = get_free_port();
  382. inst.meta.status = SERVER_MODEL_STATUS_LOADING;
  383. inst.meta.last_used = ggml_time_ms();
  384. if (inst.meta.port <= 0) {
  385. throw std::runtime_error("failed to get a port number");
  386. }
  387. inst.subproc = std::make_shared<subprocess_s>();
  388. {
  389. SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port);
  390. inst.meta.update_args(ctx_preset, bin_path); // render args
  391. std::vector<std::string> child_args = inst.meta.args; // copy
  392. std::vector<std::string> child_env = base_env; // copy
  393. child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));
  394. SRV_INF("%s", "spawning server instance with args:\n");
  395. for (const auto & arg : child_args) {
  396. SRV_INF(" %s\n", arg.c_str());
  397. }
  398. inst.meta.args = child_args; // save for debugging
  399. std::vector<char *> argv = to_char_ptr_array(child_args);
  400. std::vector<char *> envp = to_char_ptr_array(child_env);
  401. // TODO @ngxson : maybe separate stdout and stderr in the future
  402. // so that we can use stdout for commands and stderr for logging
  403. int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
  404. int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get());
  405. if (result != 0) {
  406. throw std::runtime_error("failed to spawn server instance");
  407. }
  408. inst.stdin_file = subprocess_stdin(inst.subproc.get());
  409. }
  410. // start a thread to manage the child process
  411. // captured variables are guaranteed to be destroyed only after the thread is joined
  412. inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() {
  413. FILE * stdin_file = subprocess_stdin(child_proc.get());
  414. FILE * stdout_file = subprocess_stdout(child_proc.get()); // combined stdout/stderr
  415. std::thread log_thread([&]() {
  416. // read stdout/stderr and forward to main server log
  417. // also handle status report from child process
  418. bool state_received = false; // true if child state received
  419. if (stdout_file) {
  420. char buffer[4096];
  421. while (fgets(buffer, sizeof(buffer), stdout_file) != nullptr) {
  422. LOG("[%5d] %s", port, buffer);
  423. if (!state_received && std::strstr(buffer, CMD_CHILD_TO_ROUTER_READY) != nullptr) {
  424. // child process is ready
  425. this->update_status(name, SERVER_MODEL_STATUS_LOADED, 0);
  426. state_received = true;
  427. }
  428. }
  429. } else {
  430. SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str());
  431. }
  432. });
  433. std::thread stopping_thread([&]() {
  434. // thread to monitor stopping signal
  435. auto is_stopping = [this, &name]() {
  436. return this->stopping_models.find(name) != this->stopping_models.end();
  437. };
  438. {
  439. std::unique_lock<std::mutex> lk(this->mutex);
  440. this->cv_stop.wait(lk, is_stopping);
  441. }
  442. SRV_INF("stopping model instance name=%s\n", name.c_str());
  443. // send interrupt to child process
  444. fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
  445. fflush(stdin_file);
  446. // wait to stop gracefully or timeout
  447. int64_t start_time = ggml_time_ms();
  448. while (true) {
  449. std::unique_lock<std::mutex> lk(this->mutex);
  450. if (!is_stopping()) {
  451. return; // already stopped
  452. }
  453. int64_t elapsed = ggml_time_ms() - start_time;
  454. if (elapsed >= stop_timeout * 1000) {
  455. // timeout, force kill
  456. SRV_WRN("force-killing model instance name=%s after %d seconds timeout\n", name.c_str(), stop_timeout);
  457. subprocess_terminate(child_proc.get());
  458. return;
  459. }
  460. this->cv_stop.wait_for(lk, std::chrono::seconds(1));
  461. }
  462. });
  463. // we reach here when the child process exits
  464. // note: we cannot join() prior to this point because it will close stdin_file
  465. if (log_thread.joinable()) {
  466. log_thread.join();
  467. }
  468. // stop the timeout monitoring thread
  469. {
  470. std::lock_guard<std::mutex> lk(this->mutex);
  471. stopping_models.erase(name);
  472. cv_stop.notify_all();
  473. }
  474. if (stopping_thread.joinable()) {
  475. stopping_thread.join();
  476. }
  477. // get the exit code
  478. int exit_code = 0;
  479. subprocess_join(child_proc.get(), &exit_code);
  480. subprocess_destroy(child_proc.get());
  481. // update status and exit code
  482. this->update_status(name, SERVER_MODEL_STATUS_UNLOADED, exit_code);
  483. SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
  484. });
  485. // clean up old process/thread if exists
  486. {
  487. auto & old_instance = mapping[name];
  488. // old process should have exited already, but just in case, we clean it up here
  489. if (subprocess_alive(old_instance.subproc.get())) {
  490. SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
  491. subprocess_terminate(old_instance.subproc.get()); // force kill
  492. }
  493. if (old_instance.th.joinable()) {
  494. old_instance.th.join();
  495. }
  496. }
  497. mapping[name] = std::move(inst);
  498. cv.notify_all();
  499. }
  500. void server_models::unload(const std::string & name) {
  501. std::lock_guard<std::mutex> lk(mutex);
  502. auto it = mapping.find(name);
  503. if (it != mapping.end()) {
  504. if (it->second.meta.is_active()) {
  505. SRV_INF("unloading model instance name=%s\n", name.c_str());
  506. stopping_models.insert(name);
  507. cv_stop.notify_all();
  508. // status change will be handled by the managing thread
  509. } else {
  510. SRV_WRN("model instance name=%s is not loaded\n", name.c_str());
  511. }
  512. }
  513. }
  514. void server_models::unload_all() {
  515. std::vector<std::thread> to_join;
  516. {
  517. std::lock_guard<std::mutex> lk(mutex);
  518. for (auto & [name, inst] : mapping) {
  519. if (inst.meta.is_active()) {
  520. SRV_INF("unloading model instance name=%s\n", name.c_str());
  521. stopping_models.insert(name);
  522. cv_stop.notify_all();
  523. // status change will be handled by the managing thread
  524. }
  525. // moving the thread to join list to avoid deadlock
  526. to_join.push_back(std::move(inst.th));
  527. }
  528. }
  529. for (auto & th : to_join) {
  530. if (th.joinable()) {
  531. th.join();
  532. }
  533. }
  534. }
  535. void server_models::update_status(const std::string & name, server_model_status status, int exit_code) {
  536. std::unique_lock<std::mutex> lk(mutex);
  537. auto it = mapping.find(name);
  538. if (it != mapping.end()) {
  539. auto & meta = it->second.meta;
  540. meta.status = status;
  541. meta.exit_code = exit_code;
  542. }
  543. cv.notify_all();
  544. }
  545. void server_models::wait_until_loaded(const std::string & name) {
  546. std::unique_lock<std::mutex> lk(mutex);
  547. cv.wait(lk, [this, &name]() {
  548. auto it = mapping.find(name);
  549. if (it != mapping.end()) {
  550. return it->second.meta.status != SERVER_MODEL_STATUS_LOADING;
  551. }
  552. return false;
  553. });
  554. }
  555. bool server_models::ensure_model_loaded(const std::string & name) {
  556. auto meta = get_meta(name);
  557. if (!meta.has_value()) {
  558. throw std::runtime_error("model name=" + name + " is not found");
  559. }
  560. if (meta->status == SERVER_MODEL_STATUS_LOADED) {
  561. return false; // already loaded
  562. }
  563. if (meta->status == SERVER_MODEL_STATUS_UNLOADED) {
  564. SRV_INF("model name=%s is not loaded, loading...\n", name.c_str());
  565. load(name);
  566. }
  567. // for loading state
  568. SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str());
  569. wait_until_loaded(name);
  570. // check final status
  571. meta = get_meta(name);
  572. if (!meta.has_value() || meta->is_failed()) {
  573. throw std::runtime_error("model name=" + name + " failed to load");
  574. }
  575. return true;
  576. }
  577. server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used) {
  578. auto meta = get_meta(name);
  579. if (!meta.has_value()) {
  580. throw std::runtime_error("model name=" + name + " is not found");
  581. }
  582. if (meta->status != SERVER_MODEL_STATUS_LOADED) {
  583. throw std::invalid_argument("model name=" + name + " is not loaded");
  584. }
  585. if (update_last_used) {
  586. std::unique_lock<std::mutex> lk(mutex);
  587. mapping[name].meta.last_used = ggml_time_ms();
  588. }
  589. SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port);
  590. auto proxy = std::make_unique<server_http_proxy>(
  591. method,
  592. CHILD_ADDR,
  593. meta->port,
  594. req.path,
  595. req.headers,
  596. req.body,
  597. req.should_stop);
  598. return proxy;
  599. }
  600. std::thread server_models::setup_child_server(const std::function<void(int)> & shutdown_handler) {
  601. // send a notification to the router server that a model instance is ready
  602. common_log_pause(common_log_main());
  603. fflush(stdout);
  604. fprintf(stdout, "%s\n", CMD_CHILD_TO_ROUTER_READY);
  605. fflush(stdout);
  606. common_log_resume(common_log_main());
  607. // setup thread for monitoring stdin
  608. return std::thread([shutdown_handler]() {
  609. // wait for EOF on stdin
  610. SRV_INF("%s", "child server monitoring thread started, waiting for EOF on stdin...\n");
  611. bool eof = false;
  612. while (true) {
  613. std::string line;
  614. if (!std::getline(std::cin, line)) {
  615. // EOF detected, that means the router server is unexpectedly exit or killed
  616. eof = true;
  617. break;
  618. }
  619. if (line.find(CMD_ROUTER_TO_CHILD_EXIT) != std::string::npos) {
  620. SRV_INF("%s", "exit command received, exiting...\n");
  621. shutdown_handler(0);
  622. break;
  623. }
  624. }
  625. if (eof) {
  626. SRV_INF("%s", "EOF on stdin detected, forcing shutdown...\n");
  627. exit(1);
  628. }
  629. });
  630. }
  631. //
  632. // server_models_routes
  633. //
  634. static void res_ok(std::unique_ptr<server_http_res> & res, const json & response_data) {
  635. res->status = 200;
  636. res->data = safe_json_to_str(response_data);
  637. }
  638. static void res_err(std::unique_ptr<server_http_res> & res, const json & error_data) {
  639. res->status = json_value(error_data, "code", 500);
  640. res->data = safe_json_to_str({{ "error", error_data }});
  641. }
  642. static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr<server_http_res> & res) {
  643. if (name.empty()) {
  644. res_err(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST));
  645. return false;
  646. }
  647. auto meta = models.get_meta(name);
  648. if (!meta.has_value()) {
  649. res_err(res, format_error_response("model not found", ERROR_TYPE_INVALID_REQUEST));
  650. return false;
  651. }
  652. if (models_autoload) {
  653. models.ensure_model_loaded(name);
  654. } else {
  655. if (meta->status != SERVER_MODEL_STATUS_LOADED) {
  656. res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
  657. return false;
  658. }
  659. }
  660. return true;
  661. }
  662. static bool is_autoload(const common_params & params, const server_http_req & req) {
  663. std::string autoload = req.get_param("autoload");
  664. if (autoload.empty()) {
  665. return params.models_autoload;
  666. } else {
  667. return autoload == "true" || autoload == "1";
  668. }
  669. }
  670. void server_models_routes::init_routes() {
  671. this->get_router_props = [this](const server_http_req & req) {
  672. std::string name = req.get_param("model");
  673. if (name.empty()) {
  674. // main instance
  675. auto res = std::make_unique<server_http_res>();
  676. res_ok(res, {
  677. // TODO: add support for this on web UI
  678. {"role", "router"},
  679. {"max_instances", 4}, // dummy value for testing
  680. // this is a dummy response to make sure webui doesn't break
  681. {"model_alias", "llama-server"},
  682. {"model_path", "none"},
  683. {"default_generation_settings", {
  684. {"params", json{}},
  685. {"n_ctx", 0},
  686. }},
  687. {"webui_settings", webui_settings},
  688. });
  689. return res;
  690. }
  691. return proxy_get(req);
  692. };
  693. this->proxy_get = [this](const server_http_req & req) {
  694. std::string method = "GET";
  695. std::string name = req.get_param("model");
  696. bool autoload = is_autoload(params, req);
  697. auto error_res = std::make_unique<server_http_res>();
  698. if (!router_validate_model(name, models, autoload, error_res)) {
  699. return error_res;
  700. }
  701. return models.proxy_request(req, method, name, false);
  702. };
  703. this->proxy_post = [this](const server_http_req & req) {
  704. std::string method = "POST";
  705. json body = json::parse(req.body);
  706. std::string name = json_value(body, "model", std::string());
  707. bool autoload = is_autoload(params, req);
  708. auto error_res = std::make_unique<server_http_res>();
  709. if (!router_validate_model(name, models, autoload, error_res)) {
  710. return error_res;
  711. }
  712. return models.proxy_request(req, method, name, true); // update last usage for POST request only
  713. };
  714. this->post_router_models_load = [this](const server_http_req & req) {
  715. auto res = std::make_unique<server_http_res>();
  716. json body = json::parse(req.body);
  717. std::string name = json_value(body, "model", std::string());
  718. auto model = models.get_meta(name);
  719. if (!model.has_value()) {
  720. res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND));
  721. return res;
  722. }
  723. if (model->status == SERVER_MODEL_STATUS_LOADED) {
  724. res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST));
  725. return res;
  726. }
  727. models.load(name);
  728. res_ok(res, {{"success", true}});
  729. return res;
  730. };
  731. this->get_router_models = [this](const server_http_req &) {
  732. auto res = std::make_unique<server_http_res>();
  733. json models_json = json::array();
  734. auto all_models = models.get_all_meta();
  735. std::time_t t = std::time(0);
  736. for (const auto & meta : all_models) {
  737. json status {
  738. {"value", server_model_status_to_string(meta.status)},
  739. {"args", meta.args},
  740. };
  741. if (!meta.preset.name.empty()) {
  742. common_preset preset_copy = meta.preset;
  743. unset_reserved_args(preset_copy, false);
  744. preset_copy.unset_option("LLAMA_ARG_HOST");
  745. preset_copy.unset_option("LLAMA_ARG_PORT");
  746. preset_copy.unset_option("LLAMA_ARG_ALIAS");
  747. status["preset"] = preset_copy.to_ini();
  748. }
  749. if (meta.is_failed()) {
  750. status["exit_code"] = meta.exit_code;
  751. status["failed"] = true;
  752. }
  753. models_json.push_back(json {
  754. {"id", meta.name},
  755. {"object", "model"}, // for OAI-compat
  756. {"owned_by", "llamacpp"}, // for OAI-compat
  757. {"created", t}, // for OAI-compat
  758. {"status", status},
  759. // TODO: add other fields, may require reading GGUF metadata
  760. });
  761. }
  762. res_ok(res, {
  763. {"data", models_json},
  764. {"object", "list"},
  765. });
  766. return res;
  767. };
  768. this->post_router_models_unload = [this](const server_http_req & req) {
  769. auto res = std::make_unique<server_http_res>();
  770. json body = json::parse(req.body);
  771. std::string name = json_value(body, "model", std::string());
  772. auto model = models.get_meta(name);
  773. if (!model.has_value()) {
  774. res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
  775. return res;
  776. }
  777. if (!model->is_active()) {
  778. res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
  779. return res;
  780. }
  781. models.unload(name);
  782. res_ok(res, {{"success", true}});
  783. return res;
  784. };
  785. }
  786. //
  787. // server_http_proxy
  788. //
  789. // simple implementation of a pipe
  790. // used for streaming data between threads
  791. template<typename T>
  792. struct pipe_t {
  793. std::mutex mutex;
  794. std::condition_variable cv;
  795. std::queue<T> queue;
  796. std::atomic<bool> writer_closed{false};
  797. std::atomic<bool> reader_closed{false};
  798. void close_write() {
  799. writer_closed.store(true, std::memory_order_relaxed);
  800. cv.notify_all();
  801. }
  802. void close_read() {
  803. reader_closed.store(true, std::memory_order_relaxed);
  804. cv.notify_all();
  805. }
  806. bool read(T & output, const std::function<bool()> & should_stop) {
  807. std::unique_lock<std::mutex> lk(mutex);
  808. constexpr auto poll_interval = std::chrono::milliseconds(500);
  809. while (true) {
  810. if (!queue.empty()) {
  811. output = std::move(queue.front());
  812. queue.pop();
  813. return true;
  814. }
  815. if (writer_closed.load()) {
  816. return false; // clean EOF
  817. }
  818. if (should_stop()) {
  819. close_read(); // signal broken pipe to writer
  820. return false; // cancelled / reader no longer alive
  821. }
  822. cv.wait_for(lk, poll_interval);
  823. }
  824. }
  825. bool write(T && data) {
  826. std::lock_guard<std::mutex> lk(mutex);
  827. if (reader_closed.load()) {
  828. return false; // broken pipe
  829. }
  830. queue.push(std::move(data));
  831. cv.notify_one();
  832. return true;
  833. }
  834. };
  835. static std::string to_lower_copy(const std::string & value) {
  836. std::string lowered(value.size(), '\0');
  837. std::transform(value.begin(), value.end(), lowered.begin(), [](unsigned char c) { return std::tolower(c); });
  838. return lowered;
  839. }
  840. static bool should_strip_proxy_header(const std::string & header_name) {
  841. // Headers that get duplicated when router forwards child responses
  842. if (header_name == "server" ||
  843. header_name == "transfer-encoding" ||
  844. header_name == "content-length" || // quick fix for https://github.com/ggml-org/llama.cpp/issues/17710
  845. header_name == "keep-alive") {
  846. return true;
  847. }
  848. // Router injects CORS, child also sends them: duplicate
  849. if (header_name.rfind("access-control-", 0) == 0) {
  850. return true;
  851. }
  852. return false;
  853. }
  854. server_http_proxy::server_http_proxy(
  855. const std::string & method,
  856. const std::string & host,
  857. int port,
  858. const std::string & path,
  859. const std::map<std::string, std::string> & headers,
  860. const std::string & body,
  861. const std::function<bool()> should_stop) {
  862. // shared between reader and writer threads
  863. auto cli = std::make_shared<httplib::Client>(host, port);
  864. auto pipe = std::make_shared<pipe_t<msg_t>>();
  865. // setup Client
  866. cli->set_connection_timeout(0, 200000); // 200 milliseconds
  867. this->status = 500; // to be overwritten upon response
  868. this->cleanup = [pipe]() {
  869. pipe->close_read();
  870. pipe->close_write();
  871. };
  872. // wire up the receive end of the pipe
  873. this->next = [pipe, should_stop](std::string & out) -> bool {
  874. msg_t msg;
  875. bool has_next = pipe->read(msg, should_stop);
  876. if (!msg.data.empty()) {
  877. out = std::move(msg.data);
  878. }
  879. return has_next; // false if EOF or pipe broken
  880. };
  881. // wire up the HTTP client
  882. // note: do NOT capture `this` pointer, as it may be destroyed before the thread ends
  883. httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) {
  884. msg_t msg;
  885. msg.status = response.status;
  886. for (const auto & [key, value] : response.headers) {
  887. const auto lowered = to_lower_copy(key);
  888. if (should_strip_proxy_header(lowered)) {
  889. continue;
  890. }
  891. if (lowered == "content-type") {
  892. msg.content_type = value;
  893. continue;
  894. }
  895. msg.headers[key] = value;
  896. }
  897. return pipe->write(std::move(msg)); // send headers first
  898. };
  899. httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) {
  900. // send data chunks
  901. // returns false if pipe is closed / broken (signal to stop receiving)
  902. return pipe->write({{}, 0, std::string(data, data_length), ""});
  903. };
  904. // prepare the request to destination server
  905. httplib::Request req;
  906. {
  907. req.method = method;
  908. req.path = path;
  909. for (const auto & [key, value] : headers) {
  910. req.set_header(key, value);
  911. }
  912. req.body = body;
  913. req.response_handler = response_handler;
  914. req.content_receiver = content_receiver;
  915. }
  916. // start the proxy thread
  917. SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str());
  918. this->thread = std::thread([cli, pipe, req]() {
  919. auto result = cli->send(std::move(req));
  920. if (result.error() != httplib::Error::Success) {
  921. auto err_str = httplib::to_string(result.error());
  922. SRV_ERR("http client error: %s\n", err_str.c_str());
  923. pipe->write({{}, 500, "", ""}); // header
  924. pipe->write({{}, 0, "proxy error: " + err_str, ""}); // body
  925. }
  926. pipe->close_write(); // signal EOF to reader
  927. SRV_DBG("%s", "client request thread ended\n");
  928. });
  929. this->thread.detach();
  930. // wait for the first chunk (headers)
  931. {
  932. msg_t header;
  933. if (pipe->read(header, should_stop)) {
  934. SRV_DBG("%s", "received response headers\n");
  935. this->status = header.status;
  936. this->headers = std::move(header.headers);
  937. if (!header.content_type.empty()) {
  938. this->content_type = std::move(header.content_type);
  939. }
  940. } else {
  941. SRV_DBG("%s", "no response headers received (request cancelled?)\n");
  942. }
  943. }
  944. }