server-models.cpp 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110
  1. #include "server-common.h"
  2. #include "server-models.h"
  3. #include "preset.h"
  4. #include "download.h"
  5. #include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
  6. #include <sheredom/subprocess.h>
  7. #include <functional>
  8. #include <algorithm>
  9. #include <thread>
  10. #include <mutex>
  11. #include <condition_variable>
  12. #include <cstring>
  13. #include <atomic>
  14. #include <chrono>
  15. #include <queue>
  16. #include <filesystem>
  17. #ifdef _WIN32
  18. #include <winsock2.h>
  19. #else
  20. #include <sys/socket.h>
  21. #include <netinet/in.h>
  22. #include <arpa/inet.h>
  23. #include <unistd.h>
  24. #endif
  25. #if defined(__APPLE__) && defined(__MACH__)
  26. // macOS: use _NSGetExecutablePath to get the executable path
  27. #include <mach-o/dyld.h>
  28. #include <limits.h>
  29. #endif
  30. #define CMD_EXIT "exit"
  31. // address for child process, this is needed because router may run on 0.0.0.0
  32. // ref: https://github.com/ggml-org/llama.cpp/issues/17862
  33. #define CHILD_ADDR "127.0.0.1"
  34. static std::filesystem::path get_server_exec_path() {
  35. #if defined(_WIN32)
  36. wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths
  37. DWORD len = GetModuleFileNameW(nullptr, buf, _countof(buf));
  38. if (len == 0 || len >= _countof(buf)) {
  39. throw std::runtime_error("GetModuleFileNameW failed or path too long");
  40. }
  41. return std::filesystem::path(buf);
  42. #elif defined(__APPLE__) && defined(__MACH__)
  43. char small_path[PATH_MAX];
  44. uint32_t size = sizeof(small_path);
  45. if (_NSGetExecutablePath(small_path, &size) == 0) {
  46. // resolve any symlinks to get absolute path
  47. try {
  48. return std::filesystem::canonical(std::filesystem::path(small_path));
  49. } catch (...) {
  50. return std::filesystem::path(small_path);
  51. }
  52. } else {
  53. // buffer was too small, allocate required size and call again
  54. std::vector<char> buf(size);
  55. if (_NSGetExecutablePath(buf.data(), &size) == 0) {
  56. try {
  57. return std::filesystem::canonical(std::filesystem::path(buf.data()));
  58. } catch (...) {
  59. return std::filesystem::path(buf.data());
  60. }
  61. }
  62. throw std::runtime_error("_NSGetExecutablePath failed after buffer resize");
  63. }
  64. #else
  65. char path[FILENAME_MAX];
  66. ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX);
  67. if (count <= 0) {
  68. throw std::runtime_error("failed to resolve /proc/self/exe");
  69. }
  70. return std::filesystem::path(std::string(path, count));
  71. #endif
  72. }
  73. struct local_model {
  74. std::string name;
  75. std::string path;
  76. std::string path_mmproj;
  77. };
  78. static std::vector<local_model> list_local_models(const std::string & dir) {
  79. if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) {
  80. throw std::runtime_error(string_format("error: '%s' does not exist or is not a directory\n", dir.c_str()));
  81. }
  82. std::vector<local_model> models;
  83. auto scan_subdir = [&models](const std::string & subdir_path, const std::string & name) {
  84. auto files = fs_list(subdir_path, false);
  85. common_file_info model_file;
  86. common_file_info first_shard_file;
  87. common_file_info mmproj_file;
  88. for (const auto & file : files) {
  89. if (string_ends_with(file.name, ".gguf")) {
  90. if (file.name.find("mmproj") != std::string::npos) {
  91. mmproj_file = file;
  92. } else if (file.name.find("-00001-of-") != std::string::npos) {
  93. first_shard_file = file;
  94. } else {
  95. model_file = file;
  96. }
  97. }
  98. }
  99. // single file model
  100. local_model model{
  101. /* name */ name,
  102. /* path */ first_shard_file.path.empty() ? model_file.path : first_shard_file.path,
  103. /* path_mmproj */ mmproj_file.path // can be empty
  104. };
  105. if (!model.path.empty()) {
  106. models.push_back(model);
  107. }
  108. };
  109. auto files = fs_list(dir, true);
  110. for (const auto & file : files) {
  111. if (file.is_dir) {
  112. scan_subdir(file.path, file.name);
  113. } else if (string_ends_with(file.name, ".gguf")) {
  114. // single file model
  115. std::string name = file.name;
  116. string_replace_all(name, ".gguf", "");
  117. local_model model{
  118. /* name */ name,
  119. /* path */ file.path,
  120. /* path_mmproj */ ""
  121. };
  122. models.push_back(model);
  123. }
  124. }
  125. return models;
  126. }
  127. //
  128. // server_presets
  129. //
  130. server_presets::server_presets(int argc, char ** argv, common_params & base_params, const std::string & presets_path)
  131. : ctx_params(common_params_parser_init(base_params, LLAMA_EXAMPLE_SERVER)) {
  132. if (!presets_path.empty()) {
  133. presets = common_presets_load(presets_path, ctx_params);
  134. SRV_INF("Loaded %zu presets from %s\n", presets.size(), presets_path.c_str());
  135. }
  136. // populate reserved args (will be appended by the router)
  137. for (auto & opt : ctx_params.options) {
  138. if (opt.env == nullptr) {
  139. continue;
  140. }
  141. std::string env = opt.env;
  142. if (env == "LLAMA_ARG_PORT" ||
  143. env == "LLAMA_ARG_HOST" ||
  144. env == "LLAMA_ARG_ALIAS" ||
  145. env == "LLAMA_ARG_API_KEY" ||
  146. env == "LLAMA_ARG_MODELS_DIR" ||
  147. env == "LLAMA_ARG_MODELS_MAX" ||
  148. env == "LLAMA_ARG_MODELS_PRESET" ||
  149. env == "LLAMA_ARG_MODEL" ||
  150. env == "LLAMA_ARG_MMPROJ" ||
  151. env == "LLAMA_ARG_HF_REPO" ||
  152. env == "LLAMA_ARG_NO_MODELS_AUTOLOAD") {
  153. control_args[env] = opt;
  154. }
  155. }
  156. // read base args from router's argv
  157. common_params_to_map(argc, argv, LLAMA_EXAMPLE_SERVER, base_args);
  158. // remove any router-controlled args from base_args
  159. for (const auto & cargs : control_args) {
  160. auto it = base_args.find(cargs.second);
  161. if (it != base_args.end()) {
  162. base_args.erase(it);
  163. }
  164. }
  165. }
  166. common_preset server_presets::get_preset(const std::string & name) {
  167. auto it = presets.find(name);
  168. if (it != presets.end()) {
  169. return it->second;
  170. }
  171. return common_preset();
  172. }
  173. void server_presets::render_args(server_model_meta & meta) {
  174. common_preset preset = meta.preset; // copy
  175. // merging 3 kinds of args:
  176. // 1. model-specific args (from preset)
  177. // force removing control args if any
  178. for (auto & cargs : control_args) {
  179. if (preset.options.find(cargs.second) != preset.options.end()) {
  180. SRV_WRN("Preset '%s' contains reserved arg '%s', removing it\n", preset.name.c_str(), cargs.second.args[0]);
  181. preset.options.erase(cargs.second);
  182. }
  183. }
  184. // 2. base args (from router)
  185. // inherit from base args
  186. for (const auto & [arg, value] : base_args) {
  187. preset.options[arg] = value;
  188. }
  189. // 3. control args (from router)
  190. // set control values
  191. preset.options[control_args["LLAMA_ARG_HOST"]] = CHILD_ADDR;
  192. preset.options[control_args["LLAMA_ARG_PORT"]] = std::to_string(meta.port);
  193. preset.options[control_args["LLAMA_ARG_ALIAS"]] = meta.name;
  194. if (meta.in_cache) {
  195. preset.options[control_args["LLAMA_ARG_HF_REPO"]] = meta.name;
  196. } else {
  197. preset.options[control_args["LLAMA_ARG_MODEL"]] = meta.path;
  198. if (!meta.path_mmproj.empty()) {
  199. preset.options[control_args["LLAMA_ARG_MMPROJ"]] = meta.path_mmproj;
  200. }
  201. }
  202. meta.args = preset.to_args();
  203. // add back the binary path at the front
  204. meta.args.insert(meta.args.begin(), get_server_exec_path().string());
  205. }
  206. //
  207. // server_models
  208. //
  209. server_models::server_models(
  210. const common_params & params,
  211. int argc,
  212. char ** argv,
  213. char ** envp) : base_params(params), presets(argc, argv, base_params, params.models_preset) {
  214. for (int i = 0; i < argc; i++) {
  215. base_args.push_back(std::string(argv[i]));
  216. }
  217. for (char ** env = envp; *env != nullptr; env++) {
  218. base_env.push_back(std::string(*env));
  219. }
  220. GGML_ASSERT(!base_args.empty());
  221. // set binary path
  222. try {
  223. base_args[0] = get_server_exec_path().string();
  224. } catch (const std::exception & e) {
  225. LOG_WRN("failed to get server executable path: %s\n", e.what());
  226. LOG_WRN("using original argv[0] as fallback: %s\n", base_args[0].c_str());
  227. }
  228. load_models();
  229. }
  230. void server_models::add_model(server_model_meta && meta) {
  231. if (mapping.find(meta.name) != mapping.end()) {
  232. throw std::runtime_error(string_format("model '%s' appears multiple times", meta.name.c_str()));
  233. }
  234. presets.render_args(meta); // populate meta.args
  235. std::string name = meta.name;
  236. mapping[name] = instance_t{
  237. /* subproc */ std::make_shared<subprocess_s>(),
  238. /* th */ std::thread(),
  239. /* meta */ std::move(meta)
  240. };
  241. }
  242. static std::vector<local_model> list_custom_path_models(server_presets & presets) {
  243. // detect any custom-path models in presets
  244. std::vector<local_model> custom_models;
  245. for (auto & [model_name, preset] : presets.presets) {
  246. local_model model;
  247. model.name = model_name;
  248. std::vector<common_arg> to_erase;
  249. for (auto & [arg, value] : preset.options) {
  250. std::string env(arg.env ? arg.env : "");
  251. if (env == "LLAMA_ARG_MODEL") {
  252. model.path = value;
  253. to_erase.push_back(arg);
  254. }
  255. if (env == "LLAMA_ARG_MMPROJ") {
  256. model.path_mmproj = value;
  257. to_erase.push_back(arg);
  258. }
  259. }
  260. for (auto & arg : to_erase) {
  261. preset.options.erase(arg);
  262. }
  263. if (!model.name.empty() && !model.path.empty()) {
  264. custom_models.push_back(model);
  265. }
  266. }
  267. return custom_models;
  268. }
  269. // TODO: allow refreshing cached model list
  270. void server_models::load_models() {
  271. // loading models from 3 sources:
  272. // 1. cached models
  273. auto cached_models = common_list_cached_models();
  274. for (const auto & model : cached_models) {
  275. server_model_meta meta{
  276. /* preset */ presets.get_preset(model.to_string()),
  277. /* name */ model.to_string(),
  278. /* path */ model.manifest_path,
  279. /* path_mmproj */ "", // auto-detected when loading
  280. /* in_cache */ true,
  281. /* port */ 0,
  282. /* status */ SERVER_MODEL_STATUS_UNLOADED,
  283. /* last_used */ 0,
  284. /* args */ std::vector<std::string>(),
  285. /* exit_code */ 0
  286. };
  287. add_model(std::move(meta));
  288. }
  289. // 2. local models specificed via --models-dir
  290. if (!base_params.models_dir.empty()) {
  291. auto local_models = list_local_models(base_params.models_dir);
  292. for (const auto & model : local_models) {
  293. if (mapping.find(model.name) != mapping.end()) {
  294. // already exists in cached models, skip
  295. continue;
  296. }
  297. server_model_meta meta{
  298. /* preset */ presets.get_preset(model.name),
  299. /* name */ model.name,
  300. /* path */ model.path,
  301. /* path_mmproj */ model.path_mmproj,
  302. /* in_cache */ false,
  303. /* port */ 0,
  304. /* status */ SERVER_MODEL_STATUS_UNLOADED,
  305. /* last_used */ 0,
  306. /* args */ std::vector<std::string>(),
  307. /* exit_code */ 0
  308. };
  309. add_model(std::move(meta));
  310. }
  311. }
  312. // 3. custom-path models specified in presets
  313. auto custom_models = list_custom_path_models(presets);
  314. for (const auto & model : custom_models) {
  315. server_model_meta meta{
  316. /* preset */ presets.get_preset(model.name),
  317. /* name */ model.name,
  318. /* path */ model.path,
  319. /* path_mmproj */ model.path_mmproj,
  320. /* in_cache */ false,
  321. /* port */ 0,
  322. /* status */ SERVER_MODEL_STATUS_UNLOADED,
  323. /* last_used */ 0,
  324. /* args */ std::vector<std::string>(),
  325. /* exit_code */ 0
  326. };
  327. add_model(std::move(meta));
  328. }
  329. // log available models
  330. SRV_INF("Available models (%zu) (*: custom preset)\n", mapping.size());
  331. for (const auto & [name, inst] : mapping) {
  332. SRV_INF(" %c %s\n", inst.meta.preset.name.empty() ? ' ' : '*', name.c_str());
  333. }
  334. }
  335. void server_models::update_meta(const std::string & name, const server_model_meta & meta) {
  336. std::lock_guard<std::mutex> lk(mutex);
  337. auto it = mapping.find(name);
  338. if (it != mapping.end()) {
  339. it->second.meta = meta;
  340. }
  341. cv.notify_all(); // notify wait_until_loaded
  342. }
  343. bool server_models::has_model(const std::string & name) {
  344. std::lock_guard<std::mutex> lk(mutex);
  345. return mapping.find(name) != mapping.end();
  346. }
  347. std::optional<server_model_meta> server_models::get_meta(const std::string & name) {
  348. std::lock_guard<std::mutex> lk(mutex);
  349. auto it = mapping.find(name);
  350. if (it != mapping.end()) {
  351. return it->second.meta;
  352. }
  353. return std::nullopt;
  354. }
  355. static int get_free_port() {
  356. #ifdef _WIN32
  357. WSADATA wsaData;
  358. if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
  359. return -1;
  360. }
  361. typedef SOCKET native_socket_t;
  362. #define INVALID_SOCKET_VAL INVALID_SOCKET
  363. #define CLOSE_SOCKET(s) closesocket(s)
  364. #else
  365. typedef int native_socket_t;
  366. #define INVALID_SOCKET_VAL -1
  367. #define CLOSE_SOCKET(s) close(s)
  368. #endif
  369. native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
  370. if (sock == INVALID_SOCKET_VAL) {
  371. #ifdef _WIN32
  372. WSACleanup();
  373. #endif
  374. return -1;
  375. }
  376. struct sockaddr_in serv_addr;
  377. std::memset(&serv_addr, 0, sizeof(serv_addr));
  378. serv_addr.sin_family = AF_INET;
  379. serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
  380. serv_addr.sin_port = htons(0);
  381. if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
  382. CLOSE_SOCKET(sock);
  383. #ifdef _WIN32
  384. WSACleanup();
  385. #endif
  386. return -1;
  387. }
  388. #ifdef _WIN32
  389. int namelen = sizeof(serv_addr);
  390. #else
  391. socklen_t namelen = sizeof(serv_addr);
  392. #endif
  393. if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
  394. CLOSE_SOCKET(sock);
  395. #ifdef _WIN32
  396. WSACleanup();
  397. #endif
  398. return -1;
  399. }
  400. int port = ntohs(serv_addr.sin_port);
  401. CLOSE_SOCKET(sock);
  402. #ifdef _WIN32
  403. WSACleanup();
  404. #endif
  405. return port;
  406. }
  407. // helper to convert vector<string> to char **
  408. // pointers are only valid as long as the original vector is valid
  409. static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & vec) {
  410. std::vector<char *> result;
  411. result.reserve(vec.size() + 1);
  412. for (const auto & s : vec) {
  413. result.push_back(const_cast<char*>(s.c_str()));
  414. }
  415. result.push_back(nullptr);
  416. return result;
  417. }
  418. std::vector<server_model_meta> server_models::get_all_meta() {
  419. std::lock_guard<std::mutex> lk(mutex);
  420. std::vector<server_model_meta> result;
  421. result.reserve(mapping.size());
  422. for (const auto & [name, inst] : mapping) {
  423. result.push_back(inst.meta);
  424. }
  425. return result;
  426. }
  427. void server_models::unload_lru() {
  428. if (base_params.models_max <= 0) {
  429. return; // no limit
  430. }
  431. // remove one of the servers if we passed the models_max (least recently used - LRU)
  432. std::string lru_model_name = "";
  433. int64_t lru_last_used = ggml_time_ms();
  434. size_t count_active = 0;
  435. {
  436. std::lock_guard<std::mutex> lk(mutex);
  437. for (const auto & m : mapping) {
  438. if (m.second.meta.is_active()) {
  439. count_active++;
  440. if (m.second.meta.last_used < lru_last_used) {
  441. lru_model_name = m.first;
  442. lru_last_used = m.second.meta.last_used;
  443. }
  444. }
  445. }
  446. }
  447. if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) {
  448. SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str());
  449. unload(lru_model_name);
  450. }
  451. }
  452. void server_models::load(const std::string & name) {
  453. if (!has_model(name)) {
  454. throw std::runtime_error("model name=" + name + " is not found");
  455. }
  456. unload_lru();
  457. std::lock_guard<std::mutex> lk(mutex);
  458. auto meta = mapping[name].meta;
  459. if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
  460. SRV_INF("model %s is not ready\n", name.c_str());
  461. return;
  462. }
  463. // prepare new instance info
  464. instance_t inst;
  465. inst.meta = meta;
  466. inst.meta.port = get_free_port();
  467. inst.meta.status = SERVER_MODEL_STATUS_LOADING;
  468. inst.meta.last_used = ggml_time_ms();
  469. if (inst.meta.port <= 0) {
  470. throw std::runtime_error("failed to get a port number");
  471. }
  472. inst.subproc = std::make_shared<subprocess_s>();
  473. {
  474. SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port);
  475. presets.render_args(inst.meta); // update meta.args
  476. std::vector<std::string> child_args = inst.meta.args; // copy
  477. std::vector<std::string> child_env = base_env; // copy
  478. child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));
  479. SRV_INF("%s", "spawning server instance with args:\n");
  480. for (const auto & arg : child_args) {
  481. SRV_INF(" %s\n", arg.c_str());
  482. }
  483. inst.meta.args = child_args; // save for debugging
  484. std::vector<char *> argv = to_char_ptr_array(child_args);
  485. std::vector<char *> envp = to_char_ptr_array(child_env);
  486. int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
  487. int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get());
  488. if (result != 0) {
  489. throw std::runtime_error("failed to spawn server instance");
  490. }
  491. inst.stdin_file = subprocess_stdin(inst.subproc.get());
  492. }
  493. // start a thread to manage the child process
  494. // captured variables are guaranteed to be destroyed only after the thread is joined
  495. inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port]() {
  496. // read stdout/stderr and forward to main server log
  497. FILE * p_stdout_stderr = subprocess_stdout(child_proc.get());
  498. if (p_stdout_stderr) {
  499. char buffer[4096];
  500. while (fgets(buffer, sizeof(buffer), p_stdout_stderr) != nullptr) {
  501. LOG("[%5d] %s", port, buffer);
  502. }
  503. } else {
  504. SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str());
  505. }
  506. // we reach here when the child process exits
  507. int exit_code = 0;
  508. subprocess_join(child_proc.get(), &exit_code);
  509. subprocess_destroy(child_proc.get());
  510. // update PID and status
  511. {
  512. std::lock_guard<std::mutex> lk(mutex);
  513. auto it = mapping.find(name);
  514. if (it != mapping.end()) {
  515. auto & meta = it->second.meta;
  516. meta.exit_code = exit_code;
  517. meta.status = SERVER_MODEL_STATUS_UNLOADED;
  518. }
  519. cv.notify_all();
  520. }
  521. SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
  522. });
  523. // clean up old process/thread if exists
  524. {
  525. auto & old_instance = mapping[name];
  526. // old process should have exited already, but just in case, we clean it up here
  527. if (subprocess_alive(old_instance.subproc.get())) {
  528. SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
  529. subprocess_terminate(old_instance.subproc.get()); // force kill
  530. }
  531. if (old_instance.th.joinable()) {
  532. old_instance.th.join();
  533. }
  534. }
  535. mapping[name] = std::move(inst);
  536. cv.notify_all();
  537. }
  538. static void interrupt_subprocess(FILE * stdin_file) {
  539. // because subprocess.h does not provide a way to send SIGINT,
  540. // we will send a command to the child process to exit gracefully
  541. if (stdin_file) {
  542. fprintf(stdin_file, "%s\n", CMD_EXIT);
  543. fflush(stdin_file);
  544. }
  545. }
  546. void server_models::unload(const std::string & name) {
  547. std::lock_guard<std::mutex> lk(mutex);
  548. auto it = mapping.find(name);
  549. if (it != mapping.end()) {
  550. if (it->second.meta.is_active()) {
  551. SRV_INF("unloading model instance name=%s\n", name.c_str());
  552. interrupt_subprocess(it->second.stdin_file);
  553. // status change will be handled by the managing thread
  554. } else {
  555. SRV_WRN("model instance name=%s is not loaded\n", name.c_str());
  556. }
  557. }
  558. }
  559. void server_models::unload_all() {
  560. std::vector<std::thread> to_join;
  561. {
  562. std::lock_guard<std::mutex> lk(mutex);
  563. for (auto & [name, inst] : mapping) {
  564. if (inst.meta.is_active()) {
  565. SRV_INF("unloading model instance name=%s\n", name.c_str());
  566. interrupt_subprocess(inst.stdin_file);
  567. // status change will be handled by the managing thread
  568. }
  569. // moving the thread to join list to avoid deadlock
  570. to_join.push_back(std::move(inst.th));
  571. }
  572. }
  573. for (auto & th : to_join) {
  574. if (th.joinable()) {
  575. th.join();
  576. }
  577. }
  578. }
  579. void server_models::update_status(const std::string & name, server_model_status status) {
  580. // for now, we only allow updating to LOADED status
  581. if (status != SERVER_MODEL_STATUS_LOADED) {
  582. throw std::runtime_error("invalid status value");
  583. }
  584. auto meta = get_meta(name);
  585. if (meta.has_value()) {
  586. meta->status = status;
  587. update_meta(name, meta.value());
  588. }
  589. }
  590. void server_models::wait_until_loaded(const std::string & name) {
  591. std::unique_lock<std::mutex> lk(mutex);
  592. cv.wait(lk, [this, &name]() {
  593. auto it = mapping.find(name);
  594. if (it != mapping.end()) {
  595. return it->second.meta.status != SERVER_MODEL_STATUS_LOADING;
  596. }
  597. return false;
  598. });
  599. }
  600. bool server_models::ensure_model_loaded(const std::string & name) {
  601. auto meta = get_meta(name);
  602. if (!meta.has_value()) {
  603. throw std::runtime_error("model name=" + name + " is not found");
  604. }
  605. if (meta->status == SERVER_MODEL_STATUS_LOADED) {
  606. return false; // already loaded
  607. }
  608. if (meta->status == SERVER_MODEL_STATUS_UNLOADED) {
  609. SRV_INF("model name=%s is not loaded, loading...\n", name.c_str());
  610. load(name);
  611. }
  612. SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str());
  613. wait_until_loaded(name);
  614. // check final status
  615. meta = get_meta(name);
  616. if (!meta.has_value() || meta->is_failed()) {
  617. throw std::runtime_error("model name=" + name + " failed to load");
  618. }
  619. return true;
  620. }
  621. server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used) {
  622. auto meta = get_meta(name);
  623. if (!meta.has_value()) {
  624. throw std::runtime_error("model name=" + name + " is not found");
  625. }
  626. if (meta->status != SERVER_MODEL_STATUS_LOADED) {
  627. throw std::invalid_argument("model name=" + name + " is not loaded");
  628. }
  629. if (update_last_used) {
  630. std::unique_lock<std::mutex> lk(mutex);
  631. mapping[name].meta.last_used = ggml_time_ms();
  632. }
  633. SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port);
  634. auto proxy = std::make_unique<server_http_proxy>(
  635. method,
  636. CHILD_ADDR,
  637. meta->port,
  638. req.path,
  639. req.headers,
  640. req.body,
  641. req.should_stop);
  642. return proxy;
  643. }
  644. std::thread server_models::setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function<void(int)> & shutdown_handler) {
  645. // send a notification to the router server that a model instance is ready
  646. // TODO @ngxson : use HTTP client from libcommon
  647. httplib::Client cli(base_params.hostname, router_port);
  648. cli.set_connection_timeout(0, 200000); // 200 milliseconds
  649. httplib::Request req;
  650. req.method = "POST";
  651. req.path = "/models/status";
  652. req.set_header("Content-Type", "application/json");
  653. if (!base_params.api_keys.empty()) {
  654. req.set_header("Authorization", "Bearer " + base_params.api_keys[0]);
  655. }
  656. json body;
  657. body["model"] = name;
  658. body["value"] = server_model_status_to_string(SERVER_MODEL_STATUS_LOADED);
  659. req.body = body.dump();
  660. SRV_INF("notifying router server (port=%d) that model %s is ready\n", router_port, name.c_str());
  661. auto result = cli.send(std::move(req));
  662. if (result.error() != httplib::Error::Success) {
  663. auto err_str = httplib::to_string(result.error());
  664. SRV_ERR("failed to notify router server: %s\n", err_str.c_str());
  665. exit(1); // force exit
  666. }
  667. // setup thread for monitoring stdin
  668. return std::thread([shutdown_handler]() {
  669. // wait for EOF on stdin
  670. SRV_INF("%s", "child server monitoring thread started, waiting for EOF on stdin...\n");
  671. bool eof = false;
  672. while (true) {
  673. std::string line;
  674. if (!std::getline(std::cin, line)) {
  675. // EOF detected, that means the router server is unexpectedly exit or killed
  676. eof = true;
  677. break;
  678. }
  679. if (line.find(CMD_EXIT) != std::string::npos) {
  680. SRV_INF("%s", "exit command received, exiting...\n");
  681. shutdown_handler(0);
  682. break;
  683. }
  684. }
  685. if (eof) {
  686. SRV_INF("%s", "EOF on stdin detected, forcing shutdown...\n");
  687. exit(1);
  688. }
  689. });
  690. }
  691. //
  692. // server_models_routes
  693. //
  694. static void res_ok(std::unique_ptr<server_http_res> & res, const json & response_data) {
  695. res->status = 200;
  696. res->data = safe_json_to_str(response_data);
  697. }
  698. static void res_err(std::unique_ptr<server_http_res> & res, const json & error_data) {
  699. res->status = json_value(error_data, "code", 500);
  700. res->data = safe_json_to_str({{ "error", error_data }});
  701. }
  702. static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr<server_http_res> & res) {
  703. if (name.empty()) {
  704. res_err(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST));
  705. return false;
  706. }
  707. auto meta = models.get_meta(name);
  708. if (!meta.has_value()) {
  709. res_err(res, format_error_response("model not found", ERROR_TYPE_INVALID_REQUEST));
  710. return false;
  711. }
  712. if (models_autoload) {
  713. models.ensure_model_loaded(name);
  714. } else {
  715. if (meta->status != SERVER_MODEL_STATUS_LOADED) {
  716. res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
  717. return false;
  718. }
  719. }
  720. return true;
  721. }
  722. static bool is_autoload(const common_params & params, const server_http_req & req) {
  723. std::string autoload = req.get_param("autoload");
  724. if (autoload.empty()) {
  725. return params.models_autoload;
  726. } else {
  727. return autoload == "true" || autoload == "1";
  728. }
  729. }
  730. void server_models_routes::init_routes() {
  731. this->get_router_props = [this](const server_http_req & req) {
  732. std::string name = req.get_param("model");
  733. if (name.empty()) {
  734. // main instance
  735. auto res = std::make_unique<server_http_res>();
  736. res_ok(res, {
  737. // TODO: add support for this on web UI
  738. {"role", "router"},
  739. {"max_instances", 4}, // dummy value for testing
  740. // this is a dummy response to make sure webui doesn't break
  741. {"model_alias", "llama-server"},
  742. {"model_path", "none"},
  743. {"default_generation_settings", {
  744. {"params", json{}},
  745. {"n_ctx", 0},
  746. }},
  747. });
  748. return res;
  749. }
  750. return proxy_get(req);
  751. };
  752. this->proxy_get = [this](const server_http_req & req) {
  753. std::string method = "GET";
  754. std::string name = req.get_param("model");
  755. bool autoload = is_autoload(params, req);
  756. auto error_res = std::make_unique<server_http_res>();
  757. if (!router_validate_model(name, models, autoload, error_res)) {
  758. return error_res;
  759. }
  760. return models.proxy_request(req, method, name, false);
  761. };
  762. this->proxy_post = [this](const server_http_req & req) {
  763. std::string method = "POST";
  764. json body = json::parse(req.body);
  765. std::string name = json_value(body, "model", std::string());
  766. bool autoload = is_autoload(params, req);
  767. auto error_res = std::make_unique<server_http_res>();
  768. if (!router_validate_model(name, models, autoload, error_res)) {
  769. return error_res;
  770. }
  771. return models.proxy_request(req, method, name, true); // update last usage for POST request only
  772. };
  773. this->post_router_models_load = [this](const server_http_req & req) {
  774. auto res = std::make_unique<server_http_res>();
  775. json body = json::parse(req.body);
  776. std::string name = json_value(body, "model", std::string());
  777. auto model = models.get_meta(name);
  778. if (!model.has_value()) {
  779. res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND));
  780. return res;
  781. }
  782. if (model->status == SERVER_MODEL_STATUS_LOADED) {
  783. res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST));
  784. return res;
  785. }
  786. models.load(name);
  787. res_ok(res, {{"success", true}});
  788. return res;
  789. };
  790. // used by child process to notify the router about status change
  791. // TODO @ngxson : maybe implement authentication for this endpoint in the future
  792. this->post_router_models_status = [this](const server_http_req & req) {
  793. auto res = std::make_unique<server_http_res>();
  794. json body = json::parse(req.body);
  795. std::string model = json_value(body, "model", std::string());
  796. std::string value = json_value(body, "value", std::string());
  797. models.update_status(model, server_model_status_from_string(value));
  798. res_ok(res, {{"success", true}});
  799. return res;
  800. };
  801. this->get_router_models = [this](const server_http_req &) {
  802. auto res = std::make_unique<server_http_res>();
  803. json models_json = json::array();
  804. auto all_models = models.get_all_meta();
  805. std::time_t t = std::time(0);
  806. for (const auto & meta : all_models) {
  807. json status {
  808. {"value", server_model_status_to_string(meta.status)},
  809. {"args", meta.args},
  810. };
  811. if (!meta.preset.name.empty()) {
  812. status["preset"] = meta.preset.to_ini();
  813. }
  814. if (meta.is_failed()) {
  815. status["exit_code"] = meta.exit_code;
  816. status["failed"] = true;
  817. }
  818. models_json.push_back(json {
  819. {"id", meta.name},
  820. {"object", "model"}, // for OAI-compat
  821. {"owned_by", "llamacpp"}, // for OAI-compat
  822. {"created", t}, // for OAI-compat
  823. {"in_cache", meta.in_cache},
  824. {"path", meta.path},
  825. {"status", status},
  826. // TODO: add other fields, may require reading GGUF metadata
  827. });
  828. }
  829. res_ok(res, {
  830. {"data", models_json},
  831. {"object", "list"},
  832. });
  833. return res;
  834. };
  835. this->post_router_models_unload = [this](const server_http_req & req) {
  836. auto res = std::make_unique<server_http_res>();
  837. json body = json::parse(req.body);
  838. std::string name = json_value(body, "model", std::string());
  839. auto model = models.get_meta(name);
  840. if (!model.has_value()) {
  841. res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
  842. return res;
  843. }
  844. if (model->status != SERVER_MODEL_STATUS_LOADED) {
  845. res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
  846. return res;
  847. }
  848. models.unload(name);
  849. res_ok(res, {{"success", true}});
  850. return res;
  851. };
  852. }
  853. //
  854. // server_http_proxy
  855. //
  856. // simple implementation of a pipe
  857. // used for streaming data between threads
  858. template<typename T>
  859. struct pipe_t {
  860. std::mutex mutex;
  861. std::condition_variable cv;
  862. std::queue<T> queue;
  863. std::atomic<bool> writer_closed{false};
  864. std::atomic<bool> reader_closed{false};
  865. void close_write() {
  866. writer_closed.store(true, std::memory_order_relaxed);
  867. cv.notify_all();
  868. }
  869. void close_read() {
  870. reader_closed.store(true, std::memory_order_relaxed);
  871. cv.notify_all();
  872. }
  873. bool read(T & output, const std::function<bool()> & should_stop) {
  874. std::unique_lock<std::mutex> lk(mutex);
  875. constexpr auto poll_interval = std::chrono::milliseconds(500);
  876. while (true) {
  877. if (!queue.empty()) {
  878. output = std::move(queue.front());
  879. queue.pop();
  880. return true;
  881. }
  882. if (writer_closed.load()) {
  883. return false; // clean EOF
  884. }
  885. if (should_stop()) {
  886. close_read(); // signal broken pipe to writer
  887. return false; // cancelled / reader no longer alive
  888. }
  889. cv.wait_for(lk, poll_interval);
  890. }
  891. }
  892. bool write(T && data) {
  893. std::lock_guard<std::mutex> lk(mutex);
  894. if (reader_closed.load()) {
  895. return false; // broken pipe
  896. }
  897. queue.push(std::move(data));
  898. cv.notify_one();
  899. return true;
  900. }
  901. };
  902. static std::string to_lower_copy(const std::string & value) {
  903. std::string lowered(value.size(), '\0');
  904. std::transform(value.begin(), value.end(), lowered.begin(), [](unsigned char c) { return std::tolower(c); });
  905. return lowered;
  906. }
  907. static bool should_strip_proxy_header(const std::string & header_name) {
  908. // Headers that get duplicated when router forwards child responses
  909. if (header_name == "server" ||
  910. header_name == "transfer-encoding" ||
  911. header_name == "content-length" || // quick fix for https://github.com/ggml-org/llama.cpp/issues/17710
  912. header_name == "keep-alive") {
  913. return true;
  914. }
  915. // Router injects CORS, child also sends them: duplicate
  916. if (header_name.rfind("access-control-", 0) == 0) {
  917. return true;
  918. }
  919. return false;
  920. }
  921. server_http_proxy::server_http_proxy(
  922. const std::string & method,
  923. const std::string & host,
  924. int port,
  925. const std::string & path,
  926. const std::map<std::string, std::string> & headers,
  927. const std::string & body,
  928. const std::function<bool()> should_stop) {
  929. // shared between reader and writer threads
  930. auto cli = std::make_shared<httplib::Client>(host, port);
  931. auto pipe = std::make_shared<pipe_t<msg_t>>();
  932. // setup Client
  933. cli->set_connection_timeout(0, 200000); // 200 milliseconds
  934. this->status = 500; // to be overwritten upon response
  935. this->cleanup = [pipe]() {
  936. pipe->close_read();
  937. pipe->close_write();
  938. };
  939. // wire up the receive end of the pipe
  940. this->next = [pipe, should_stop](std::string & out) -> bool {
  941. msg_t msg;
  942. bool has_next = pipe->read(msg, should_stop);
  943. if (!msg.data.empty()) {
  944. out = std::move(msg.data);
  945. }
  946. return has_next; // false if EOF or pipe broken
  947. };
  948. // wire up the HTTP client
  949. // note: do NOT capture `this` pointer, as it may be destroyed before the thread ends
  950. httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) {
  951. msg_t msg;
  952. msg.status = response.status;
  953. for (const auto & [key, value] : response.headers) {
  954. const auto lowered = to_lower_copy(key);
  955. if (should_strip_proxy_header(lowered)) {
  956. continue;
  957. }
  958. if (lowered == "content-type") {
  959. msg.content_type = value;
  960. continue;
  961. }
  962. msg.headers[key] = value;
  963. }
  964. return pipe->write(std::move(msg)); // send headers first
  965. };
  966. httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) {
  967. // send data chunks
  968. // returns false if pipe is closed / broken (signal to stop receiving)
  969. return pipe->write({{}, 0, std::string(data, data_length), ""});
  970. };
  971. // prepare the request to destination server
  972. httplib::Request req;
  973. {
  974. req.method = method;
  975. req.path = path;
  976. for (const auto & [key, value] : headers) {
  977. req.set_header(key, value);
  978. }
  979. req.body = body;
  980. req.response_handler = response_handler;
  981. req.content_receiver = content_receiver;
  982. }
  983. // start the proxy thread
  984. SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str());
  985. this->thread = std::thread([cli, pipe, req]() {
  986. auto result = cli->send(std::move(req));
  987. if (result.error() != httplib::Error::Success) {
  988. auto err_str = httplib::to_string(result.error());
  989. SRV_ERR("http client error: %s\n", err_str.c_str());
  990. pipe->write({{}, 500, "", ""}); // header
  991. pipe->write({{}, 0, "proxy error: " + err_str, ""}); // body
  992. }
  993. pipe->close_write(); // signal EOF to reader
  994. SRV_DBG("%s", "client request thread ended\n");
  995. });
  996. this->thread.detach();
  997. // wait for the first chunk (headers)
  998. {
  999. msg_t header;
  1000. if (pipe->read(header, should_stop)) {
  1001. SRV_DBG("%s", "received response headers\n");
  1002. this->status = header.status;
  1003. this->headers = std::move(header.headers);
  1004. if (!header.content_type.empty()) {
  1005. this->content_type = std::move(header.content_type);
  1006. }
  1007. } else {
  1008. SRV_DBG("%s", "no response headers received (request cancelled?)\n");
  1009. }
  1010. }
  1011. }