rpc-server.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. #if defined(_MSC_VER)
  2. #define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
  3. #endif
  4. #include "ggml-rpc.h"
  5. #ifdef _WIN32
  6. # define NOMINMAX
  7. # define DIRECTORY_SEPARATOR '\\'
  8. # include <locale>
  9. # include <windows.h>
  10. # include <fcntl.h>
  11. # include <io.h>
  12. #else
  13. # define DIRECTORY_SEPARATOR '/'
  14. # include <unistd.h>
  15. # include <sys/stat.h>
  16. #endif
  17. #include <codecvt>
  18. #include <string>
  19. #include <stdio.h>
  20. #include <vector>
  21. #include <filesystem>
  22. #include <algorithm>
  23. #include <thread>
  24. namespace fs = std::filesystem;
  25. // NOTE: this is copied from common.cpp to avoid linking with libcommon
  26. // returns true if successful, false otherwise
  27. static bool fs_create_directory_with_parents(const std::string & path) {
  28. #ifdef _WIN32
  29. std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
  30. std::wstring wpath = converter.from_bytes(path);
  31. // if the path already exists, check whether it's a directory
  32. const DWORD attributes = GetFileAttributesW(wpath.c_str());
  33. if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) {
  34. return true;
  35. }
  36. size_t pos_slash = 0;
  37. // process path from front to back, procedurally creating directories
  38. while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
  39. const std::wstring subpath = wpath.substr(0, pos_slash);
  40. const wchar_t * test = subpath.c_str();
  41. const bool success = CreateDirectoryW(test, NULL);
  42. if (!success) {
  43. const DWORD error = GetLastError();
  44. // if the path already exists, ensure that it's a directory
  45. if (error == ERROR_ALREADY_EXISTS) {
  46. const DWORD attributes = GetFileAttributesW(subpath.c_str());
  47. if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) {
  48. return false;
  49. }
  50. } else {
  51. return false;
  52. }
  53. }
  54. pos_slash += 1;
  55. }
  56. return true;
  57. #else
  58. // if the path already exists, check whether it's a directory
  59. struct stat info;
  60. if (stat(path.c_str(), &info) == 0) {
  61. return S_ISDIR(info.st_mode);
  62. }
  63. size_t pos_slash = 1; // skip leading slashes for directory creation
  64. // process path from front to back, procedurally creating directories
  65. while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) {
  66. const std::string subpath = path.substr(0, pos_slash);
  67. struct stat info;
  68. // if the path already exists, ensure that it's a directory
  69. if (stat(subpath.c_str(), &info) == 0) {
  70. if (!S_ISDIR(info.st_mode)) {
  71. return false;
  72. }
  73. } else {
  74. // create parent directories
  75. const int ret = mkdir(subpath.c_str(), 0755);
  76. if (ret != 0) {
  77. return false;
  78. }
  79. }
  80. pos_slash += 1;
  81. }
  82. return true;
  83. #endif // _WIN32
  84. }
  85. // NOTE: this is copied from common.cpp to avoid linking with libcommon
  86. static std::string fs_get_cache_directory() {
  87. std::string cache_directory = "";
  88. auto ensure_trailing_slash = [](std::string p) {
  89. // Make sure to add trailing slash
  90. if (p.back() != DIRECTORY_SEPARATOR) {
  91. p += DIRECTORY_SEPARATOR;
  92. }
  93. return p;
  94. };
  95. if (getenv("LLAMA_CACHE")) {
  96. cache_directory = std::getenv("LLAMA_CACHE");
  97. } else {
  98. #if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX)
  99. if (std::getenv("XDG_CACHE_HOME")) {
  100. cache_directory = std::getenv("XDG_CACHE_HOME");
  101. } else {
  102. cache_directory = std::getenv("HOME") + std::string("/.cache/");
  103. }
  104. #elif defined(__APPLE__)
  105. cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
  106. #elif defined(_WIN32)
  107. cache_directory = std::getenv("LOCALAPPDATA");
  108. #else
  109. # error Unknown architecture
  110. #endif
  111. cache_directory = ensure_trailing_slash(cache_directory);
  112. cache_directory += "llama.cpp";
  113. }
  114. return ensure_trailing_slash(cache_directory);
  115. }
  116. struct rpc_server_params {
  117. std::string host = "127.0.0.1";
  118. int port = 50052;
  119. size_t backend_mem = 0;
  120. bool use_cache = false;
  121. int n_threads = std::max(1U, std::thread::hardware_concurrency()/2);
  122. std::string device;
  123. };
  124. static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
  125. fprintf(stderr, "Usage: %s [options]\n\n", argv[0]);
  126. fprintf(stderr, "options:\n");
  127. fprintf(stderr, " -h, --help show this help message and exit\n");
  128. fprintf(stderr, " -t, --threads number of threads for the CPU backend (default: %d)\n", params.n_threads);
  129. fprintf(stderr, " -d DEV, --device device to use\n");
  130. fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str());
  131. fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port);
  132. fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n");
  133. fprintf(stderr, " -c, --cache enable local file cache\n");
  134. fprintf(stderr, "\n");
  135. }
  136. static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & params) {
  137. std::string arg;
  138. for (int i = 1; i < argc; i++) {
  139. arg = argv[i];
  140. if (arg == "-H" || arg == "--host") {
  141. if (++i >= argc) {
  142. return false;
  143. }
  144. params.host = argv[i];
  145. } else if (arg == "-t" || arg == "--threads") {
  146. if (++i >= argc) {
  147. return false;
  148. }
  149. params.n_threads = std::stoi(argv[i]);
  150. if (params.n_threads <= 0) {
  151. fprintf(stderr, "error: invalid number of threads: %d\n", params.n_threads);
  152. return false;
  153. }
  154. } else if (arg == "-d" || arg == "--device") {
  155. if (++i >= argc) {
  156. return false;
  157. }
  158. params.device = argv[i];
  159. if (ggml_backend_dev_by_name(params.device.c_str()) == nullptr) {
  160. fprintf(stderr, "error: unknown device: %s\n", params.device.c_str());
  161. fprintf(stderr, "available devices:\n");
  162. for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
  163. auto * dev = ggml_backend_dev_get(i);
  164. size_t free, total;
  165. ggml_backend_dev_memory(dev, &free, &total);
  166. printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
  167. }
  168. return false;
  169. }
  170. } else if (arg == "-p" || arg == "--port") {
  171. if (++i >= argc) {
  172. return false;
  173. }
  174. params.port = std::stoi(argv[i]);
  175. if (params.port <= 0 || params.port > 65535) {
  176. return false;
  177. }
  178. } else if (arg == "-c" || arg == "--cache") {
  179. params.use_cache = true;
  180. } else if (arg == "-m" || arg == "--mem") {
  181. if (++i >= argc) {
  182. return false;
  183. }
  184. params.backend_mem = std::stoul(argv[i]) * 1024 * 1024;
  185. } else if (arg == "-h" || arg == "--help") {
  186. print_usage(argc, argv, params);
  187. exit(0);
  188. } else {
  189. fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
  190. print_usage(argc, argv, params);
  191. exit(0);
  192. }
  193. }
  194. return true;
  195. }
  196. static ggml_backend_t create_backend(const rpc_server_params & params) {
  197. ggml_backend_t backend = nullptr;
  198. if (!params.device.empty()) {
  199. ggml_backend_dev_t dev = ggml_backend_dev_by_name(params.device.c_str());
  200. if (dev) {
  201. backend = ggml_backend_dev_init(dev, nullptr);
  202. if (!backend) {
  203. fprintf(stderr, "Failed to create backend for device %s\n", params.device.c_str());
  204. return nullptr;
  205. }
  206. }
  207. }
  208. // try to initialize a GPU backend first
  209. if (!backend) {
  210. backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_GPU, nullptr);
  211. }
  212. // if there aren't GPU backends fallback to CPU backend
  213. if (!backend) {
  214. backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
  215. }
  216. if (backend) {
  217. fprintf(stderr, "%s: using %s backend\n", __func__, ggml_backend_name(backend));
  218. // set the number of threads
  219. ggml_backend_dev_t dev = ggml_backend_get_device(backend);
  220. ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
  221. if (reg) {
  222. auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
  223. if (ggml_backend_set_n_threads_fn) {
  224. ggml_backend_set_n_threads_fn(backend, params.n_threads);
  225. }
  226. }
  227. }
  228. return backend;
  229. }
  230. static void get_backend_memory(ggml_backend_t backend, size_t * free_mem, size_t * total_mem) {
  231. ggml_backend_dev_t dev = ggml_backend_get_device(backend);
  232. GGML_ASSERT(dev != nullptr);
  233. ggml_backend_dev_memory(dev, free_mem, total_mem);
  234. }
  235. int main(int argc, char * argv[]) {
  236. ggml_backend_load_all();
  237. rpc_server_params params;
  238. if (!rpc_server_params_parse(argc, argv, params)) {
  239. fprintf(stderr, "Invalid parameters\n");
  240. return 1;
  241. }
  242. if (params.host != "127.0.0.1") {
  243. fprintf(stderr, "\n");
  244. fprintf(stderr, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n");
  245. fprintf(stderr, "WARNING: Host ('%s') is != '127.0.0.1'\n", params.host.c_str());
  246. fprintf(stderr, " Never expose the RPC server to an open network!\n");
  247. fprintf(stderr, " This is an experimental feature and is not secure!\n");
  248. fprintf(stderr, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n");
  249. fprintf(stderr, "\n");
  250. }
  251. ggml_backend_t backend = create_backend(params);
  252. if (!backend) {
  253. fprintf(stderr, "Failed to create backend\n");
  254. return 1;
  255. }
  256. std::string endpoint = params.host + ":" + std::to_string(params.port);
  257. size_t free_mem, total_mem;
  258. if (params.backend_mem > 0) {
  259. free_mem = params.backend_mem;
  260. total_mem = params.backend_mem;
  261. } else {
  262. get_backend_memory(backend, &free_mem, &total_mem);
  263. }
  264. const char * cache_dir = nullptr;
  265. std::string cache_dir_str;
  266. if (params.use_cache) {
  267. cache_dir_str = fs_get_cache_directory() + "rpc/";
  268. if (!fs_create_directory_with_parents(cache_dir_str)) {
  269. fprintf(stderr, "Failed to create cache directory: %s\n", cache_dir_str.c_str());
  270. return 1;
  271. }
  272. cache_dir = cache_dir_str.c_str();
  273. }
  274. ggml_backend_reg_t reg = ggml_backend_reg_by_name("RPC");
  275. if (!reg) {
  276. fprintf(stderr, "Failed to find RPC backend\n");
  277. return 1;
  278. }
  279. auto start_server_fn = (decltype(ggml_backend_rpc_start_server)*) ggml_backend_reg_get_proc_address(reg, "ggml_backend_rpc_start_server");
  280. if (!start_server_fn) {
  281. fprintf(stderr, "Failed to obtain RPC backend start server function\n");
  282. return 1;
  283. }
  284. start_server_fn(backend, endpoint.c_str(), cache_dir, free_mem, total_mem);
  285. ggml_backend_free(backend);
  286. return 0;
  287. }