rpc-server.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. #if defined(_MSC_VER)
  2. #define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
  3. #endif
  4. #include "ggml-rpc.h"
  5. #ifdef _WIN32
  6. # define NOMINMAX
  7. # define DIRECTORY_SEPARATOR '\\'
  8. # include <locale>
  9. # include <windows.h>
  10. # include <fcntl.h>
  11. # include <io.h>
  12. #else
  13. # define DIRECTORY_SEPARATOR '/'
  14. # include <unistd.h>
  15. # include <sys/stat.h>
  16. #endif
  17. #include <codecvt>
  18. #include <string>
  19. #include <stdio.h>
  20. #include <vector>
  21. #include <filesystem>
  22. #include <algorithm>
  23. #include <thread>
  24. #include <regex>
  25. namespace fs = std::filesystem;
  26. // NOTE: this is copied from common.cpp to avoid linking with libcommon
  27. // returns true if successful, false otherwise
  28. static bool fs_create_directory_with_parents(const std::string & path) {
  29. #ifdef _WIN32
  30. std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
  31. std::wstring wpath = converter.from_bytes(path);
  32. // if the path already exists, check whether it's a directory
  33. const DWORD attributes = GetFileAttributesW(wpath.c_str());
  34. if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) {
  35. return true;
  36. }
  37. size_t pos_slash = 0;
  38. // process path from front to back, procedurally creating directories
  39. while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
  40. const std::wstring subpath = wpath.substr(0, pos_slash);
  41. const wchar_t * test = subpath.c_str();
  42. const bool success = CreateDirectoryW(test, NULL);
  43. if (!success) {
  44. const DWORD error = GetLastError();
  45. // if the path already exists, ensure that it's a directory
  46. if (error == ERROR_ALREADY_EXISTS) {
  47. const DWORD attributes = GetFileAttributesW(subpath.c_str());
  48. if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) {
  49. return false;
  50. }
  51. } else {
  52. return false;
  53. }
  54. }
  55. pos_slash += 1;
  56. }
  57. return true;
  58. #else
  59. // if the path already exists, check whether it's a directory
  60. struct stat info;
  61. if (stat(path.c_str(), &info) == 0) {
  62. return S_ISDIR(info.st_mode);
  63. }
  64. size_t pos_slash = 1; // skip leading slashes for directory creation
  65. // process path from front to back, procedurally creating directories
  66. while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) {
  67. const std::string subpath = path.substr(0, pos_slash);
  68. struct stat info;
  69. // if the path already exists, ensure that it's a directory
  70. if (stat(subpath.c_str(), &info) == 0) {
  71. if (!S_ISDIR(info.st_mode)) {
  72. return false;
  73. }
  74. } else {
  75. // create parent directories
  76. const int ret = mkdir(subpath.c_str(), 0755);
  77. if (ret != 0) {
  78. return false;
  79. }
  80. }
  81. pos_slash += 1;
  82. }
  83. return true;
  84. #endif // _WIN32
  85. }
  86. // NOTE: this is copied from common.cpp to avoid linking with libcommon
  87. static std::string fs_get_cache_directory() {
  88. std::string cache_directory = "";
  89. auto ensure_trailing_slash = [](std::string p) {
  90. // Make sure to add trailing slash
  91. if (p.back() != DIRECTORY_SEPARATOR) {
  92. p += DIRECTORY_SEPARATOR;
  93. }
  94. return p;
  95. };
  96. if (getenv("LLAMA_CACHE")) {
  97. cache_directory = std::getenv("LLAMA_CACHE");
  98. } else {
  99. #if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__)
  100. if (std::getenv("XDG_CACHE_HOME")) {
  101. cache_directory = std::getenv("XDG_CACHE_HOME");
  102. } else {
  103. cache_directory = std::getenv("HOME") + std::string("/.cache/");
  104. }
  105. #elif defined(__APPLE__)
  106. cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
  107. #elif defined(_WIN32)
  108. cache_directory = std::getenv("LOCALAPPDATA");
  109. #else
  110. # error Unknown architecture
  111. #endif
  112. cache_directory = ensure_trailing_slash(cache_directory);
  113. cache_directory += "llama.cpp";
  114. }
  115. return ensure_trailing_slash(cache_directory);
  116. }
  117. struct rpc_server_params {
  118. std::string host = "127.0.0.1";
  119. int port = 50052;
  120. bool use_cache = false;
  121. int n_threads = std::max(1U, std::thread::hardware_concurrency()/2);
  122. std::vector<std::string> devices;
  123. std::vector<size_t> dev_mem;
  124. };
  125. static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
  126. fprintf(stderr, "Usage: %s [options]\n\n", argv[0]);
  127. fprintf(stderr, "options:\n");
  128. fprintf(stderr, " -h, --help show this help message and exit\n");
  129. fprintf(stderr, " -t, --threads N number of threads for the CPU device (default: %d)\n", params.n_threads);
  130. fprintf(stderr, " -d, --device <dev1,dev2,...> comma-separated list of devices\n");
  131. fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str());
  132. fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port);
  133. fprintf(stderr, " -m, --mem <M1,M2,...> memory size for each device (in MB)\n");
  134. fprintf(stderr, " -c, --cache enable local file cache\n");
  135. fprintf(stderr, "\n");
  136. }
  137. static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & params) {
  138. std::string arg;
  139. for (int i = 1; i < argc; i++) {
  140. arg = argv[i];
  141. if (arg == "-H" || arg == "--host") {
  142. if (++i >= argc) {
  143. return false;
  144. }
  145. params.host = argv[i];
  146. } else if (arg == "-t" || arg == "--threads") {
  147. if (++i >= argc) {
  148. return false;
  149. }
  150. params.n_threads = std::stoi(argv[i]);
  151. if (params.n_threads <= 0) {
  152. fprintf(stderr, "error: invalid number of threads: %d\n", params.n_threads);
  153. return false;
  154. }
  155. } else if (arg == "-d" || arg == "--device") {
  156. if (++i >= argc) {
  157. return false;
  158. }
  159. const std::regex regex{ R"([,/]+)" };
  160. std::string dev_str = argv[i];
  161. std::sregex_token_iterator iter(dev_str.begin(), dev_str.end(), regex, -1);
  162. std::sregex_token_iterator end;
  163. for ( ; iter != end; ++iter) {
  164. try {
  165. params.devices.push_back(*iter);
  166. } catch (const std::exception & ) {
  167. fprintf(stderr, "error: invalid device: %s\n", iter->str().c_str());
  168. return false;
  169. }
  170. }
  171. } else if (arg == "-p" || arg == "--port") {
  172. if (++i >= argc) {
  173. return false;
  174. }
  175. params.port = std::stoi(argv[i]);
  176. if (params.port <= 0 || params.port > 65535) {
  177. return false;
  178. }
  179. } else if (arg == "-c" || arg == "--cache") {
  180. params.use_cache = true;
  181. } else if (arg == "-m" || arg == "--mem") {
  182. if (++i >= argc) {
  183. return false;
  184. }
  185. const std::regex regex{ R"([,/]+)" };
  186. std::string mem_str = argv[i];
  187. std::sregex_token_iterator iter(mem_str.begin(), mem_str.end(), regex, -1);
  188. std::sregex_token_iterator end;
  189. for ( ; iter != end; ++iter) {
  190. try {
  191. size_t mem = std::stoul(*iter) * 1024 * 1024;
  192. params.dev_mem.push_back(mem);
  193. } catch (const std::exception & ) {
  194. fprintf(stderr, "error: invalid memory size: %s\n", iter->str().c_str());
  195. return false;
  196. }
  197. }
  198. } else if (arg == "-h" || arg == "--help") {
  199. print_usage(argc, argv, params);
  200. exit(0);
  201. } else {
  202. fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
  203. print_usage(argc, argv, params);
  204. exit(0);
  205. }
  206. }
  207. return true;
  208. }
  209. static std::vector<ggml_backend_dev_t> get_devices(const rpc_server_params & params) {
  210. std::vector<ggml_backend_dev_t> devices;
  211. if (!params.devices.empty()) {
  212. for (auto device : params.devices) {
  213. ggml_backend_dev_t dev = ggml_backend_dev_by_name(device.c_str());
  214. if (dev) {
  215. devices.push_back(dev);
  216. } else {
  217. fprintf(stderr, "error: unknown device: %s\n", device.c_str());
  218. fprintf(stderr, "available devices:\n");
  219. for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
  220. auto * dev = ggml_backend_dev_get(i);
  221. size_t free, total;
  222. ggml_backend_dev_memory(dev, &free, &total);
  223. printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
  224. }
  225. return {};
  226. }
  227. }
  228. }
  229. // Try non-CPU devices first
  230. if (devices.empty()) {
  231. for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
  232. ggml_backend_dev_t dev = ggml_backend_dev_get(i);
  233. if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) {
  234. devices.push_back(dev);
  235. }
  236. }
  237. }
  238. // If there are no accelerators, fallback to CPU device
  239. if (devices.empty()) {
  240. ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
  241. if (dev) {
  242. devices.push_back(dev);
  243. }
  244. }
  245. return devices;
  246. }
  247. int main(int argc, char * argv[]) {
  248. ggml_backend_load_all();
  249. rpc_server_params params;
  250. if (!rpc_server_params_parse(argc, argv, params)) {
  251. fprintf(stderr, "Invalid parameters\n");
  252. return 1;
  253. }
  254. if (params.host != "127.0.0.1") {
  255. fprintf(stderr, "\n");
  256. fprintf(stderr, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n");
  257. fprintf(stderr, "WARNING: Host ('%s') is != '127.0.0.1'\n", params.host.c_str());
  258. fprintf(stderr, " Never expose the RPC server to an open network!\n");
  259. fprintf(stderr, " This is an experimental feature and is not secure!\n");
  260. fprintf(stderr, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n");
  261. fprintf(stderr, "\n");
  262. }
  263. auto devices = get_devices(params);
  264. if (devices.empty()) {
  265. fprintf(stderr, "No devices found\n");
  266. return 1;
  267. }
  268. std::string endpoint = params.host + ":" + std::to_string(params.port);
  269. std::vector<size_t> free_mem, total_mem;
  270. for (size_t i = 0; i < devices.size(); i++) {
  271. if (i < params.dev_mem.size()) {
  272. free_mem.push_back(params.dev_mem[i]);
  273. total_mem.push_back(params.dev_mem[i]);
  274. } else {
  275. size_t free, total;
  276. ggml_backend_dev_memory(devices[i], &free, &total);
  277. free_mem.push_back(free);
  278. total_mem.push_back(total);
  279. }
  280. }
  281. const char * cache_dir = nullptr;
  282. std::string cache_dir_str;
  283. if (params.use_cache) {
  284. cache_dir_str = fs_get_cache_directory() + "rpc/";
  285. if (!fs_create_directory_with_parents(cache_dir_str)) {
  286. fprintf(stderr, "Failed to create cache directory: %s\n", cache_dir_str.c_str());
  287. return 1;
  288. }
  289. cache_dir = cache_dir_str.c_str();
  290. }
  291. ggml_backend_reg_t reg = ggml_backend_reg_by_name("RPC");
  292. if (!reg) {
  293. fprintf(stderr, "Failed to find RPC backend\n");
  294. return 1;
  295. }
  296. auto start_server_fn = (decltype(ggml_backend_rpc_start_server)*) ggml_backend_reg_get_proc_address(reg, "ggml_backend_rpc_start_server");
  297. if (!start_server_fn) {
  298. fprintf(stderr, "Failed to obtain RPC backend start server function\n");
  299. return 1;
  300. }
  301. start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(),
  302. devices.data(), free_mem.data(), total_mem.data());
  303. return 0;
  304. }