server-models.h 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. #pragma once
  2. #include "common.h"
  3. #include "preset.h"
  4. #include "server-common.h"
  5. #include "server-http.h"
  6. #include <mutex>
  7. #include <condition_variable>
  8. #include <functional>
  9. #include <memory>
  10. #include <set>
  11. /**
  12. * state diagram:
  13. *
  14. * UNLOADED ──► LOADING ──► LOADED
  15. * ▲ │ │
  16. * └───failed───┘ │
  17. * ▲ │
  18. * └────────unloaded─────────┘
  19. */
  20. enum server_model_status {
  21. // TODO: also add downloading state when the logic is added
  22. SERVER_MODEL_STATUS_UNLOADED,
  23. SERVER_MODEL_STATUS_LOADING,
  24. SERVER_MODEL_STATUS_LOADED
  25. };
  26. static server_model_status server_model_status_from_string(const std::string & status_str) {
  27. if (status_str == "unloaded") {
  28. return SERVER_MODEL_STATUS_UNLOADED;
  29. }
  30. if (status_str == "loading") {
  31. return SERVER_MODEL_STATUS_LOADING;
  32. }
  33. if (status_str == "loaded") {
  34. return SERVER_MODEL_STATUS_LOADED;
  35. }
  36. throw std::runtime_error("invalid server model status");
  37. }
  38. static std::string server_model_status_to_string(server_model_status status) {
  39. switch (status) {
  40. case SERVER_MODEL_STATUS_UNLOADED: return "unloaded";
  41. case SERVER_MODEL_STATUS_LOADING: return "loading";
  42. case SERVER_MODEL_STATUS_LOADED: return "loaded";
  43. default: return "unknown";
  44. }
  45. }
  46. struct server_model_meta {
  47. common_preset preset;
  48. std::string name;
  49. int port = 0;
  50. server_model_status status = SERVER_MODEL_STATUS_UNLOADED;
  51. int64_t last_used = 0; // for LRU unloading
  52. std::vector<std::string> args; // args passed to the model instance, will be populated by render_args()
  53. int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
  54. int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown
  55. bool is_active() const {
  56. return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING;
  57. }
  58. bool is_failed() const {
  59. return status == SERVER_MODEL_STATUS_UNLOADED && exit_code != 0;
  60. }
  61. void update_args(common_preset_context & ctx_presets, std::string bin_path);
  62. };
  63. struct subprocess_s;
  64. struct server_models {
  65. private:
  66. struct instance_t {
  67. std::shared_ptr<subprocess_s> subproc; // shared between main thread and monitoring thread
  68. std::thread th;
  69. server_model_meta meta;
  70. FILE * stdin_file = nullptr;
  71. };
  72. std::mutex mutex;
  73. std::condition_variable cv;
  74. std::map<std::string, instance_t> mapping;
  75. // for stopping models
  76. std::condition_variable cv_stop;
  77. std::set<std::string> stopping_models;
  78. common_preset_context ctx_preset;
  79. common_params base_params;
  80. std::string bin_path;
  81. std::vector<std::string> base_env;
  82. common_preset base_preset; // base preset from llama-server CLI args
  83. void update_meta(const std::string & name, const server_model_meta & meta);
  84. // unload least recently used models if the limit is reached
  85. void unload_lru();
  86. // not thread-safe, caller must hold mutex
  87. void add_model(server_model_meta && meta);
  88. public:
  89. server_models(const common_params & params, int argc, char ** argv);
  90. void load_models();
  91. // check if a model instance exists (thread-safe)
  92. bool has_model(const std::string & name);
  93. // return a copy of model metadata (thread-safe)
  94. std::optional<server_model_meta> get_meta(const std::string & name);
  95. // return a copy of all model metadata (thread-safe)
  96. std::vector<server_model_meta> get_all_meta();
  97. // load and unload model instances
  98. // these functions are thread-safe
  99. void load(const std::string & name);
  100. void unload(const std::string & name);
  101. void unload_all();
  102. // update the status of a model instance (thread-safe)
  103. void update_status(const std::string & name, server_model_status status, int exit_code);
  104. // wait until the model instance is fully loaded (thread-safe)
  105. // return when the model is loaded or failed to load
  106. void wait_until_loaded(const std::string & name);
  107. // load the model if not loaded, otherwise do nothing (thread-safe)
  108. // return false if model is already loaded; return true otherwise (meta may need to be refreshed)
  109. bool ensure_model_loaded(const std::string & name);
  110. // proxy an HTTP request to the model instance
  111. server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used);
  112. // notify the router server that a model instance is ready
  113. // return the monitoring thread (to be joined by the caller)
  114. static std::thread setup_child_server(const std::function<void(int)> & shutdown_handler);
  115. };
  116. struct server_models_routes {
  117. common_params params;
  118. json webui_settings = json::object();
  119. server_models models;
  120. server_models_routes(const common_params & params, int argc, char ** argv)
  121. : params(params), models(params, argc, argv) {
  122. if (!this->params.webui_config_json.empty()) {
  123. try {
  124. webui_settings = json::parse(this->params.webui_config_json);
  125. } catch (const std::exception & e) {
  126. LOG_ERR("%s: failed to parse webui config: %s\n", __func__, e.what());
  127. throw;
  128. }
  129. }
  130. init_routes();
  131. }
  132. void init_routes();
  133. // handlers using lambda function, so that they can capture `this` without `std::bind`
  134. server_http_context::handler_t get_router_props;
  135. server_http_context::handler_t proxy_get;
  136. server_http_context::handler_t proxy_post;
  137. server_http_context::handler_t get_router_models;
  138. server_http_context::handler_t post_router_models_load;
  139. server_http_context::handler_t post_router_models_unload;
  140. };
  141. /**
  142. * A simple HTTP proxy that forwards requests to another server
  143. * and relays the responses back.
  144. */
  145. struct server_http_proxy : server_http_res {
  146. std::function<void()> cleanup = nullptr;
  147. public:
  148. server_http_proxy(const std::string & method,
  149. const std::string & host,
  150. int port,
  151. const std::string & path,
  152. const std::map<std::string, std::string> & headers,
  153. const std::string & body,
  154. const std::function<bool()> should_stop,
  155. int32_t timeout_read,
  156. int32_t timeout_write
  157. );
  158. ~server_http_proxy() {
  159. if (cleanup) {
  160. cleanup();
  161. }
  162. }
  163. private:
  164. std::thread thread;
  165. struct msg_t {
  166. std::map<std::string, std::string> headers;
  167. int status = 0;
  168. std::string data;
  169. std::string content_type;
  170. };
  171. };