server-models.h 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. #pragma once
  2. #include "common.h"
  3. #include "server-http.h"
  4. #include <mutex>
  5. #include <condition_variable>
  6. #include <functional>
  7. #include <memory>
  8. /**
  9. * state diagram:
  10. *
  11. * UNLOADED ──► LOADING ──► LOADED
  12. * ▲ │ │
  13. * └───failed───┘ │
  14. * ▲ │
  15. * └────────unloaded─────────┘
  16. */
  17. enum server_model_status {
  18. // TODO: also add downloading state when the logic is added
  19. SERVER_MODEL_STATUS_UNLOADED,
  20. SERVER_MODEL_STATUS_LOADING,
  21. SERVER_MODEL_STATUS_LOADED
  22. };
  23. static server_model_status server_model_status_from_string(const std::string & status_str) {
  24. if (status_str == "unloaded") {
  25. return SERVER_MODEL_STATUS_UNLOADED;
  26. }
  27. if (status_str == "loading") {
  28. return SERVER_MODEL_STATUS_LOADING;
  29. }
  30. if (status_str == "loaded") {
  31. return SERVER_MODEL_STATUS_LOADED;
  32. }
  33. throw std::runtime_error("invalid server model status");
  34. }
  35. static std::string server_model_status_to_string(server_model_status status) {
  36. switch (status) {
  37. case SERVER_MODEL_STATUS_UNLOADED: return "unloaded";
  38. case SERVER_MODEL_STATUS_LOADING: return "loading";
  39. case SERVER_MODEL_STATUS_LOADED: return "loaded";
  40. default: return "unknown";
  41. }
  42. }
  43. struct server_model_meta {
  44. std::string name;
  45. std::string path;
  46. std::string path_mmproj; // only available if in_cache=false
  47. bool in_cache = false; // if true, use -hf; use -m otherwise
  48. int port = 0;
  49. server_model_status status = SERVER_MODEL_STATUS_UNLOADED;
  50. int64_t last_used = 0; // for LRU unloading
  51. std::vector<std::string> args; // additional args passed to the model instance (used for debugging)
  52. int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED)
  53. bool is_active() const {
  54. return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING;
  55. }
  56. bool is_failed() const {
  57. return status == SERVER_MODEL_STATUS_UNLOADED && exit_code != 0;
  58. }
  59. };
  60. struct subprocess_s;
  61. struct server_models {
  62. private:
  63. struct instance_t {
  64. std::shared_ptr<subprocess_s> subproc; // shared between main thread and monitoring thread
  65. std::thread th;
  66. server_model_meta meta;
  67. FILE * stdin_file = nullptr;
  68. };
  69. std::mutex mutex;
  70. std::condition_variable cv;
  71. std::map<std::string, instance_t> mapping;
  72. common_params base_params;
  73. std::vector<std::string> base_args;
  74. std::vector<std::string> base_env;
  75. void update_meta(const std::string & name, const server_model_meta & meta);
  76. // unload least recently used models if the limit is reached
  77. void unload_lru();
  78. public:
  79. server_models(const common_params & params, int argc, char ** argv, char ** envp);
  80. // check if a model instance exists
  81. bool has_model(const std::string & name);
  82. // return a copy of model metadata
  83. std::optional<server_model_meta> get_meta(const std::string & name);
  84. // return a copy of all model metadata
  85. std::vector<server_model_meta> get_all_meta();
  86. // if auto_load is true, load the model with previous args if any
  87. void load(const std::string & name, bool auto_load);
  88. void unload(const std::string & name);
  89. void unload_all();
  90. // update the status of a model instance
  91. void update_status(const std::string & name, server_model_status status);
  92. // wait until the model instance is fully loaded
  93. // return when the model is loaded or failed to load
  94. void wait_until_loaded(const std::string & name);
  95. // load the model if not loaded, otherwise do nothing
  96. // return false if model is already loaded; return true otherwise (meta may need to be refreshed)
  97. bool ensure_model_loaded(const std::string & name);
  98. // proxy an HTTP request to the model instance
  99. server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used);
  100. // notify the router server that a model instance is ready
  101. // return the monitoring thread (to be joined by the caller)
  102. static std::thread setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function<void(int)> & shutdown_handler);
  103. };
  104. struct server_models_routes {
  105. common_params params;
  106. server_models models;
  107. server_models_routes(const common_params & params, int argc, char ** argv, char ** envp)
  108. : params(params), models(params, argc, argv, envp) {
  109. init_routes();
  110. }
  111. void init_routes();
  112. // handlers using lambda function, so that they can capture `this` without `std::bind`
  113. server_http_context::handler_t get_router_props;
  114. server_http_context::handler_t proxy_get;
  115. server_http_context::handler_t proxy_post;
  116. server_http_context::handler_t get_router_models;
  117. server_http_context::handler_t post_router_models_load;
  118. server_http_context::handler_t post_router_models_status;
  119. server_http_context::handler_t post_router_models_unload;
  120. };
  121. /**
  122. * A simple HTTP proxy that forwards requests to another server
  123. * and relays the responses back.
  124. */
  125. struct server_http_proxy : server_http_res {
  126. std::function<void()> cleanup = nullptr;
  127. public:
  128. server_http_proxy(const std::string & method,
  129. const std::string & host,
  130. int port,
  131. const std::string & path,
  132. const std::map<std::string, std::string> & headers,
  133. const std::string & body,
  134. const std::function<bool()> should_stop);
  135. ~server_http_proxy() {
  136. if (cleanup) {
  137. cleanup();
  138. }
  139. }
  140. private:
  141. std::thread thread;
  142. struct msg_t {
  143. std::map<std::string, std::string> headers;
  144. int status = 0;
  145. std::string data;
  146. };
  147. };