server-models.cpp 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092
  1. #include "server-common.h"
  2. #include "server-models.h"
  3. #include "preset.h"
  4. #include "download.h"
  5. #include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
  6. #include <sheredom/subprocess.h>
  7. #include <functional>
  8. #include <algorithm>
  9. #include <thread>
  10. #include <mutex>
  11. #include <condition_variable>
  12. #include <cstring>
  13. #include <atomic>
  14. #include <chrono>
  15. #include <queue>
  16. #include <filesystem>
  17. #include <cstring>
  18. #ifdef _WIN32
  19. #include <winsock2.h>
  20. #include <windows.h>
  21. #else
  22. #include <sys/socket.h>
  23. #include <netinet/in.h>
  24. #include <arpa/inet.h>
  25. #include <unistd.h>
  26. extern char **environ;
  27. #endif
  28. #if defined(__APPLE__) && defined(__MACH__)
  29. // macOS: use _NSGetExecutablePath to get the executable path
  30. #include <mach-o/dyld.h>
  31. #include <limits.h>
  32. #endif
  33. #define DEFAULT_STOP_TIMEOUT 10 // seconds
  34. #define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit"
  35. #define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready"
  36. // address for child process, this is needed because router may run on 0.0.0.0
  37. // ref: https://github.com/ggml-org/llama.cpp/issues/17862
  38. #define CHILD_ADDR "127.0.0.1"
  39. static std::filesystem::path get_server_exec_path() {
  40. #if defined(_WIN32)
  41. wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths
  42. DWORD len = GetModuleFileNameW(nullptr, buf, _countof(buf));
  43. if (len == 0 || len >= _countof(buf)) {
  44. throw std::runtime_error("GetModuleFileNameW failed or path too long");
  45. }
  46. return std::filesystem::path(buf);
  47. #elif defined(__APPLE__) && defined(__MACH__)
  48. char small_path[PATH_MAX];
  49. uint32_t size = sizeof(small_path);
  50. if (_NSGetExecutablePath(small_path, &size) == 0) {
  51. // resolve any symlinks to get absolute path
  52. try {
  53. return std::filesystem::canonical(std::filesystem::path(small_path));
  54. } catch (...) {
  55. return std::filesystem::path(small_path);
  56. }
  57. } else {
  58. // buffer was too small, allocate required size and call again
  59. std::vector<char> buf(size);
  60. if (_NSGetExecutablePath(buf.data(), &size) == 0) {
  61. try {
  62. return std::filesystem::canonical(std::filesystem::path(buf.data()));
  63. } catch (...) {
  64. return std::filesystem::path(buf.data());
  65. }
  66. }
  67. throw std::runtime_error("_NSGetExecutablePath failed after buffer resize");
  68. }
  69. #else
  70. char path[FILENAME_MAX];
  71. ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX);
  72. if (count <= 0) {
  73. throw std::runtime_error("failed to resolve /proc/self/exe");
  74. }
  75. return std::filesystem::path(std::string(path, count));
  76. #endif
  77. }
  78. static void unset_reserved_args(common_preset & preset, bool unset_model_args) {
  79. preset.unset_option("LLAMA_ARG_SSL_KEY_FILE");
  80. preset.unset_option("LLAMA_ARG_SSL_CERT_FILE");
  81. preset.unset_option("LLAMA_API_KEY");
  82. preset.unset_option("LLAMA_ARG_MODELS_DIR");
  83. preset.unset_option("LLAMA_ARG_MODELS_MAX");
  84. preset.unset_option("LLAMA_ARG_MODELS_PRESET");
  85. preset.unset_option("LLAMA_ARG_MODELS_AUTOLOAD");
  86. if (unset_model_args) {
  87. preset.unset_option("LLAMA_ARG_MODEL");
  88. preset.unset_option("LLAMA_ARG_MMPROJ");
  89. preset.unset_option("LLAMA_ARG_HF_REPO");
  90. }
  91. }
  92. #ifdef _WIN32
  93. static std::string wide_to_utf8(const wchar_t * ws) {
  94. if (!ws || !*ws) {
  95. return {};
  96. }
  97. const int len = static_cast<int>(std::wcslen(ws));
  98. const int bytes = WideCharToMultiByte(CP_UTF8, 0, ws, len, nullptr, 0, nullptr, nullptr);
  99. if (bytes == 0) {
  100. return {};
  101. }
  102. std::string utf8(bytes, '\0');
  103. WideCharToMultiByte(CP_UTF8, 0, ws, len, utf8.data(), bytes, nullptr, nullptr);
  104. return utf8;
  105. }
  106. #endif
  107. static std::vector<std::string> get_environment() {
  108. std::vector<std::string> env;
  109. #ifdef _WIN32
  110. LPWCH env_block = GetEnvironmentStringsW();
  111. if (!env_block) {
  112. return env;
  113. }
  114. for (LPWCH e = env_block; *e; e += wcslen(e) + 1) {
  115. env.emplace_back(wide_to_utf8(e));
  116. }
  117. FreeEnvironmentStringsW(env_block);
  118. #else
  119. if (environ == nullptr) {
  120. return env;
  121. }
  122. for (char ** e = environ; *e != nullptr; e++) {
  123. env.emplace_back(*e);
  124. }
  125. #endif
  126. return env;
  127. }
  128. void server_model_meta::update_args(common_preset_context & ctx_preset, std::string bin_path) {
  129. // update params
  130. unset_reserved_args(preset, false);
  131. preset.set_option(ctx_preset, "LLAMA_ARG_HOST", CHILD_ADDR);
  132. preset.set_option(ctx_preset, "LLAMA_ARG_PORT", std::to_string(port));
  133. preset.set_option(ctx_preset, "LLAMA_ARG_ALIAS", name);
  134. // TODO: maybe validate preset before rendering ?
  135. // render args
  136. args = preset.to_args(bin_path);
  137. }
  138. //
  139. // server_models
  140. //
  141. server_models::server_models(
  142. const common_params & params,
  143. int argc,
  144. char ** argv)
  145. : ctx_preset(LLAMA_EXAMPLE_SERVER),
  146. base_params(params),
  147. base_env(get_environment()),
  148. base_preset(ctx_preset.load_from_args(argc, argv)) {
  149. // clean up base preset
  150. unset_reserved_args(base_preset, true);
  151. // set binary path
  152. try {
  153. bin_path = get_server_exec_path().string();
  154. } catch (const std::exception & e) {
  155. bin_path = argv[0];
  156. LOG_WRN("failed to get server executable path: %s\n", e.what());
  157. LOG_WRN("using original argv[0] as fallback: %s\n", argv[0]);
  158. }
  159. load_models();
  160. }
  161. void server_models::add_model(server_model_meta && meta) {
  162. if (mapping.find(meta.name) != mapping.end()) {
  163. throw std::runtime_error(string_format("model '%s' appears multiple times", meta.name.c_str()));
  164. }
  165. meta.update_args(ctx_preset, bin_path); // render args
  166. std::string name = meta.name;
  167. mapping[name] = instance_t{
  168. /* subproc */ std::make_shared<subprocess_s>(),
  169. /* th */ std::thread(),
  170. /* meta */ std::move(meta)
  171. };
  172. }
  173. // TODO: allow refreshing cached model list
  174. void server_models::load_models() {
  175. // loading models from 3 sources:
  176. // 1. cached models
  177. common_presets cached_models = ctx_preset.load_from_cache();
  178. SRV_INF("Loaded %zu cached model presets\n", cached_models.size());
  179. // 2. local models from --models-dir
  180. common_presets local_models;
  181. if (!base_params.models_dir.empty()) {
  182. local_models = ctx_preset.load_from_models_dir(base_params.models_dir);
  183. SRV_INF("Loaded %zu local model presets from %s\n", local_models.size(), base_params.models_dir.c_str());
  184. }
  185. // 3. custom-path models from presets
  186. common_preset global = {};
  187. common_presets custom_presets = {};
  188. if (!base_params.models_preset.empty()) {
  189. custom_presets = ctx_preset.load_from_ini(base_params.models_preset, global);
  190. SRV_INF("Loaded %zu custom model presets from %s\n", custom_presets.size(), base_params.models_preset.c_str());
  191. }
  192. // cascade, apply global preset first
  193. cached_models = ctx_preset.cascade(global, cached_models);
  194. local_models = ctx_preset.cascade(global, local_models);
  195. custom_presets = ctx_preset.cascade(global, custom_presets);
  196. // note: if a model exists in both cached and local, local takes precedence
  197. common_presets final_presets;
  198. for (const auto & [name, preset] : cached_models) {
  199. final_presets[name] = preset;
  200. }
  201. for (const auto & [name, preset] : local_models) {
  202. final_presets[name] = preset;
  203. }
  204. // process custom presets from INI
  205. for (const auto & [name, custom] : custom_presets) {
  206. if (final_presets.find(name) != final_presets.end()) {
  207. // apply custom config if exists
  208. common_preset & target = final_presets[name];
  209. target.merge(custom);
  210. } else {
  211. // otherwise add directly
  212. final_presets[name] = custom;
  213. }
  214. }
  215. // server base preset from CLI args take highest precedence
  216. for (auto & [name, preset] : final_presets) {
  217. preset.merge(base_preset);
  218. }
  219. // convert presets to server_model_meta and add to mapping
  220. for (const auto & preset : final_presets) {
  221. server_model_meta meta{
  222. /* preset */ preset.second,
  223. /* name */ preset.first,
  224. /* port */ 0,
  225. /* status */ SERVER_MODEL_STATUS_UNLOADED,
  226. /* last_used */ 0,
  227. /* args */ std::vector<std::string>(),
  228. /* exit_code */ 0,
  229. /* stop_timeout */ DEFAULT_STOP_TIMEOUT,
  230. };
  231. add_model(std::move(meta));
  232. }
  233. // log available models
  234. {
  235. std::unordered_set<std::string> custom_names;
  236. for (const auto & [name, preset] : custom_presets) {
  237. custom_names.insert(name);
  238. }
  239. SRV_INF("Available models (%zu) (*: custom preset)\n", mapping.size());
  240. for (const auto & [name, inst] : mapping) {
  241. bool has_custom = custom_names.find(name) != custom_names.end();
  242. SRV_INF(" %c %s\n", has_custom ? '*' : ' ', name.c_str());
  243. }
  244. }
  245. // handle custom stop-timeout option
  246. for (auto & [name, inst] : mapping) {
  247. std::string val;
  248. if (inst.meta.preset.get_option(COMMON_ARG_PRESET_STOP_TIMEOUT, val)) {
  249. try {
  250. inst.meta.stop_timeout = std::stoi(val);
  251. } catch (...) {
  252. SRV_WRN("invalid stop-timeout value '%s' for model '%s', using default %d seconds\n",
  253. val.c_str(), name.c_str(), DEFAULT_STOP_TIMEOUT);
  254. inst.meta.stop_timeout = DEFAULT_STOP_TIMEOUT;
  255. }
  256. }
  257. }
  258. // load any autoload models
  259. std::vector<std::string> models_to_load;
  260. for (const auto & [name, inst] : mapping) {
  261. std::string val;
  262. if (inst.meta.preset.get_option(COMMON_ARG_PRESET_LOAD_ON_STARTUP, val)) {
  263. models_to_load.push_back(name);
  264. }
  265. }
  266. if ((int)models_to_load.size() > base_params.models_max) {
  267. throw std::runtime_error(string_format(
  268. "number of models to load on startup (%zu) exceeds models_max (%d)",
  269. models_to_load.size(),
  270. base_params.models_max
  271. ));
  272. }
  273. for (const auto & name : models_to_load) {
  274. SRV_INF("(startup) loading model %s\n", name.c_str());
  275. load(name);
  276. }
  277. }
  278. void server_models::update_meta(const std::string & name, const server_model_meta & meta) {
  279. std::lock_guard<std::mutex> lk(mutex);
  280. auto it = mapping.find(name);
  281. if (it != mapping.end()) {
  282. it->second.meta = meta;
  283. }
  284. cv.notify_all(); // notify wait_until_loaded
  285. }
  286. bool server_models::has_model(const std::string & name) {
  287. std::lock_guard<std::mutex> lk(mutex);
  288. return mapping.find(name) != mapping.end();
  289. }
  290. std::optional<server_model_meta> server_models::get_meta(const std::string & name) {
  291. std::lock_guard<std::mutex> lk(mutex);
  292. auto it = mapping.find(name);
  293. if (it != mapping.end()) {
  294. return it->second.meta;
  295. }
  296. return std::nullopt;
  297. }
  298. static int get_free_port() {
  299. #ifdef _WIN32
  300. WSADATA wsaData;
  301. if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
  302. return -1;
  303. }
  304. typedef SOCKET native_socket_t;
  305. #define INVALID_SOCKET_VAL INVALID_SOCKET
  306. #define CLOSE_SOCKET(s) closesocket(s)
  307. #else
  308. typedef int native_socket_t;
  309. #define INVALID_SOCKET_VAL -1
  310. #define CLOSE_SOCKET(s) close(s)
  311. #endif
  312. native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
  313. if (sock == INVALID_SOCKET_VAL) {
  314. #ifdef _WIN32
  315. WSACleanup();
  316. #endif
  317. return -1;
  318. }
  319. struct sockaddr_in serv_addr;
  320. std::memset(&serv_addr, 0, sizeof(serv_addr));
  321. serv_addr.sin_family = AF_INET;
  322. serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
  323. serv_addr.sin_port = htons(0);
  324. if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
  325. CLOSE_SOCKET(sock);
  326. #ifdef _WIN32
  327. WSACleanup();
  328. #endif
  329. return -1;
  330. }
  331. #ifdef _WIN32
  332. int namelen = sizeof(serv_addr);
  333. #else
  334. socklen_t namelen = sizeof(serv_addr);
  335. #endif
  336. if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
  337. CLOSE_SOCKET(sock);
  338. #ifdef _WIN32
  339. WSACleanup();
  340. #endif
  341. return -1;
  342. }
  343. int port = ntohs(serv_addr.sin_port);
  344. CLOSE_SOCKET(sock);
  345. #ifdef _WIN32
  346. WSACleanup();
  347. #endif
  348. return port;
  349. }
  350. // helper to convert vector<string> to char **
  351. // pointers are only valid as long as the original vector is valid
  352. static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & vec) {
  353. std::vector<char *> result;
  354. result.reserve(vec.size() + 1);
  355. for (const auto & s : vec) {
  356. result.push_back(const_cast<char*>(s.c_str()));
  357. }
  358. result.push_back(nullptr);
  359. return result;
  360. }
  361. std::vector<server_model_meta> server_models::get_all_meta() {
  362. std::lock_guard<std::mutex> lk(mutex);
  363. std::vector<server_model_meta> result;
  364. result.reserve(mapping.size());
  365. for (const auto & [name, inst] : mapping) {
  366. result.push_back(inst.meta);
  367. }
  368. return result;
  369. }
  370. void server_models::unload_lru() {
  371. if (base_params.models_max <= 0) {
  372. return; // no limit
  373. }
  374. // remove one of the servers if we passed the models_max (least recently used - LRU)
  375. std::string lru_model_name = "";
  376. int64_t lru_last_used = ggml_time_ms();
  377. size_t count_active = 0;
  378. {
  379. std::unique_lock<std::mutex> lk(mutex);
  380. for (const auto & m : mapping) {
  381. if (m.second.meta.is_active()) {
  382. count_active++;
  383. if (m.second.meta.last_used < lru_last_used) {
  384. lru_model_name = m.first;
  385. lru_last_used = m.second.meta.last_used;
  386. }
  387. }
  388. }
  389. }
  390. if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) {
  391. SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str());
  392. unload(lru_model_name);
  393. // wait for unload to complete
  394. {
  395. std::unique_lock<std::mutex> lk(mutex);
  396. cv.wait(lk, [this, &lru_model_name]() {
  397. return mapping[lru_model_name].meta.status == SERVER_MODEL_STATUS_UNLOADED;
  398. });
  399. }
  400. }
  401. }
  402. void server_models::load(const std::string & name) {
  403. if (!has_model(name)) {
  404. throw std::runtime_error("model name=" + name + " is not found");
  405. }
  406. unload_lru();
  407. std::lock_guard<std::mutex> lk(mutex);
  408. auto meta = mapping[name].meta;
  409. if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
  410. SRV_INF("model %s is not ready\n", name.c_str());
  411. return;
  412. }
  413. // prepare new instance info
  414. instance_t inst;
  415. inst.meta = meta;
  416. inst.meta.port = get_free_port();
  417. inst.meta.status = SERVER_MODEL_STATUS_LOADING;
  418. inst.meta.last_used = ggml_time_ms();
  419. if (inst.meta.port <= 0) {
  420. throw std::runtime_error("failed to get a port number");
  421. }
  422. inst.subproc = std::make_shared<subprocess_s>();
  423. {
  424. SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port);
  425. inst.meta.update_args(ctx_preset, bin_path); // render args
  426. std::vector<std::string> child_args = inst.meta.args; // copy
  427. std::vector<std::string> child_env = base_env; // copy
  428. child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));
  429. SRV_INF("%s", "spawning server instance with args:\n");
  430. for (const auto & arg : child_args) {
  431. SRV_INF(" %s\n", arg.c_str());
  432. }
  433. inst.meta.args = child_args; // save for debugging
  434. std::vector<char *> argv = to_char_ptr_array(child_args);
  435. std::vector<char *> envp = to_char_ptr_array(child_env);
  436. // TODO @ngxson : maybe separate stdout and stderr in the future
  437. // so that we can use stdout for commands and stderr for logging
  438. int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
  439. int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get());
  440. if (result != 0) {
  441. throw std::runtime_error("failed to spawn server instance");
  442. }
  443. inst.stdin_file = subprocess_stdin(inst.subproc.get());
  444. }
  445. // start a thread to manage the child process
  446. // captured variables are guaranteed to be destroyed only after the thread is joined
  447. inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() {
  448. FILE * stdin_file = subprocess_stdin(child_proc.get());
  449. FILE * stdout_file = subprocess_stdout(child_proc.get()); // combined stdout/stderr
  450. std::thread log_thread([&]() {
  451. // read stdout/stderr and forward to main server log
  452. // also handle status report from child process
  453. bool state_received = false; // true if child state received
  454. if (stdout_file) {
  455. char buffer[4096];
  456. while (fgets(buffer, sizeof(buffer), stdout_file) != nullptr) {
  457. LOG("[%5d] %s", port, buffer);
  458. if (!state_received && std::strstr(buffer, CMD_CHILD_TO_ROUTER_READY) != nullptr) {
  459. // child process is ready
  460. this->update_status(name, SERVER_MODEL_STATUS_LOADED, 0);
  461. state_received = true;
  462. }
  463. }
  464. } else {
  465. SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str());
  466. }
  467. });
  468. std::thread stopping_thread([&]() {
  469. // thread to monitor stopping signal
  470. auto is_stopping = [this, &name]() {
  471. return this->stopping_models.find(name) != this->stopping_models.end();
  472. };
  473. {
  474. std::unique_lock<std::mutex> lk(this->mutex);
  475. this->cv_stop.wait(lk, is_stopping);
  476. }
  477. SRV_INF("stopping model instance name=%s\n", name.c_str());
  478. // send interrupt to child process
  479. fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT);
  480. fflush(stdin_file);
  481. // wait to stop gracefully or timeout
  482. int64_t start_time = ggml_time_ms();
  483. while (true) {
  484. std::unique_lock<std::mutex> lk(this->mutex);
  485. if (!is_stopping()) {
  486. return; // already stopped
  487. }
  488. int64_t elapsed = ggml_time_ms() - start_time;
  489. if (elapsed >= stop_timeout * 1000) {
  490. // timeout, force kill
  491. SRV_WRN("force-killing model instance name=%s after %d seconds timeout\n", name.c_str(), stop_timeout);
  492. subprocess_terminate(child_proc.get());
  493. return;
  494. }
  495. this->cv_stop.wait_for(lk, std::chrono::seconds(1));
  496. }
  497. });
  498. // we reach here when the child process exits
  499. // note: we cannot join() prior to this point because it will close stdin_file
  500. if (log_thread.joinable()) {
  501. log_thread.join();
  502. }
  503. // stop the timeout monitoring thread
  504. {
  505. std::lock_guard<std::mutex> lk(this->mutex);
  506. stopping_models.erase(name);
  507. cv_stop.notify_all();
  508. }
  509. if (stopping_thread.joinable()) {
  510. stopping_thread.join();
  511. }
  512. // get the exit code
  513. int exit_code = 0;
  514. subprocess_join(child_proc.get(), &exit_code);
  515. subprocess_destroy(child_proc.get());
  516. // update status and exit code
  517. this->update_status(name, SERVER_MODEL_STATUS_UNLOADED, exit_code);
  518. SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
  519. });
  520. // clean up old process/thread if exists
  521. {
  522. auto & old_instance = mapping[name];
  523. // old process should have exited already, but just in case, we clean it up here
  524. if (subprocess_alive(old_instance.subproc.get())) {
  525. SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
  526. subprocess_terminate(old_instance.subproc.get()); // force kill
  527. }
  528. if (old_instance.th.joinable()) {
  529. old_instance.th.join();
  530. }
  531. }
  532. mapping[name] = std::move(inst);
  533. cv.notify_all();
  534. }
  535. void server_models::unload(const std::string & name) {
  536. std::lock_guard<std::mutex> lk(mutex);
  537. auto it = mapping.find(name);
  538. if (it != mapping.end()) {
  539. if (it->second.meta.is_active()) {
  540. SRV_INF("unloading model instance name=%s\n", name.c_str());
  541. stopping_models.insert(name);
  542. cv_stop.notify_all();
  543. // status change will be handled by the managing thread
  544. } else {
  545. SRV_WRN("model instance name=%s is not loaded\n", name.c_str());
  546. }
  547. }
  548. }
  549. void server_models::unload_all() {
  550. std::vector<std::thread> to_join;
  551. {
  552. std::lock_guard<std::mutex> lk(mutex);
  553. for (auto & [name, inst] : mapping) {
  554. if (inst.meta.is_active()) {
  555. SRV_INF("unloading model instance name=%s\n", name.c_str());
  556. stopping_models.insert(name);
  557. cv_stop.notify_all();
  558. // status change will be handled by the managing thread
  559. }
  560. // moving the thread to join list to avoid deadlock
  561. to_join.push_back(std::move(inst.th));
  562. }
  563. }
  564. for (auto & th : to_join) {
  565. if (th.joinable()) {
  566. th.join();
  567. }
  568. }
  569. }
  570. void server_models::update_status(const std::string & name, server_model_status status, int exit_code) {
  571. std::unique_lock<std::mutex> lk(mutex);
  572. auto it = mapping.find(name);
  573. if (it != mapping.end()) {
  574. auto & meta = it->second.meta;
  575. meta.status = status;
  576. meta.exit_code = exit_code;
  577. }
  578. cv.notify_all();
  579. }
  580. void server_models::wait_until_loaded(const std::string & name) {
  581. std::unique_lock<std::mutex> lk(mutex);
  582. cv.wait(lk, [this, &name]() {
  583. auto it = mapping.find(name);
  584. if (it != mapping.end()) {
  585. return it->second.meta.status != SERVER_MODEL_STATUS_LOADING;
  586. }
  587. return false;
  588. });
  589. }
  590. bool server_models::ensure_model_loaded(const std::string & name) {
  591. auto meta = get_meta(name);
  592. if (!meta.has_value()) {
  593. throw std::runtime_error("model name=" + name + " is not found");
  594. }
  595. if (meta->status == SERVER_MODEL_STATUS_LOADED) {
  596. return false; // already loaded
  597. }
  598. if (meta->status == SERVER_MODEL_STATUS_UNLOADED) {
  599. SRV_INF("model name=%s is not loaded, loading...\n", name.c_str());
  600. load(name);
  601. }
  602. // for loading state
  603. SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str());
  604. wait_until_loaded(name);
  605. // check final status
  606. meta = get_meta(name);
  607. if (!meta.has_value() || meta->is_failed()) {
  608. throw std::runtime_error("model name=" + name + " failed to load");
  609. }
  610. return true;
  611. }
  612. server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used) {
  613. auto meta = get_meta(name);
  614. if (!meta.has_value()) {
  615. throw std::runtime_error("model name=" + name + " is not found");
  616. }
  617. if (meta->status != SERVER_MODEL_STATUS_LOADED) {
  618. throw std::invalid_argument("model name=" + name + " is not loaded");
  619. }
  620. if (update_last_used) {
  621. std::unique_lock<std::mutex> lk(mutex);
  622. mapping[name].meta.last_used = ggml_time_ms();
  623. }
  624. SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port);
  625. auto proxy = std::make_unique<server_http_proxy>(
  626. method,
  627. CHILD_ADDR,
  628. meta->port,
  629. req.path,
  630. req.headers,
  631. req.body,
  632. req.should_stop,
  633. base_params.timeout_read,
  634. base_params.timeout_write
  635. );
  636. return proxy;
  637. }
  638. std::thread server_models::setup_child_server(const std::function<void(int)> & shutdown_handler) {
  639. // send a notification to the router server that a model instance is ready
  640. common_log_pause(common_log_main());
  641. fflush(stdout);
  642. fprintf(stdout, "%s\n", CMD_CHILD_TO_ROUTER_READY);
  643. fflush(stdout);
  644. common_log_resume(common_log_main());
  645. // setup thread for monitoring stdin
  646. return std::thread([shutdown_handler]() {
  647. // wait for EOF on stdin
  648. SRV_INF("%s", "child server monitoring thread started, waiting for EOF on stdin...\n");
  649. bool eof = false;
  650. while (true) {
  651. std::string line;
  652. if (!std::getline(std::cin, line)) {
  653. // EOF detected, that means the router server is unexpectedly exit or killed
  654. eof = true;
  655. break;
  656. }
  657. if (line.find(CMD_ROUTER_TO_CHILD_EXIT) != std::string::npos) {
  658. SRV_INF("%s", "exit command received, exiting...\n");
  659. shutdown_handler(0);
  660. break;
  661. }
  662. }
  663. if (eof) {
  664. SRV_INF("%s", "EOF on stdin detected, forcing shutdown...\n");
  665. exit(1);
  666. }
  667. });
  668. }
  669. //
  670. // server_models_routes
  671. //
  672. static void res_ok(std::unique_ptr<server_http_res> & res, const json & response_data) {
  673. res->status = 200;
  674. res->data = safe_json_to_str(response_data);
  675. }
  676. static void res_err(std::unique_ptr<server_http_res> & res, const json & error_data) {
  677. res->status = json_value(error_data, "code", 500);
  678. res->data = safe_json_to_str({{ "error", error_data }});
  679. }
  680. static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr<server_http_res> & res) {
  681. if (name.empty()) {
  682. res_err(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST));
  683. return false;
  684. }
  685. auto meta = models.get_meta(name);
  686. if (!meta.has_value()) {
  687. res_err(res, format_error_response("model not found", ERROR_TYPE_INVALID_REQUEST));
  688. return false;
  689. }
  690. if (models_autoload) {
  691. models.ensure_model_loaded(name);
  692. } else {
  693. if (meta->status != SERVER_MODEL_STATUS_LOADED) {
  694. res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
  695. return false;
  696. }
  697. }
  698. return true;
  699. }
  700. static bool is_autoload(const common_params & params, const server_http_req & req) {
  701. std::string autoload = req.get_param("autoload");
  702. if (autoload.empty()) {
  703. return params.models_autoload;
  704. } else {
  705. return autoload == "true" || autoload == "1";
  706. }
  707. }
  708. void server_models_routes::init_routes() {
  709. this->get_router_props = [this](const server_http_req & req) {
  710. std::string name = req.get_param("model");
  711. if (name.empty()) {
  712. // main instance
  713. auto res = std::make_unique<server_http_res>();
  714. res_ok(res, {
  715. // TODO: add support for this on web UI
  716. {"role", "router"},
  717. {"max_instances", 4}, // dummy value for testing
  718. // this is a dummy response to make sure webui doesn't break
  719. {"model_alias", "llama-server"},
  720. {"model_path", "none"},
  721. {"default_generation_settings", {
  722. {"params", json{}},
  723. {"n_ctx", 0},
  724. }},
  725. {"webui_settings", webui_settings},
  726. });
  727. return res;
  728. }
  729. return proxy_get(req);
  730. };
  731. this->proxy_get = [this](const server_http_req & req) {
  732. std::string method = "GET";
  733. std::string name = req.get_param("model");
  734. bool autoload = is_autoload(params, req);
  735. auto error_res = std::make_unique<server_http_res>();
  736. if (!router_validate_model(name, models, autoload, error_res)) {
  737. return error_res;
  738. }
  739. return models.proxy_request(req, method, name, false);
  740. };
  741. this->proxy_post = [this](const server_http_req & req) {
  742. std::string method = "POST";
  743. json body = json::parse(req.body);
  744. std::string name = json_value(body, "model", std::string());
  745. bool autoload = is_autoload(params, req);
  746. auto error_res = std::make_unique<server_http_res>();
  747. if (!router_validate_model(name, models, autoload, error_res)) {
  748. return error_res;
  749. }
  750. return models.proxy_request(req, method, name, true); // update last usage for POST request only
  751. };
  752. this->post_router_models_load = [this](const server_http_req & req) {
  753. auto res = std::make_unique<server_http_res>();
  754. json body = json::parse(req.body);
  755. std::string name = json_value(body, "model", std::string());
  756. auto model = models.get_meta(name);
  757. if (!model.has_value()) {
  758. res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND));
  759. return res;
  760. }
  761. if (model->status == SERVER_MODEL_STATUS_LOADED) {
  762. res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST));
  763. return res;
  764. }
  765. models.load(name);
  766. res_ok(res, {{"success", true}});
  767. return res;
  768. };
  769. this->get_router_models = [this](const server_http_req &) {
  770. auto res = std::make_unique<server_http_res>();
  771. json models_json = json::array();
  772. auto all_models = models.get_all_meta();
  773. std::time_t t = std::time(0);
  774. for (const auto & meta : all_models) {
  775. json status {
  776. {"value", server_model_status_to_string(meta.status)},
  777. {"args", meta.args},
  778. };
  779. if (!meta.preset.name.empty()) {
  780. common_preset preset_copy = meta.preset;
  781. unset_reserved_args(preset_copy, false);
  782. preset_copy.unset_option("LLAMA_ARG_HOST");
  783. preset_copy.unset_option("LLAMA_ARG_PORT");
  784. preset_copy.unset_option("LLAMA_ARG_ALIAS");
  785. status["preset"] = preset_copy.to_ini();
  786. }
  787. if (meta.is_failed()) {
  788. status["exit_code"] = meta.exit_code;
  789. status["failed"] = true;
  790. }
  791. models_json.push_back(json {
  792. {"id", meta.name},
  793. {"object", "model"}, // for OAI-compat
  794. {"owned_by", "llamacpp"}, // for OAI-compat
  795. {"created", t}, // for OAI-compat
  796. {"status", status},
  797. // TODO: add other fields, may require reading GGUF metadata
  798. });
  799. }
  800. res_ok(res, {
  801. {"data", models_json},
  802. {"object", "list"},
  803. });
  804. return res;
  805. };
  806. this->post_router_models_unload = [this](const server_http_req & req) {
  807. auto res = std::make_unique<server_http_res>();
  808. json body = json::parse(req.body);
  809. std::string name = json_value(body, "model", std::string());
  810. auto model = models.get_meta(name);
  811. if (!model.has_value()) {
  812. res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
  813. return res;
  814. }
  815. if (!model->is_active()) {
  816. res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
  817. return res;
  818. }
  819. models.unload(name);
  820. res_ok(res, {{"success", true}});
  821. return res;
  822. };
  823. }
  824. //
  825. // server_http_proxy
  826. //
  827. // simple implementation of a pipe
  828. // used for streaming data between threads
  829. template<typename T>
  830. struct pipe_t {
  831. std::mutex mutex;
  832. std::condition_variable cv;
  833. std::queue<T> queue;
  834. std::atomic<bool> writer_closed{false};
  835. std::atomic<bool> reader_closed{false};
  836. void close_write() {
  837. writer_closed.store(true, std::memory_order_relaxed);
  838. cv.notify_all();
  839. }
  840. void close_read() {
  841. reader_closed.store(true, std::memory_order_relaxed);
  842. cv.notify_all();
  843. }
  844. bool read(T & output, const std::function<bool()> & should_stop) {
  845. std::unique_lock<std::mutex> lk(mutex);
  846. constexpr auto poll_interval = std::chrono::milliseconds(500);
  847. while (true) {
  848. if (!queue.empty()) {
  849. output = std::move(queue.front());
  850. queue.pop();
  851. return true;
  852. }
  853. if (writer_closed.load()) {
  854. return false; // clean EOF
  855. }
  856. if (should_stop()) {
  857. close_read(); // signal broken pipe to writer
  858. return false; // cancelled / reader no longer alive
  859. }
  860. cv.wait_for(lk, poll_interval);
  861. }
  862. }
  863. bool write(T && data) {
  864. std::lock_guard<std::mutex> lk(mutex);
  865. if (reader_closed.load()) {
  866. return false; // broken pipe
  867. }
  868. queue.push(std::move(data));
  869. cv.notify_one();
  870. return true;
  871. }
  872. };
  873. static std::string to_lower_copy(const std::string & value) {
  874. std::string lowered(value.size(), '\0');
  875. std::transform(value.begin(), value.end(), lowered.begin(), [](unsigned char c) { return std::tolower(c); });
  876. return lowered;
  877. }
  878. static bool should_strip_proxy_header(const std::string & header_name) {
  879. // Headers that get duplicated when router forwards child responses
  880. if (header_name == "server" ||
  881. header_name == "transfer-encoding" ||
  882. header_name == "content-length" || // quick fix for https://github.com/ggml-org/llama.cpp/issues/17710
  883. header_name == "keep-alive") {
  884. return true;
  885. }
  886. // Router injects CORS, child also sends them: duplicate
  887. if (header_name.rfind("access-control-", 0) == 0) {
  888. return true;
  889. }
  890. return false;
  891. }
  892. server_http_proxy::server_http_proxy(
  893. const std::string & method,
  894. const std::string & host,
  895. int port,
  896. const std::string & path,
  897. const std::map<std::string, std::string> & headers,
  898. const std::string & body,
  899. const std::function<bool()> should_stop,
  900. int32_t timeout_read,
  901. int32_t timeout_write
  902. ) {
  903. // shared between reader and writer threads
  904. auto cli = std::make_shared<httplib::Client>(host, port);
  905. auto pipe = std::make_shared<pipe_t<msg_t>>();
  906. // setup Client
  907. cli->set_connection_timeout(0, 200000); // 200 milliseconds
  908. cli->set_write_timeout(timeout_read, 0); // reversed for cli (client) vs srv (server)
  909. cli->set_read_timeout(timeout_write, 0);
  910. this->status = 500; // to be overwritten upon response
  911. this->cleanup = [pipe]() {
  912. pipe->close_read();
  913. pipe->close_write();
  914. };
  915. // wire up the receive end of the pipe
  916. this->next = [pipe, should_stop](std::string & out) -> bool {
  917. msg_t msg;
  918. bool has_next = pipe->read(msg, should_stop);
  919. if (!msg.data.empty()) {
  920. out = std::move(msg.data);
  921. }
  922. return has_next; // false if EOF or pipe broken
  923. };
  924. // wire up the HTTP client
  925. // note: do NOT capture `this` pointer, as it may be destroyed before the thread ends
  926. httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) {
  927. msg_t msg;
  928. msg.status = response.status;
  929. for (const auto & [key, value] : response.headers) {
  930. const auto lowered = to_lower_copy(key);
  931. if (should_strip_proxy_header(lowered)) {
  932. continue;
  933. }
  934. if (lowered == "content-type") {
  935. msg.content_type = value;
  936. continue;
  937. }
  938. msg.headers[key] = value;
  939. }
  940. return pipe->write(std::move(msg)); // send headers first
  941. };
  942. httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) {
  943. // send data chunks
  944. // returns false if pipe is closed / broken (signal to stop receiving)
  945. return pipe->write({{}, 0, std::string(data, data_length), ""});
  946. };
  947. // prepare the request to destination server
  948. httplib::Request req;
  949. {
  950. req.method = method;
  951. req.path = path;
  952. for (const auto & [key, value] : headers) {
  953. req.set_header(key, value);
  954. }
  955. req.body = body;
  956. req.response_handler = response_handler;
  957. req.content_receiver = content_receiver;
  958. }
  959. // start the proxy thread
  960. SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str());
  961. this->thread = std::thread([cli, pipe, req]() {
  962. auto result = cli->send(std::move(req));
  963. if (result.error() != httplib::Error::Success) {
  964. auto err_str = httplib::to_string(result.error());
  965. SRV_ERR("http client error: %s\n", err_str.c_str());
  966. pipe->write({{}, 500, "", ""}); // header
  967. pipe->write({{}, 0, "proxy error: " + err_str, ""}); // body
  968. }
  969. pipe->close_write(); // signal EOF to reader
  970. SRV_DBG("%s", "client request thread ended\n");
  971. });
  972. this->thread.detach();
  973. // wait for the first chunk (headers)
  974. {
  975. msg_t header;
  976. if (pipe->read(header, should_stop)) {
  977. SRV_DBG("%s", "received response headers\n");
  978. this->status = header.status;
  979. this->headers = std::move(header.headers);
  980. if (!header.content_type.empty()) {
  981. this->content_type = std::move(header.content_type);
  982. }
  983. } else {
  984. SRV_DBG("%s", "no response headers received (request cancelled?)\n");
  985. }
  986. }
  987. }