server-models.cpp 38 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109
  1. #include "server-common.h"
  2. #include "server-models.h"
  3. #include "preset.h"
  4. #include "download.h"
  5. #include <cpp-httplib/httplib.h> // TODO: remove this once we use HTTP client from download.h
  6. #include <sheredom/subprocess.h>
  7. #include <functional>
  8. #include <algorithm>
  9. #include <thread>
  10. #include <mutex>
  11. #include <condition_variable>
  12. #include <cstring>
  13. #include <atomic>
  14. #include <chrono>
  15. #include <queue>
  16. #ifdef _WIN32
  17. #include <winsock2.h>
  18. #else
  19. #include <sys/socket.h>
  20. #include <netinet/in.h>
  21. #include <arpa/inet.h>
  22. #include <unistd.h>
  23. #endif
  24. #if defined(__APPLE__) && defined(__MACH__)
  25. // macOS: use _NSGetExecutablePath to get the executable path
  26. #include <mach-o/dyld.h>
  27. #include <limits.h>
  28. #endif
  29. #define CMD_EXIT "exit"
  30. // address for child process, this is needed because router may run on 0.0.0.0
  31. // ref: https://github.com/ggml-org/llama.cpp/issues/17862
  32. #define CHILD_ADDR "127.0.0.1"
  33. static std::filesystem::path get_server_exec_path() {
  34. #if defined(_WIN32)
  35. wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths
  36. DWORD len = GetModuleFileNameW(nullptr, buf, _countof(buf));
  37. if (len == 0 || len >= _countof(buf)) {
  38. throw std::runtime_error("GetModuleFileNameW failed or path too long");
  39. }
  40. return std::filesystem::path(buf);
  41. #elif defined(__APPLE__) && defined(__MACH__)
  42. char small_path[PATH_MAX];
  43. uint32_t size = sizeof(small_path);
  44. if (_NSGetExecutablePath(small_path, &size) == 0) {
  45. // resolve any symlinks to get absolute path
  46. try {
  47. return std::filesystem::canonical(std::filesystem::path(small_path));
  48. } catch (...) {
  49. return std::filesystem::path(small_path);
  50. }
  51. } else {
  52. // buffer was too small, allocate required size and call again
  53. std::vector<char> buf(size);
  54. if (_NSGetExecutablePath(buf.data(), &size) == 0) {
  55. try {
  56. return std::filesystem::canonical(std::filesystem::path(buf.data()));
  57. } catch (...) {
  58. return std::filesystem::path(buf.data());
  59. }
  60. }
  61. throw std::runtime_error("_NSGetExecutablePath failed after buffer resize");
  62. }
  63. #else
  64. char path[FILENAME_MAX];
  65. ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX);
  66. if (count <= 0) {
  67. throw std::runtime_error("failed to resolve /proc/self/exe");
  68. }
  69. return std::filesystem::path(std::string(path, count));
  70. #endif
  71. }
  72. struct local_model {
  73. std::string name;
  74. std::string path;
  75. std::string path_mmproj;
  76. };
  77. static std::vector<local_model> list_local_models(const std::string & dir) {
  78. if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) {
  79. throw std::runtime_error(string_format("error: '%s' does not exist or is not a directory\n", dir.c_str()));
  80. }
  81. std::vector<local_model> models;
  82. auto scan_subdir = [&models](const std::string & subdir_path, const std::string & name) {
  83. auto files = fs_list(subdir_path, false);
  84. common_file_info model_file;
  85. common_file_info first_shard_file;
  86. common_file_info mmproj_file;
  87. for (const auto & file : files) {
  88. if (string_ends_with(file.name, ".gguf")) {
  89. if (file.name.find("mmproj") != std::string::npos) {
  90. mmproj_file = file;
  91. } else if (file.name.find("-00001-of-") != std::string::npos) {
  92. first_shard_file = file;
  93. } else {
  94. model_file = file;
  95. }
  96. }
  97. }
  98. // single file model
  99. local_model model{
  100. /* name */ name,
  101. /* path */ first_shard_file.path.empty() ? model_file.path : first_shard_file.path,
  102. /* path_mmproj */ mmproj_file.path // can be empty
  103. };
  104. if (!model.path.empty()) {
  105. models.push_back(model);
  106. }
  107. };
  108. auto files = fs_list(dir, true);
  109. for (const auto & file : files) {
  110. if (file.is_dir) {
  111. scan_subdir(file.path, file.name);
  112. } else if (string_ends_with(file.name, ".gguf")) {
  113. // single file model
  114. std::string name = file.name;
  115. string_replace_all(name, ".gguf", "");
  116. local_model model{
  117. /* name */ name,
  118. /* path */ file.path,
  119. /* path_mmproj */ ""
  120. };
  121. models.push_back(model);
  122. }
  123. }
  124. return models;
  125. }
  126. //
  127. // server_presets
  128. //
  129. server_presets::server_presets(int argc, char ** argv, common_params & base_params, const std::string & presets_path)
  130. : ctx_params(common_params_parser_init(base_params, LLAMA_EXAMPLE_SERVER)) {
  131. if (!presets_path.empty()) {
  132. presets = common_presets_load(presets_path, ctx_params);
  133. SRV_INF("Loaded %zu presets from %s\n", presets.size(), presets_path.c_str());
  134. }
  135. // populate reserved args (will be appended by the router)
  136. for (auto & opt : ctx_params.options) {
  137. if (opt.env == nullptr) {
  138. continue;
  139. }
  140. std::string env = opt.env;
  141. if (env == "LLAMA_ARG_PORT" ||
  142. env == "LLAMA_ARG_HOST" ||
  143. env == "LLAMA_ARG_ALIAS" ||
  144. env == "LLAMA_ARG_API_KEY" ||
  145. env == "LLAMA_ARG_MODELS_DIR" ||
  146. env == "LLAMA_ARG_MODELS_MAX" ||
  147. env == "LLAMA_ARG_MODELS_PRESET" ||
  148. env == "LLAMA_ARG_MODEL" ||
  149. env == "LLAMA_ARG_MMPROJ" ||
  150. env == "LLAMA_ARG_HF_REPO" ||
  151. env == "LLAMA_ARG_NO_MODELS_AUTOLOAD") {
  152. control_args[env] = opt;
  153. }
  154. }
  155. // read base args from router's argv
  156. common_params_parse(argc, argv, LLAMA_EXAMPLE_SERVER, base_args);
  157. // remove any router-controlled args from base_args
  158. for (const auto & cargs : control_args) {
  159. auto it = base_args.find(cargs.second);
  160. if (it != base_args.end()) {
  161. base_args.erase(it);
  162. }
  163. }
  164. }
  165. common_preset server_presets::get_preset(const std::string & name) {
  166. auto it = presets.find(name);
  167. if (it != presets.end()) {
  168. return it->second;
  169. }
  170. return common_preset();
  171. }
  172. void server_presets::render_args(server_model_meta & meta) {
  173. common_preset preset = meta.preset; // copy
  174. // merging 3 kinds of args:
  175. // 1. model-specific args (from preset)
  176. // force removing control args if any
  177. for (auto & cargs : control_args) {
  178. if (preset.options.find(cargs.second) != preset.options.end()) {
  179. SRV_WRN("Preset '%s' contains reserved arg '%s', removing it\n", preset.name.c_str(), cargs.second.args[0]);
  180. preset.options.erase(cargs.second);
  181. }
  182. }
  183. // 2. base args (from router)
  184. // inherit from base args
  185. for (const auto & [arg, value] : base_args) {
  186. preset.options[arg] = value;
  187. }
  188. // 3. control args (from router)
  189. // set control values
  190. preset.options[control_args["LLAMA_ARG_HOST"]] = CHILD_ADDR;
  191. preset.options[control_args["LLAMA_ARG_PORT"]] = std::to_string(meta.port);
  192. preset.options[control_args["LLAMA_ARG_ALIAS"]] = meta.name;
  193. if (meta.in_cache) {
  194. preset.options[control_args["LLAMA_ARG_HF_REPO"]] = meta.name;
  195. } else {
  196. preset.options[control_args["LLAMA_ARG_MODEL"]] = meta.path;
  197. if (!meta.path_mmproj.empty()) {
  198. preset.options[control_args["LLAMA_ARG_MMPROJ"]] = meta.path_mmproj;
  199. }
  200. }
  201. meta.args = preset.to_args();
  202. // add back the binary path at the front
  203. meta.args.insert(meta.args.begin(), get_server_exec_path().string());
  204. }
  205. //
  206. // server_models
  207. //
  208. server_models::server_models(
  209. const common_params & params,
  210. int argc,
  211. char ** argv,
  212. char ** envp) : base_params(params), presets(argc, argv, base_params, params.models_preset) {
  213. for (int i = 0; i < argc; i++) {
  214. base_args.push_back(std::string(argv[i]));
  215. }
  216. for (char ** env = envp; *env != nullptr; env++) {
  217. base_env.push_back(std::string(*env));
  218. }
  219. GGML_ASSERT(!base_args.empty());
  220. // set binary path
  221. try {
  222. base_args[0] = get_server_exec_path().string();
  223. } catch (const std::exception & e) {
  224. LOG_WRN("failed to get server executable path: %s\n", e.what());
  225. LOG_WRN("using original argv[0] as fallback: %s\n", base_args[0].c_str());
  226. }
  227. load_models();
  228. }
  229. void server_models::add_model(server_model_meta && meta) {
  230. if (mapping.find(meta.name) != mapping.end()) {
  231. throw std::runtime_error(string_format("model '%s' appears multiple times", meta.name.c_str()));
  232. }
  233. presets.render_args(meta); // populate meta.args
  234. std::string name = meta.name;
  235. mapping[name] = instance_t{
  236. /* subproc */ std::make_shared<subprocess_s>(),
  237. /* th */ std::thread(),
  238. /* meta */ std::move(meta)
  239. };
  240. }
  241. static std::vector<local_model> list_custom_path_models(server_presets & presets) {
  242. // detect any custom-path models in presets
  243. std::vector<local_model> custom_models;
  244. for (auto & [model_name, preset] : presets.presets) {
  245. local_model model;
  246. model.name = model_name;
  247. std::vector<common_arg> to_erase;
  248. for (auto & [arg, value] : preset.options) {
  249. std::string env(arg.env ? arg.env : "");
  250. if (env == "LLAMA_ARG_MODEL") {
  251. model.path = value;
  252. to_erase.push_back(arg);
  253. }
  254. if (env == "LLAMA_ARG_MMPROJ") {
  255. model.path_mmproj = value;
  256. to_erase.push_back(arg);
  257. }
  258. }
  259. for (auto & arg : to_erase) {
  260. preset.options.erase(arg);
  261. }
  262. if (!model.name.empty() && !model.path.empty()) {
  263. custom_models.push_back(model);
  264. }
  265. }
  266. return custom_models;
  267. }
  268. // TODO: allow refreshing cached model list
  269. void server_models::load_models() {
  270. // loading models from 3 sources:
  271. // 1. cached models
  272. auto cached_models = common_list_cached_models();
  273. for (const auto & model : cached_models) {
  274. server_model_meta meta{
  275. /* preset */ presets.get_preset(model.to_string()),
  276. /* name */ model.to_string(),
  277. /* path */ model.manifest_path,
  278. /* path_mmproj */ "", // auto-detected when loading
  279. /* in_cache */ true,
  280. /* port */ 0,
  281. /* status */ SERVER_MODEL_STATUS_UNLOADED,
  282. /* last_used */ 0,
  283. /* args */ std::vector<std::string>(),
  284. /* exit_code */ 0
  285. };
  286. add_model(std::move(meta));
  287. }
  288. // 2. local models specificed via --models-dir
  289. if (!base_params.models_dir.empty()) {
  290. auto local_models = list_local_models(base_params.models_dir);
  291. for (const auto & model : local_models) {
  292. if (mapping.find(model.name) != mapping.end()) {
  293. // already exists in cached models, skip
  294. continue;
  295. }
  296. server_model_meta meta{
  297. /* preset */ presets.get_preset(model.name),
  298. /* name */ model.name,
  299. /* path */ model.path,
  300. /* path_mmproj */ model.path_mmproj,
  301. /* in_cache */ false,
  302. /* port */ 0,
  303. /* status */ SERVER_MODEL_STATUS_UNLOADED,
  304. /* last_used */ 0,
  305. /* args */ std::vector<std::string>(),
  306. /* exit_code */ 0
  307. };
  308. add_model(std::move(meta));
  309. }
  310. }
  311. // 3. custom-path models specified in presets
  312. auto custom_models = list_custom_path_models(presets);
  313. for (const auto & model : custom_models) {
  314. server_model_meta meta{
  315. /* preset */ presets.get_preset(model.name),
  316. /* name */ model.name,
  317. /* path */ model.path,
  318. /* path_mmproj */ model.path_mmproj,
  319. /* in_cache */ false,
  320. /* port */ 0,
  321. /* status */ SERVER_MODEL_STATUS_UNLOADED,
  322. /* last_used */ 0,
  323. /* args */ std::vector<std::string>(),
  324. /* exit_code */ 0
  325. };
  326. add_model(std::move(meta));
  327. }
  328. // log available models
  329. SRV_INF("Available models (%zu) (*: custom preset)\n", mapping.size());
  330. for (const auto & [name, inst] : mapping) {
  331. SRV_INF(" %c %s\n", inst.meta.preset.name.empty() ? ' ' : '*', name.c_str());
  332. }
  333. }
  334. void server_models::update_meta(const std::string & name, const server_model_meta & meta) {
  335. std::lock_guard<std::mutex> lk(mutex);
  336. auto it = mapping.find(name);
  337. if (it != mapping.end()) {
  338. it->second.meta = meta;
  339. }
  340. cv.notify_all(); // notify wait_until_loaded
  341. }
  342. bool server_models::has_model(const std::string & name) {
  343. std::lock_guard<std::mutex> lk(mutex);
  344. return mapping.find(name) != mapping.end();
  345. }
  346. std::optional<server_model_meta> server_models::get_meta(const std::string & name) {
  347. std::lock_guard<std::mutex> lk(mutex);
  348. auto it = mapping.find(name);
  349. if (it != mapping.end()) {
  350. return it->second.meta;
  351. }
  352. return std::nullopt;
  353. }
  354. static int get_free_port() {
  355. #ifdef _WIN32
  356. WSADATA wsaData;
  357. if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) {
  358. return -1;
  359. }
  360. typedef SOCKET native_socket_t;
  361. #define INVALID_SOCKET_VAL INVALID_SOCKET
  362. #define CLOSE_SOCKET(s) closesocket(s)
  363. #else
  364. typedef int native_socket_t;
  365. #define INVALID_SOCKET_VAL -1
  366. #define CLOSE_SOCKET(s) close(s)
  367. #endif
  368. native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0);
  369. if (sock == INVALID_SOCKET_VAL) {
  370. #ifdef _WIN32
  371. WSACleanup();
  372. #endif
  373. return -1;
  374. }
  375. struct sockaddr_in serv_addr;
  376. std::memset(&serv_addr, 0, sizeof(serv_addr));
  377. serv_addr.sin_family = AF_INET;
  378. serv_addr.sin_addr.s_addr = htonl(INADDR_ANY);
  379. serv_addr.sin_port = htons(0);
  380. if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) {
  381. CLOSE_SOCKET(sock);
  382. #ifdef _WIN32
  383. WSACleanup();
  384. #endif
  385. return -1;
  386. }
  387. #ifdef _WIN32
  388. int namelen = sizeof(serv_addr);
  389. #else
  390. socklen_t namelen = sizeof(serv_addr);
  391. #endif
  392. if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) {
  393. CLOSE_SOCKET(sock);
  394. #ifdef _WIN32
  395. WSACleanup();
  396. #endif
  397. return -1;
  398. }
  399. int port = ntohs(serv_addr.sin_port);
  400. CLOSE_SOCKET(sock);
  401. #ifdef _WIN32
  402. WSACleanup();
  403. #endif
  404. return port;
  405. }
  406. // helper to convert vector<string> to char **
  407. // pointers are only valid as long as the original vector is valid
  408. static std::vector<char *> to_char_ptr_array(const std::vector<std::string> & vec) {
  409. std::vector<char *> result;
  410. result.reserve(vec.size() + 1);
  411. for (const auto & s : vec) {
  412. result.push_back(const_cast<char*>(s.c_str()));
  413. }
  414. result.push_back(nullptr);
  415. return result;
  416. }
  417. std::vector<server_model_meta> server_models::get_all_meta() {
  418. std::lock_guard<std::mutex> lk(mutex);
  419. std::vector<server_model_meta> result;
  420. result.reserve(mapping.size());
  421. for (const auto & [name, inst] : mapping) {
  422. result.push_back(inst.meta);
  423. }
  424. return result;
  425. }
  426. void server_models::unload_lru() {
  427. if (base_params.models_max <= 0) {
  428. return; // no limit
  429. }
  430. // remove one of the servers if we passed the models_max (least recently used - LRU)
  431. std::string lru_model_name = "";
  432. int64_t lru_last_used = ggml_time_ms();
  433. size_t count_active = 0;
  434. {
  435. std::lock_guard<std::mutex> lk(mutex);
  436. for (const auto & m : mapping) {
  437. if (m.second.meta.is_active()) {
  438. count_active++;
  439. if (m.second.meta.last_used < lru_last_used) {
  440. lru_model_name = m.first;
  441. lru_last_used = m.second.meta.last_used;
  442. }
  443. }
  444. }
  445. }
  446. if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) {
  447. SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str());
  448. unload(lru_model_name);
  449. }
  450. }
  451. void server_models::load(const std::string & name) {
  452. if (!has_model(name)) {
  453. throw std::runtime_error("model name=" + name + " is not found");
  454. }
  455. unload_lru();
  456. std::lock_guard<std::mutex> lk(mutex);
  457. auto meta = mapping[name].meta;
  458. if (meta.status != SERVER_MODEL_STATUS_UNLOADED) {
  459. SRV_INF("model %s is not ready\n", name.c_str());
  460. return;
  461. }
  462. // prepare new instance info
  463. instance_t inst;
  464. inst.meta = meta;
  465. inst.meta.port = get_free_port();
  466. inst.meta.status = SERVER_MODEL_STATUS_LOADING;
  467. inst.meta.last_used = ggml_time_ms();
  468. if (inst.meta.port <= 0) {
  469. throw std::runtime_error("failed to get a port number");
  470. }
  471. inst.subproc = std::make_shared<subprocess_s>();
  472. {
  473. SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port);
  474. presets.render_args(inst.meta); // update meta.args
  475. std::vector<std::string> child_args = inst.meta.args; // copy
  476. std::vector<std::string> child_env = base_env; // copy
  477. child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port));
  478. SRV_INF("%s", "spawning server instance with args:\n");
  479. for (const auto & arg : child_args) {
  480. SRV_INF(" %s\n", arg.c_str());
  481. }
  482. inst.meta.args = child_args; // save for debugging
  483. std::vector<char *> argv = to_char_ptr_array(child_args);
  484. std::vector<char *> envp = to_char_ptr_array(child_env);
  485. int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr;
  486. int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get());
  487. if (result != 0) {
  488. throw std::runtime_error("failed to spawn server instance");
  489. }
  490. inst.stdin_file = subprocess_stdin(inst.subproc.get());
  491. }
  492. // start a thread to manage the child process
  493. // captured variables are guaranteed to be destroyed only after the thread is joined
  494. inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port]() {
  495. // read stdout/stderr and forward to main server log
  496. FILE * p_stdout_stderr = subprocess_stdout(child_proc.get());
  497. if (p_stdout_stderr) {
  498. char buffer[4096];
  499. while (fgets(buffer, sizeof(buffer), p_stdout_stderr) != nullptr) {
  500. LOG("[%5d] %s", port, buffer);
  501. }
  502. } else {
  503. SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str());
  504. }
  505. // we reach here when the child process exits
  506. int exit_code = 0;
  507. subprocess_join(child_proc.get(), &exit_code);
  508. subprocess_destroy(child_proc.get());
  509. // update PID and status
  510. {
  511. std::lock_guard<std::mutex> lk(mutex);
  512. auto it = mapping.find(name);
  513. if (it != mapping.end()) {
  514. auto & meta = it->second.meta;
  515. meta.exit_code = exit_code;
  516. meta.status = SERVER_MODEL_STATUS_UNLOADED;
  517. }
  518. cv.notify_all();
  519. }
  520. SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code);
  521. });
  522. // clean up old process/thread if exists
  523. {
  524. auto & old_instance = mapping[name];
  525. // old process should have exited already, but just in case, we clean it up here
  526. if (subprocess_alive(old_instance.subproc.get())) {
  527. SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str());
  528. subprocess_terminate(old_instance.subproc.get()); // force kill
  529. }
  530. if (old_instance.th.joinable()) {
  531. old_instance.th.join();
  532. }
  533. }
  534. mapping[name] = std::move(inst);
  535. cv.notify_all();
  536. }
  537. static void interrupt_subprocess(FILE * stdin_file) {
  538. // because subprocess.h does not provide a way to send SIGINT,
  539. // we will send a command to the child process to exit gracefully
  540. if (stdin_file) {
  541. fprintf(stdin_file, "%s\n", CMD_EXIT);
  542. fflush(stdin_file);
  543. }
  544. }
  545. void server_models::unload(const std::string & name) {
  546. std::lock_guard<std::mutex> lk(mutex);
  547. auto it = mapping.find(name);
  548. if (it != mapping.end()) {
  549. if (it->second.meta.is_active()) {
  550. SRV_INF("unloading model instance name=%s\n", name.c_str());
  551. interrupt_subprocess(it->second.stdin_file);
  552. // status change will be handled by the managing thread
  553. } else {
  554. SRV_WRN("model instance name=%s is not loaded\n", name.c_str());
  555. }
  556. }
  557. }
  558. void server_models::unload_all() {
  559. std::vector<std::thread> to_join;
  560. {
  561. std::lock_guard<std::mutex> lk(mutex);
  562. for (auto & [name, inst] : mapping) {
  563. if (inst.meta.is_active()) {
  564. SRV_INF("unloading model instance name=%s\n", name.c_str());
  565. interrupt_subprocess(inst.stdin_file);
  566. // status change will be handled by the managing thread
  567. }
  568. // moving the thread to join list to avoid deadlock
  569. to_join.push_back(std::move(inst.th));
  570. }
  571. }
  572. for (auto & th : to_join) {
  573. if (th.joinable()) {
  574. th.join();
  575. }
  576. }
  577. }
  578. void server_models::update_status(const std::string & name, server_model_status status) {
  579. // for now, we only allow updating to LOADED status
  580. if (status != SERVER_MODEL_STATUS_LOADED) {
  581. throw std::runtime_error("invalid status value");
  582. }
  583. auto meta = get_meta(name);
  584. if (meta.has_value()) {
  585. meta->status = status;
  586. update_meta(name, meta.value());
  587. }
  588. }
  589. void server_models::wait_until_loaded(const std::string & name) {
  590. std::unique_lock<std::mutex> lk(mutex);
  591. cv.wait(lk, [this, &name]() {
  592. auto it = mapping.find(name);
  593. if (it != mapping.end()) {
  594. return it->second.meta.status != SERVER_MODEL_STATUS_LOADING;
  595. }
  596. return false;
  597. });
  598. }
  599. bool server_models::ensure_model_loaded(const std::string & name) {
  600. auto meta = get_meta(name);
  601. if (!meta.has_value()) {
  602. throw std::runtime_error("model name=" + name + " is not found");
  603. }
  604. if (meta->status == SERVER_MODEL_STATUS_LOADED) {
  605. return false; // already loaded
  606. }
  607. if (meta->status == SERVER_MODEL_STATUS_UNLOADED) {
  608. SRV_INF("model name=%s is not loaded, loading...\n", name.c_str());
  609. load(name);
  610. }
  611. SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str());
  612. wait_until_loaded(name);
  613. // check final status
  614. meta = get_meta(name);
  615. if (!meta.has_value() || meta->is_failed()) {
  616. throw std::runtime_error("model name=" + name + " failed to load");
  617. }
  618. return true;
  619. }
  620. server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used) {
  621. auto meta = get_meta(name);
  622. if (!meta.has_value()) {
  623. throw std::runtime_error("model name=" + name + " is not found");
  624. }
  625. if (meta->status != SERVER_MODEL_STATUS_LOADED) {
  626. throw std::invalid_argument("model name=" + name + " is not loaded");
  627. }
  628. if (update_last_used) {
  629. std::unique_lock<std::mutex> lk(mutex);
  630. mapping[name].meta.last_used = ggml_time_ms();
  631. }
  632. SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port);
  633. auto proxy = std::make_unique<server_http_proxy>(
  634. method,
  635. CHILD_ADDR,
  636. meta->port,
  637. req.path,
  638. req.headers,
  639. req.body,
  640. req.should_stop);
  641. return proxy;
  642. }
  643. std::thread server_models::setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function<void(int)> & shutdown_handler) {
  644. // send a notification to the router server that a model instance is ready
  645. // TODO @ngxson : use HTTP client from libcommon
  646. httplib::Client cli(base_params.hostname, router_port);
  647. cli.set_connection_timeout(0, 200000); // 200 milliseconds
  648. httplib::Request req;
  649. req.method = "POST";
  650. req.path = "/models/status";
  651. req.set_header("Content-Type", "application/json");
  652. if (!base_params.api_keys.empty()) {
  653. req.set_header("Authorization", "Bearer " + base_params.api_keys[0]);
  654. }
  655. json body;
  656. body["model"] = name;
  657. body["value"] = server_model_status_to_string(SERVER_MODEL_STATUS_LOADED);
  658. req.body = body.dump();
  659. SRV_INF("notifying router server (port=%d) that model %s is ready\n", router_port, name.c_str());
  660. auto result = cli.send(std::move(req));
  661. if (result.error() != httplib::Error::Success) {
  662. auto err_str = httplib::to_string(result.error());
  663. SRV_ERR("failed to notify router server: %s\n", err_str.c_str());
  664. exit(1); // force exit
  665. }
  666. // setup thread for monitoring stdin
  667. return std::thread([shutdown_handler]() {
  668. // wait for EOF on stdin
  669. SRV_INF("%s", "child server monitoring thread started, waiting for EOF on stdin...\n");
  670. bool eof = false;
  671. while (true) {
  672. std::string line;
  673. if (!std::getline(std::cin, line)) {
  674. // EOF detected, that means the router server is unexpectedly exit or killed
  675. eof = true;
  676. break;
  677. }
  678. if (line.find(CMD_EXIT) != std::string::npos) {
  679. SRV_INF("%s", "exit command received, exiting...\n");
  680. shutdown_handler(0);
  681. break;
  682. }
  683. }
  684. if (eof) {
  685. SRV_INF("%s", "EOF on stdin detected, forcing shutdown...\n");
  686. exit(1);
  687. }
  688. });
  689. }
  690. //
  691. // server_models_routes
  692. //
  693. static void res_ok(std::unique_ptr<server_http_res> & res, const json & response_data) {
  694. res->status = 200;
  695. res->data = safe_json_to_str(response_data);
  696. }
  697. static void res_err(std::unique_ptr<server_http_res> & res, const json & error_data) {
  698. res->status = json_value(error_data, "code", 500);
  699. res->data = safe_json_to_str({{ "error", error_data }});
  700. }
  701. static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr<server_http_res> & res) {
  702. if (name.empty()) {
  703. res_err(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST));
  704. return false;
  705. }
  706. auto meta = models.get_meta(name);
  707. if (!meta.has_value()) {
  708. res_err(res, format_error_response("model not found", ERROR_TYPE_INVALID_REQUEST));
  709. return false;
  710. }
  711. if (models_autoload) {
  712. models.ensure_model_loaded(name);
  713. } else {
  714. if (meta->status != SERVER_MODEL_STATUS_LOADED) {
  715. res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
  716. return false;
  717. }
  718. }
  719. return true;
  720. }
  721. static bool is_autoload(const common_params & params, const server_http_req & req) {
  722. std::string autoload = req.get_param("autoload");
  723. if (autoload.empty()) {
  724. return params.models_autoload;
  725. } else {
  726. return autoload == "true" || autoload == "1";
  727. }
  728. }
  729. void server_models_routes::init_routes() {
  730. this->get_router_props = [this](const server_http_req & req) {
  731. std::string name = req.get_param("model");
  732. if (name.empty()) {
  733. // main instance
  734. auto res = std::make_unique<server_http_res>();
  735. res_ok(res, {
  736. // TODO: add support for this on web UI
  737. {"role", "router"},
  738. {"max_instances", 4}, // dummy value for testing
  739. // this is a dummy response to make sure webui doesn't break
  740. {"model_alias", "llama-server"},
  741. {"model_path", "none"},
  742. {"default_generation_settings", {
  743. {"params", json{}},
  744. {"n_ctx", 0},
  745. }},
  746. });
  747. return res;
  748. }
  749. return proxy_get(req);
  750. };
  751. this->proxy_get = [this](const server_http_req & req) {
  752. std::string method = "GET";
  753. std::string name = req.get_param("model");
  754. bool autoload = is_autoload(params, req);
  755. auto error_res = std::make_unique<server_http_res>();
  756. if (!router_validate_model(name, models, autoload, error_res)) {
  757. return error_res;
  758. }
  759. return models.proxy_request(req, method, name, false);
  760. };
  761. this->proxy_post = [this](const server_http_req & req) {
  762. std::string method = "POST";
  763. json body = json::parse(req.body);
  764. std::string name = json_value(body, "model", std::string());
  765. bool autoload = is_autoload(params, req);
  766. auto error_res = std::make_unique<server_http_res>();
  767. if (!router_validate_model(name, models, autoload, error_res)) {
  768. return error_res;
  769. }
  770. return models.proxy_request(req, method, name, true); // update last usage for POST request only
  771. };
  772. this->post_router_models_load = [this](const server_http_req & req) {
  773. auto res = std::make_unique<server_http_res>();
  774. json body = json::parse(req.body);
  775. std::string name = json_value(body, "model", std::string());
  776. auto model = models.get_meta(name);
  777. if (!model.has_value()) {
  778. res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND));
  779. return res;
  780. }
  781. if (model->status == SERVER_MODEL_STATUS_LOADED) {
  782. res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST));
  783. return res;
  784. }
  785. models.load(name);
  786. res_ok(res, {{"success", true}});
  787. return res;
  788. };
  789. // used by child process to notify the router about status change
  790. // TODO @ngxson : maybe implement authentication for this endpoint in the future
  791. this->post_router_models_status = [this](const server_http_req & req) {
  792. auto res = std::make_unique<server_http_res>();
  793. json body = json::parse(req.body);
  794. std::string model = json_value(body, "model", std::string());
  795. std::string value = json_value(body, "value", std::string());
  796. models.update_status(model, server_model_status_from_string(value));
  797. res_ok(res, {{"success", true}});
  798. return res;
  799. };
  800. this->get_router_models = [this](const server_http_req &) {
  801. auto res = std::make_unique<server_http_res>();
  802. json models_json = json::array();
  803. auto all_models = models.get_all_meta();
  804. std::time_t t = std::time(0);
  805. for (const auto & meta : all_models) {
  806. json status {
  807. {"value", server_model_status_to_string(meta.status)},
  808. {"args", meta.args},
  809. };
  810. if (!meta.preset.name.empty()) {
  811. status["preset"] = meta.preset.to_ini();
  812. }
  813. if (meta.is_failed()) {
  814. status["exit_code"] = meta.exit_code;
  815. status["failed"] = true;
  816. }
  817. models_json.push_back(json {
  818. {"id", meta.name},
  819. {"object", "model"}, // for OAI-compat
  820. {"owned_by", "llamacpp"}, // for OAI-compat
  821. {"created", t}, // for OAI-compat
  822. {"in_cache", meta.in_cache},
  823. {"path", meta.path},
  824. {"status", status},
  825. // TODO: add other fields, may require reading GGUF metadata
  826. });
  827. }
  828. res_ok(res, {
  829. {"data", models_json},
  830. {"object", "list"},
  831. });
  832. return res;
  833. };
  834. this->post_router_models_unload = [this](const server_http_req & req) {
  835. auto res = std::make_unique<server_http_res>();
  836. json body = json::parse(req.body);
  837. std::string name = json_value(body, "model", std::string());
  838. auto model = models.get_meta(name);
  839. if (!model.has_value()) {
  840. res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST));
  841. return res;
  842. }
  843. if (model->status != SERVER_MODEL_STATUS_LOADED) {
  844. res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST));
  845. return res;
  846. }
  847. models.unload(name);
  848. res_ok(res, {{"success", true}});
  849. return res;
  850. };
  851. }
  852. //
  853. // server_http_proxy
  854. //
  855. // simple implementation of a pipe
  856. // used for streaming data between threads
  857. template<typename T>
  858. struct pipe_t {
  859. std::mutex mutex;
  860. std::condition_variable cv;
  861. std::queue<T> queue;
  862. std::atomic<bool> writer_closed{false};
  863. std::atomic<bool> reader_closed{false};
  864. void close_write() {
  865. writer_closed.store(true, std::memory_order_relaxed);
  866. cv.notify_all();
  867. }
  868. void close_read() {
  869. reader_closed.store(true, std::memory_order_relaxed);
  870. cv.notify_all();
  871. }
  872. bool read(T & output, const std::function<bool()> & should_stop) {
  873. std::unique_lock<std::mutex> lk(mutex);
  874. constexpr auto poll_interval = std::chrono::milliseconds(500);
  875. while (true) {
  876. if (!queue.empty()) {
  877. output = std::move(queue.front());
  878. queue.pop();
  879. return true;
  880. }
  881. if (writer_closed.load()) {
  882. return false; // clean EOF
  883. }
  884. if (should_stop()) {
  885. close_read(); // signal broken pipe to writer
  886. return false; // cancelled / reader no longer alive
  887. }
  888. cv.wait_for(lk, poll_interval);
  889. }
  890. }
  891. bool write(T && data) {
  892. std::lock_guard<std::mutex> lk(mutex);
  893. if (reader_closed.load()) {
  894. return false; // broken pipe
  895. }
  896. queue.push(std::move(data));
  897. cv.notify_one();
  898. return true;
  899. }
  900. };
  901. static std::string to_lower_copy(const std::string & value) {
  902. std::string lowered(value.size(), '\0');
  903. std::transform(value.begin(), value.end(), lowered.begin(), [](unsigned char c) { return std::tolower(c); });
  904. return lowered;
  905. }
  906. static bool should_strip_proxy_header(const std::string & header_name) {
  907. // Headers that get duplicated when router forwards child responses
  908. if (header_name == "server" ||
  909. header_name == "transfer-encoding" ||
  910. header_name == "content-length" || // quick fix for https://github.com/ggml-org/llama.cpp/issues/17710
  911. header_name == "keep-alive") {
  912. return true;
  913. }
  914. // Router injects CORS, child also sends them: duplicate
  915. if (header_name.rfind("access-control-", 0) == 0) {
  916. return true;
  917. }
  918. return false;
  919. }
  920. server_http_proxy::server_http_proxy(
  921. const std::string & method,
  922. const std::string & host,
  923. int port,
  924. const std::string & path,
  925. const std::map<std::string, std::string> & headers,
  926. const std::string & body,
  927. const std::function<bool()> should_stop) {
  928. // shared between reader and writer threads
  929. auto cli = std::make_shared<httplib::Client>(host, port);
  930. auto pipe = std::make_shared<pipe_t<msg_t>>();
  931. // setup Client
  932. cli->set_connection_timeout(0, 200000); // 200 milliseconds
  933. this->status = 500; // to be overwritten upon response
  934. this->cleanup = [pipe]() {
  935. pipe->close_read();
  936. pipe->close_write();
  937. };
  938. // wire up the receive end of the pipe
  939. this->next = [pipe, should_stop](std::string & out) -> bool {
  940. msg_t msg;
  941. bool has_next = pipe->read(msg, should_stop);
  942. if (!msg.data.empty()) {
  943. out = std::move(msg.data);
  944. }
  945. return has_next; // false if EOF or pipe broken
  946. };
  947. // wire up the HTTP client
  948. // note: do NOT capture `this` pointer, as it may be destroyed before the thread ends
  949. httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) {
  950. msg_t msg;
  951. msg.status = response.status;
  952. for (const auto & [key, value] : response.headers) {
  953. const auto lowered = to_lower_copy(key);
  954. if (should_strip_proxy_header(lowered)) {
  955. continue;
  956. }
  957. if (lowered == "content-type") {
  958. msg.content_type = value;
  959. continue;
  960. }
  961. msg.headers[key] = value;
  962. }
  963. return pipe->write(std::move(msg)); // send headers first
  964. };
  965. httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) {
  966. // send data chunks
  967. // returns false if pipe is closed / broken (signal to stop receiving)
  968. return pipe->write({{}, 0, std::string(data, data_length), ""});
  969. };
  970. // prepare the request to destination server
  971. httplib::Request req;
  972. {
  973. req.method = method;
  974. req.path = path;
  975. for (const auto & [key, value] : headers) {
  976. req.set_header(key, value);
  977. }
  978. req.body = body;
  979. req.response_handler = response_handler;
  980. req.content_receiver = content_receiver;
  981. }
  982. // start the proxy thread
  983. SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str());
  984. this->thread = std::thread([cli, pipe, req]() {
  985. auto result = cli->send(std::move(req));
  986. if (result.error() != httplib::Error::Success) {
  987. auto err_str = httplib::to_string(result.error());
  988. SRV_ERR("http client error: %s\n", err_str.c_str());
  989. pipe->write({{}, 500, "", ""}); // header
  990. pipe->write({{}, 0, "proxy error: " + err_str, ""}); // body
  991. }
  992. pipe->close_write(); // signal EOF to reader
  993. SRV_DBG("%s", "client request thread ended\n");
  994. });
  995. this->thread.detach();
  996. // wait for the first chunk (headers)
  997. {
  998. msg_t header;
  999. if (pipe->read(header, should_stop)) {
  1000. SRV_DBG("%s", "received response headers\n");
  1001. this->status = header.status;
  1002. this->headers = std::move(header.headers);
  1003. if (!header.content_type.empty()) {
  1004. this->content_type = std::move(header.content_type);
  1005. }
  1006. } else {
  1007. SRV_DBG("%s", "no response headers received (request cancelled?)\n");
  1008. }
  1009. }
  1010. }