rpc-server.cpp 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. #if defined(_MSC_VER)
  2. #define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
  3. #endif
  4. #include "ggml-rpc.h"
  5. #ifdef _WIN32
  6. # define NOMINMAX
  7. # define DIRECTORY_SEPARATOR '\\'
  8. # include <locale>
  9. # include <windows.h>
  10. # include <fcntl.h>
  11. # include <io.h>
  12. #else
  13. # define DIRECTORY_SEPARATOR '/'
  14. # include <unistd.h>
  15. # include <sys/stat.h>
  16. #endif
  17. #include <codecvt>
  18. #include <string>
  19. #include <stdio.h>
  20. #include <vector>
  21. #include <filesystem>
  22. #include <algorithm>
  23. #include <thread>
  24. #include <regex>
  25. namespace fs = std::filesystem;
  26. // NOTE: this is copied from common.cpp to avoid linking with libcommon
  27. // returns true if successful, false otherwise
  28. static bool fs_create_directory_with_parents(const std::string & path) {
  29. #ifdef _WIN32
  30. std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
  31. std::wstring wpath = converter.from_bytes(path);
  32. // if the path already exists, check whether it's a directory
  33. const DWORD attributes = GetFileAttributesW(wpath.c_str());
  34. if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) {
  35. return true;
  36. }
  37. size_t pos_slash = 0;
  38. // process path from front to back, procedurally creating directories
  39. while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
  40. const std::wstring subpath = wpath.substr(0, pos_slash);
  41. const wchar_t * test = subpath.c_str();
  42. const bool success = CreateDirectoryW(test, NULL);
  43. if (!success) {
  44. const DWORD error = GetLastError();
  45. // if the path already exists, ensure that it's a directory
  46. if (error == ERROR_ALREADY_EXISTS) {
  47. const DWORD attributes = GetFileAttributesW(subpath.c_str());
  48. if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) {
  49. return false;
  50. }
  51. } else {
  52. return false;
  53. }
  54. }
  55. pos_slash += 1;
  56. }
  57. return true;
  58. #else
  59. // if the path already exists, check whether it's a directory
  60. struct stat info;
  61. if (stat(path.c_str(), &info) == 0) {
  62. return S_ISDIR(info.st_mode);
  63. }
  64. size_t pos_slash = 1; // skip leading slashes for directory creation
  65. // process path from front to back, procedurally creating directories
  66. while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) {
  67. const std::string subpath = path.substr(0, pos_slash);
  68. struct stat info;
  69. // if the path already exists, ensure that it's a directory
  70. if (stat(subpath.c_str(), &info) == 0) {
  71. if (!S_ISDIR(info.st_mode)) {
  72. return false;
  73. }
  74. } else {
  75. // create parent directories
  76. const int ret = mkdir(subpath.c_str(), 0755);
  77. if (ret != 0) {
  78. return false;
  79. }
  80. }
  81. pos_slash += 1;
  82. }
  83. return true;
  84. #endif // _WIN32
  85. }
  86. // NOTE: this is copied from common.cpp to avoid linking with libcommon
  87. static std::string fs_get_cache_directory() {
  88. std::string cache_directory = "";
  89. auto ensure_trailing_slash = [](std::string p) {
  90. // Make sure to add trailing slash
  91. if (p.back() != DIRECTORY_SEPARATOR) {
  92. p += DIRECTORY_SEPARATOR;
  93. }
  94. return p;
  95. };
  96. if (getenv("LLAMA_CACHE")) {
  97. cache_directory = std::getenv("LLAMA_CACHE");
  98. } else {
  99. #if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__)
  100. if (std::getenv("XDG_CACHE_HOME")) {
  101. cache_directory = std::getenv("XDG_CACHE_HOME");
  102. } else {
  103. cache_directory = std::getenv("HOME") + std::string("/.cache/");
  104. }
  105. #elif defined(__APPLE__)
  106. cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
  107. #elif defined(_WIN32)
  108. cache_directory = std::getenv("LOCALAPPDATA");
  109. #else
  110. # error Unknown architecture
  111. #endif
  112. cache_directory = ensure_trailing_slash(cache_directory);
  113. cache_directory += "llama.cpp";
  114. }
  115. return ensure_trailing_slash(cache_directory);
  116. }
  117. struct rpc_server_params {
  118. std::string host = "127.0.0.1";
  119. int port = 50052;
  120. bool use_cache = false;
  121. int n_threads = std::max(1U, std::thread::hardware_concurrency()/2);
  122. std::vector<std::string> devices;
  123. };
  124. static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) {
  125. fprintf(stderr, "Usage: %s [options]\n\n", argv[0]);
  126. fprintf(stderr, "options:\n");
  127. fprintf(stderr, " -h, --help show this help message and exit\n");
  128. fprintf(stderr, " -t, --threads N number of threads for the CPU device (default: %d)\n", params.n_threads);
  129. fprintf(stderr, " -d, --device <dev1,dev2,...> comma-separated list of devices\n");
  130. fprintf(stderr, " -H, --host HOST host to bind to (default: %s)\n", params.host.c_str());
  131. fprintf(stderr, " -p, --port PORT port to bind to (default: %d)\n", params.port);
  132. fprintf(stderr, " -c, --cache enable local file cache\n");
  133. fprintf(stderr, "\n");
  134. }
  135. static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & params) {
  136. std::string arg;
  137. for (int i = 1; i < argc; i++) {
  138. arg = argv[i];
  139. if (arg == "-H" || arg == "--host") {
  140. if (++i >= argc) {
  141. return false;
  142. }
  143. params.host = argv[i];
  144. } else if (arg == "-t" || arg == "--threads") {
  145. if (++i >= argc) {
  146. return false;
  147. }
  148. params.n_threads = std::stoi(argv[i]);
  149. if (params.n_threads <= 0) {
  150. fprintf(stderr, "error: invalid number of threads: %d\n", params.n_threads);
  151. return false;
  152. }
  153. } else if (arg == "-d" || arg == "--device") {
  154. if (++i >= argc) {
  155. return false;
  156. }
  157. const std::regex regex{ R"([,/]+)" };
  158. std::string dev_str = argv[i];
  159. std::sregex_token_iterator iter(dev_str.begin(), dev_str.end(), regex, -1);
  160. std::sregex_token_iterator end;
  161. for ( ; iter != end; ++iter) {
  162. try {
  163. params.devices.push_back(*iter);
  164. } catch (const std::exception & ) {
  165. fprintf(stderr, "error: invalid device: %s\n", iter->str().c_str());
  166. return false;
  167. }
  168. }
  169. } else if (arg == "-p" || arg == "--port") {
  170. if (++i >= argc) {
  171. return false;
  172. }
  173. params.port = std::stoi(argv[i]);
  174. if (params.port <= 0 || params.port > 65535) {
  175. return false;
  176. }
  177. } else if (arg == "-c" || arg == "--cache") {
  178. params.use_cache = true;
  179. } else if (arg == "-h" || arg == "--help") {
  180. print_usage(argc, argv, params);
  181. exit(0);
  182. } else {
  183. fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
  184. print_usage(argc, argv, params);
  185. exit(0);
  186. }
  187. }
  188. return true;
  189. }
  190. static std::vector<ggml_backend_dev_t> get_devices(const rpc_server_params & params) {
  191. std::vector<ggml_backend_dev_t> devices;
  192. if (!params.devices.empty()) {
  193. for (auto device : params.devices) {
  194. ggml_backend_dev_t dev = ggml_backend_dev_by_name(device.c_str());
  195. if (dev) {
  196. devices.push_back(dev);
  197. } else {
  198. fprintf(stderr, "error: unknown device: %s\n", device.c_str());
  199. fprintf(stderr, "available devices:\n");
  200. for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
  201. auto * dev = ggml_backend_dev_get(i);
  202. size_t free, total;
  203. ggml_backend_dev_memory(dev, &free, &total);
  204. printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
  205. }
  206. return {};
  207. }
  208. }
  209. }
  210. // Try non-CPU devices first
  211. if (devices.empty()) {
  212. for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
  213. ggml_backend_dev_t dev = ggml_backend_dev_get(i);
  214. if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU) {
  215. devices.push_back(dev);
  216. }
  217. }
  218. }
  219. // If there are no accelerators, fallback to CPU device
  220. if (devices.empty()) {
  221. ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
  222. if (dev) {
  223. devices.push_back(dev);
  224. }
  225. }
  226. return devices;
  227. }
  228. int main(int argc, char * argv[]) {
  229. ggml_backend_load_all();
  230. rpc_server_params params;
  231. if (!rpc_server_params_parse(argc, argv, params)) {
  232. fprintf(stderr, "Invalid parameters\n");
  233. return 1;
  234. }
  235. if (params.host != "127.0.0.1") {
  236. fprintf(stderr, "\n");
  237. fprintf(stderr, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n");
  238. fprintf(stderr, "WARNING: Host ('%s') is != '127.0.0.1'\n", params.host.c_str());
  239. fprintf(stderr, " Never expose the RPC server to an open network!\n");
  240. fprintf(stderr, " This is an experimental feature and is not secure!\n");
  241. fprintf(stderr, "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n");
  242. fprintf(stderr, "\n");
  243. }
  244. auto devices = get_devices(params);
  245. if (devices.empty()) {
  246. fprintf(stderr, "No devices found\n");
  247. return 1;
  248. }
  249. std::string endpoint = params.host + ":" + std::to_string(params.port);
  250. const char * cache_dir = nullptr;
  251. std::string cache_dir_str;
  252. if (params.use_cache) {
  253. cache_dir_str = fs_get_cache_directory() + "rpc/";
  254. if (!fs_create_directory_with_parents(cache_dir_str)) {
  255. fprintf(stderr, "Failed to create cache directory: %s\n", cache_dir_str.c_str());
  256. return 1;
  257. }
  258. cache_dir = cache_dir_str.c_str();
  259. }
  260. ggml_backend_reg_t reg = ggml_backend_reg_by_name("RPC");
  261. if (!reg) {
  262. fprintf(stderr, "Failed to find RPC backend\n");
  263. return 1;
  264. }
  265. auto start_server_fn = (decltype(ggml_backend_rpc_start_server)*) ggml_backend_reg_get_proc_address(reg, "ggml_backend_rpc_start_server");
  266. if (!start_server_fn) {
  267. fprintf(stderr, "Failed to obtain RPC backend start server function\n");
  268. return 1;
  269. }
  270. start_server_fn(endpoint.c_str(), cache_dir, params.n_threads, devices.size(), devices.data());
  271. return 0;
  272. }