ggml-rpc.cpp 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032
  1. #include "ggml-rpc.h"
  2. #include "ggml.h"
  3. #include "ggml-backend-impl.h"
  4. #include <cinttypes>
  5. #include <string>
  6. #include <vector>
  7. #include <memory>
  8. #include <unordered_map>
  9. #include <unordered_set>
  10. #ifdef _WIN32
  11. # define WIN32_LEAN_AND_MEAN
  12. # ifndef NOMINMAX
  13. # define NOMINMAX
  14. # endif
  15. # include <windows.h>
  16. # include <winsock2.h>
  17. #else
  18. # include <arpa/inet.h>
  19. # include <sys/socket.h>
  20. # include <sys/types.h>
  21. # include <netinet/in.h>
  22. # include <netinet/tcp.h>
  23. # include <netdb.h>
  24. # include <unistd.h>
  25. #endif
  26. #include <string.h>
  27. #define UNUSED GGML_UNUSED
  28. #define GGML_DEBUG 0
  29. #if (GGML_DEBUG >= 1)
  30. #define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
  31. #else
  32. #define GGML_PRINT_DEBUG(...)
  33. #endif
  34. #ifdef _WIN32
  35. typedef SOCKET sockfd_t;
  36. using ssize_t = __int64;
  37. #else
  38. typedef int sockfd_t;
  39. #endif
  40. // cross-platform socket
  41. struct socket_t {
  42. sockfd_t fd;
  43. socket_t(sockfd_t fd) : fd(fd) {}
  44. ~socket_t() {
  45. #ifdef _WIN32
  46. closesocket(this->fd);
  47. #else
  48. close(this->fd);
  49. #endif
  50. }
  51. };
  52. // ggml_tensor is serialized into rpc_tensor
  53. struct rpc_tensor {
  54. uint64_t id;
  55. uint32_t type;
  56. uint64_t buffer;
  57. uint32_t ne[GGML_MAX_DIMS];
  58. uint32_t nb[GGML_MAX_DIMS];
  59. uint32_t op;
  60. int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
  61. int32_t flags;
  62. uint64_t src[GGML_MAX_SRC];
  63. uint64_t view_src;
  64. uint64_t view_offs;
  65. uint64_t data;
  66. char name[GGML_MAX_NAME];
  67. };
  68. // RPC commands
  69. enum rpc_cmd {
  70. ALLOC_BUFFER = 0,
  71. GET_ALIGNMENT,
  72. GET_MAX_SIZE,
  73. BUFFER_GET_BASE,
  74. FREE_BUFFER,
  75. BUFFER_CLEAR,
  76. SET_TENSOR,
  77. GET_TENSOR,
  78. COPY_TENSOR,
  79. GRAPH_COMPUTE,
  80. GET_DEVICE_MEMORY,
  81. };
  82. // RPC data structures
  83. static ggml_guid_t ggml_backend_rpc_guid() {
  84. static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};
  85. return &guid;
  86. }
  87. struct ggml_backend_rpc_buffer_type_context {
  88. std::shared_ptr<socket_t> sock;
  89. std::string name;
  90. size_t alignment;
  91. size_t max_size;
  92. };
  93. struct ggml_backend_rpc_context {
  94. std::string endpoint;
  95. std::string name;
  96. std::shared_ptr<socket_t> sock;
  97. ggml_backend_buffer_type_t buft;
  98. };
  99. struct ggml_backend_rpc_buffer_context {
  100. std::shared_ptr<socket_t> sock;
  101. std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
  102. uint64_t remote_ptr;
  103. std::string name;
  104. };
  105. // RPC helper functions
  106. static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
  107. #ifdef _WIN32
  108. if (fd == INVALID_SOCKET) {
  109. return nullptr;
  110. }
  111. #else
  112. if (fd < 0) {
  113. return nullptr;
  114. }
  115. #endif
  116. return std::make_shared<socket_t>(fd);
  117. }
  118. static bool set_no_delay(sockfd_t sockfd) {
  119. int flag = 1;
  120. // set TCP_NODELAY to disable Nagle's algorithm
  121. int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
  122. return ret == 0;
  123. }
  124. static bool set_reuse_addr(sockfd_t sockfd) {
  125. int flag = 1;
  126. int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int));
  127. return ret == 0;
  128. }
  129. static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
  130. struct sockaddr_in addr;
  131. auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
  132. auto sock_ptr = make_socket(sockfd);
  133. if (sock_ptr == nullptr) {
  134. return nullptr;
  135. }
  136. if (!set_no_delay(sockfd)) {
  137. fprintf(stderr, "Failed to set TCP_NODELAY\n");
  138. return nullptr;
  139. }
  140. addr.sin_family = AF_INET;
  141. addr.sin_port = htons(port);
  142. struct hostent * server = gethostbyname(host);
  143. if (server == NULL) {
  144. fprintf(stderr, "Cannot resolve host '%s'\n", host);
  145. return nullptr;
  146. }
  147. memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
  148. if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
  149. return nullptr;
  150. }
  151. return sock_ptr;
  152. }
  153. static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {
  154. auto client_socket_fd = accept(srv_sockfd, NULL, NULL);
  155. auto client_socket = make_socket(client_socket_fd);
  156. if (client_socket == nullptr) {
  157. return nullptr;
  158. }
  159. if (!set_no_delay(client_socket_fd)) {
  160. fprintf(stderr, "Failed to set TCP_NODELAY\n");
  161. return nullptr;
  162. }
  163. return client_socket;
  164. }
  165. static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) {
  166. auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
  167. auto sock = make_socket(sockfd);
  168. if (sock == nullptr) {
  169. return nullptr;
  170. }
  171. if (!set_reuse_addr(sockfd)) {
  172. fprintf(stderr, "Failed to set SO_REUSEADDR\n");
  173. return nullptr;
  174. }
  175. struct sockaddr_in serv_addr;
  176. serv_addr.sin_family = AF_INET;
  177. serv_addr.sin_addr.s_addr = inet_addr(host);
  178. serv_addr.sin_port = htons(port);
  179. if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
  180. return nullptr;
  181. }
  182. if (listen(sockfd, 1) < 0) {
  183. return nullptr;
  184. }
  185. return sock;
  186. }
  187. static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
  188. size_t bytes_sent = 0;
  189. while (bytes_sent < size) {
  190. ssize_t n = send(sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0);
  191. if (n < 0) {
  192. return false;
  193. }
  194. bytes_sent += n;
  195. }
  196. return true;
  197. }
  198. static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
  199. size_t bytes_recv = 0;
  200. while (bytes_recv < size) {
  201. ssize_t n = recv(sockfd, (char *)data + bytes_recv, size - bytes_recv, 0);
  202. if (n <= 0) {
  203. return false;
  204. }
  205. bytes_recv += n;
  206. }
  207. return true;
  208. }
  209. static bool parse_endpoint(const char * endpoint, std::string & host, int & port) {
  210. std::string str(endpoint);
  211. size_t pos = str.find(':');
  212. if (pos == std::string::npos) {
  213. return false;
  214. }
  215. host = str.substr(0, pos);
  216. port = std::stoi(str.substr(pos + 1));
  217. return true;
  218. }
  219. // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
  220. // RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
  221. static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
  222. uint8_t cmd_byte = cmd;
  223. if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
  224. return false;
  225. }
  226. uint64_t input_size = input.size();
  227. if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
  228. return false;
  229. }
  230. if (!send_data(sock->fd, input.data(), input.size())) {
  231. return false;
  232. }
  233. uint64_t output_size;
  234. if (!recv_data(sock->fd, &output_size, sizeof(output_size))) {
  235. return false;
  236. }
  237. if (output_size == 0) {
  238. output.clear();
  239. return true;
  240. }
  241. output.resize(output_size);
  242. if (!recv_data(sock->fd, output.data(), output_size)) {
  243. return false;
  244. }
  245. return true;
  246. }
  247. // RPC client-side implementation
  248. GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
  249. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  250. return ctx->name.c_str();
  251. }
  252. GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
  253. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  254. // input serialization format: | remote_ptr (8 bytes) |
  255. std::vector<uint8_t> input(sizeof(uint64_t), 0);
  256. uint64_t remote_ptr = ctx->remote_ptr;
  257. memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
  258. std::vector<uint8_t> output;
  259. bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output);
  260. GGML_ASSERT(status);
  261. GGML_ASSERT(output.empty());
  262. delete ctx;
  263. }
  264. GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
  265. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  266. if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
  267. return ctx->base_cache[buffer];
  268. }
  269. // input serialization format: | remote_ptr (8 bytes) |
  270. std::vector<uint8_t> input(sizeof(uint64_t), 0);
  271. uint64_t remote_ptr = ctx->remote_ptr;
  272. memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
  273. std::vector<uint8_t> output;
  274. bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output);
  275. GGML_ASSERT(status);
  276. GGML_ASSERT(output.size() == sizeof(uint64_t));
  277. // output serialization format: | base_ptr (8 bytes) |
  278. uint64_t base_ptr;
  279. memcpy(&base_ptr, output.data(), sizeof(base_ptr));
  280. void * base = reinterpret_cast<void *>(base_ptr);
  281. ctx->base_cache[buffer] = base;
  282. return base;
  283. }
  284. static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
  285. rpc_tensor result;
  286. result.id = reinterpret_cast<uint64_t>(tensor);
  287. result.type = tensor->type;
  288. if (tensor->buffer) {
  289. ggml_backend_buffer_t buffer = tensor->buffer;
  290. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  291. result.buffer = ctx->remote_ptr;
  292. } else {
  293. result.buffer = 0;
  294. }
  295. for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
  296. result.ne[i] = tensor->ne[i];
  297. result.nb[i] = tensor->nb[i];
  298. }
  299. result.op = tensor->op;
  300. for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
  301. result.op_params[i] = tensor->op_params[i];
  302. }
  303. result.flags = tensor->flags;
  304. for (uint32_t i = 0; i < GGML_MAX_SRC; i++) {
  305. result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]);
  306. }
  307. result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
  308. result.view_offs = tensor->view_offs;
  309. result.data = reinterpret_cast<uint64_t>(tensor->data);
  310. snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
  311. return result;
  312. }
  313. static ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
  314. ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
  315. tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
  316. for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
  317. result->nb[i] = tensor->nb[i];
  318. }
  319. result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
  320. result->op = (ggml_op) tensor->op;
  321. for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
  322. result->op_params[i] = tensor->op_params[i];
  323. }
  324. result->flags = tensor->flags;
  325. result->data = reinterpret_cast<void *>(tensor->data);
  326. ggml_set_name(result, tensor->name);
  327. return result;
  328. }
  329. GGML_CALL static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
  330. UNUSED(buffer);
  331. if (ggml_is_quantized(tensor->type)) {
  332. // TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized
  333. GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor");
  334. }
  335. }
  336. GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
  337. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  338. // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
  339. size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
  340. std::vector<uint8_t> input(input_size, 0);
  341. rpc_tensor rpc_tensor = serialize_tensor(tensor);
  342. memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
  343. memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
  344. memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
  345. std::vector<uint8_t> output;
  346. bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output);
  347. GGML_ASSERT(status);
  348. }
  349. GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
  350. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  351. // input serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
  352. int input_size = sizeof(rpc_tensor) + 2*sizeof(uint64_t);
  353. std::vector<uint8_t> input(input_size, 0);
  354. rpc_tensor rpc_tensor = serialize_tensor(tensor);
  355. memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
  356. memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
  357. memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
  358. std::vector<uint8_t> output;
  359. bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output);
  360. GGML_ASSERT(status);
  361. GGML_ASSERT(output.size() == size);
  362. // output serialization format: | data (size bytes) |
  363. memcpy(data, output.data(), size);
  364. }
  365. GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
  366. // check if src and dst are on the same server
  367. ggml_backend_buffer_t src_buffer = src->buffer;
  368. ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
  369. ggml_backend_buffer_t dst_buffer = dst->buffer;
  370. ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
  371. if (src_ctx->sock != dst_ctx->sock) {
  372. return false;
  373. }
  374. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  375. // input serialization format: | rpc_tensor src | rpc_tensor dst |
  376. int input_size = 2*sizeof(rpc_tensor);
  377. std::vector<uint8_t> input(input_size, 0);
  378. rpc_tensor rpc_src = serialize_tensor(src);
  379. rpc_tensor rpc_dst = serialize_tensor(dst);
  380. memcpy(input.data(), &rpc_src, sizeof(rpc_src));
  381. memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
  382. std::vector<uint8_t> output;
  383. bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output);
  384. GGML_ASSERT(status);
  385. // output serialization format: | result (1 byte) |
  386. GGML_ASSERT(output.size() == 1);
  387. return output[0];
  388. }
  389. GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
  390. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  391. // serialization format: | bufptr (8 bytes) | value (1 byte) |
  392. int input_size = sizeof(uint64_t) + sizeof(uint8_t);
  393. std::vector<uint8_t> input(input_size, 0);
  394. memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
  395. memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
  396. std::vector<uint8_t> output;
  397. bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output);
  398. GGML_ASSERT(status);
  399. }
  400. static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
  401. /* .get_name = */ ggml_backend_rpc_buffer_get_name,
  402. /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
  403. /* .get_base = */ ggml_backend_rpc_buffer_get_base,
  404. /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
  405. /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
  406. /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
  407. /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
  408. /* .clear = */ ggml_backend_rpc_buffer_clear,
  409. /* .reset = */ NULL,
  410. };
  411. GGML_CALL static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {
  412. ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
  413. return buft_ctx->name.c_str();
  414. }
  415. GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
  416. ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
  417. // input serialization format: | size (8 bytes) |
  418. int input_size = sizeof(uint64_t);
  419. std::vector<uint8_t> input(input_size, 0);
  420. memcpy(input.data(), &size, sizeof(size));
  421. std::vector<uint8_t> output;
  422. bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output);
  423. GGML_ASSERT(status);
  424. GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
  425. // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
  426. uint64_t remote_ptr;
  427. memcpy(&remote_ptr, output.data(), sizeof(remote_ptr));
  428. size_t remote_size;
  429. memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size));
  430. ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
  431. ggml_backend_rpc_buffer_interface,
  432. new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
  433. remote_size);
  434. return buffer;
  435. }
  436. static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
  437. // input serialization format: | 0 bytes |
  438. std::vector<uint8_t> input;
  439. std::vector<uint8_t> output;
  440. bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output);
  441. GGML_ASSERT(status);
  442. GGML_ASSERT(output.size() == sizeof(uint64_t));
  443. // output serialization format: | alignment (8 bytes) |
  444. uint64_t alignment;
  445. memcpy(&alignment, output.data(), sizeof(alignment));
  446. return alignment;
  447. }
  448. GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
  449. ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
  450. return buft_ctx->alignment;
  451. }
  452. static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
  453. // input serialization format: | 0 bytes |
  454. std::vector<uint8_t> input;
  455. std::vector<uint8_t> output;
  456. bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output);
  457. GGML_ASSERT(status);
  458. GGML_ASSERT(output.size() == sizeof(uint64_t));
  459. // output serialization format: | max_size (8 bytes) |
  460. uint64_t max_size;
  461. memcpy(&max_size, output.data(), sizeof(max_size));
  462. return max_size;
  463. }
  464. GGML_CALL static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
  465. ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
  466. return buft_ctx->max_size;
  467. }
  468. GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
  469. UNUSED(buft);
  470. return ggml_nbytes(tensor);
  471. }
  472. GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
  473. if (!ggml_backend_is_rpc(backend)) {
  474. return false;
  475. }
  476. ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
  477. ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
  478. return buft_ctx->sock == rpc_ctx->sock;
  479. }
  480. static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
  481. /* .get_name = */ ggml_backend_rpc_buffer_type_name,
  482. /* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer,
  483. /* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment,
  484. /* .get_max_size = */ ggml_backend_rpc_get_max_size,
  485. /* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size,
  486. /* .supports_backend = */ ggml_backend_rpc_buffer_type_supports_backend,
  487. /* .is_host = */ NULL,
  488. };
  489. GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
  490. ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
  491. return rpc_ctx->name.c_str();
  492. }
  493. GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
  494. ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
  495. ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context;
  496. delete buft_ctx;
  497. delete rpc_ctx->buft;
  498. delete rpc_ctx;
  499. delete backend;
  500. }
  501. GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
  502. ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
  503. return ctx->buft;
  504. }
  505. GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
  506. UNUSED(backend);
  507. // this is no-op because we don't have any async operations
  508. }
  509. static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors, std::unordered_set<ggml_tensor*> & visited) {
  510. if (tensor == nullptr) {
  511. return;
  512. }
  513. if (visited.find(tensor) != visited.end()) {
  514. return;
  515. }
  516. visited.insert(tensor);
  517. for (int i = 0; i < GGML_MAX_SRC; i++) {
  518. add_tensor(tensor->src[i], tensors, visited);
  519. }
  520. add_tensor(tensor->view_src, tensors, visited);
  521. tensors.push_back(serialize_tensor(tensor));
  522. }
  523. static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
  524. uint32_t n_nodes = cgraph->n_nodes;
  525. std::vector<rpc_tensor> tensors;
  526. std::unordered_set<ggml_tensor*> visited;
  527. for (uint32_t i = 0; i < n_nodes; i++) {
  528. add_tensor(cgraph->nodes[i], tensors, visited);
  529. }
  530. // serialization format:
  531. // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
  532. uint32_t n_tensors = tensors.size();
  533. int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
  534. output.resize(output_size, 0);
  535. memcpy(output.data(), &n_nodes, sizeof(n_nodes));
  536. uint64_t * out_nodes = (uint64_t *)(output.data() + sizeof(n_nodes));
  537. for (uint32_t i = 0; i < n_nodes; i++) {
  538. out_nodes[i] = reinterpret_cast<uint64_t>(cgraph->nodes[i]);
  539. }
  540. uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
  541. *out_ntensors = n_tensors;
  542. rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
  543. memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
  544. }
  545. GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
  546. ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
  547. std::vector<uint8_t> input;
  548. serialize_graph(cgraph, input);
  549. std::vector<uint8_t> output;
  550. bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output);
  551. GGML_ASSERT(status);
  552. GGML_ASSERT(output.size() == 1);
  553. return (enum ggml_status)output[0];
  554. }
  555. GGML_CALL static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
  556. UNUSED(backend);
  557. UNUSED(op);
  558. GGML_ASSERT(false && "not implemented");
  559. return false;
  560. }
  561. static ggml_backend_i ggml_backend_rpc_interface = {
  562. /* .get_name = */ ggml_backend_rpc_name,
  563. /* .free = */ ggml_backend_rpc_free,
  564. /* .get_default_buffer_type = */ ggml_backend_rpc_get_default_buffer_type,
  565. /* .set_tensor_async = */ NULL,
  566. /* .get_tensor_async = */ NULL,
  567. /* .cpy_tensor_async = */ NULL,
  568. /* .synchronize = */ ggml_backend_rpc_synchronize,
  569. /* .graph_plan_create = */ NULL,
  570. /* .graph_plan_free = */ NULL,
  571. /* .graph_plan_compute = */ NULL,
  572. /* .graph_compute = */ ggml_backend_rpc_graph_compute,
  573. /* .supports_op = */ ggml_backend_rpc_supports_op,
  574. /* .offload_op = */ NULL,
  575. /* .event_new = */ NULL,
  576. /* .event_free = */ NULL,
  577. /* .event_record = */ NULL,
  578. /* .event_wait = */ NULL,
  579. /* .event_synchronize = */ NULL,
  580. };
  581. static std::unordered_map<std::string, ggml_backend_t> instances;
  582. GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
  583. ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
  584. return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr;
  585. }
  586. GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
  587. std::string endpoint_str(endpoint);
  588. if (instances.find(endpoint_str) != instances.end()) {
  589. return instances[endpoint_str];
  590. }
  591. #ifdef _WIN32
  592. {
  593. WSADATA wsaData;
  594. int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
  595. if (res != 0) {
  596. return nullptr;
  597. }
  598. }
  599. #endif
  600. GGML_PRINT_DEBUG("Connecting to %s\n", endpoint);
  601. std::string host;
  602. int port;
  603. if (!parse_endpoint(endpoint, host, port)) {
  604. return nullptr;
  605. }
  606. auto sock = socket_connect(host.c_str(), port);
  607. if (sock == nullptr) {
  608. return nullptr;
  609. }
  610. size_t alignment = get_alignment(sock);
  611. size_t max_size = get_max_size(sock);
  612. ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
  613. /* .sock = */ sock,
  614. /* .name = */ "RPC" + std::to_string(sock->fd),
  615. /* .alignment = */ alignment,
  616. /* .max_size = */ max_size
  617. };
  618. ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
  619. /* .iface = */ ggml_backend_rpc_buffer_type_interface,
  620. /* .context = */ buft_ctx
  621. };
  622. ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
  623. /* .endpoint = */ endpoint,
  624. /* .name = */ "RPC" + std::to_string(sock->fd),
  625. /* .sock = */ sock,
  626. /* .buft = */ buft
  627. };
  628. instances[endpoint] = new ggml_backend {
  629. /* .guid = */ ggml_backend_rpc_guid(),
  630. /* .interface = */ ggml_backend_rpc_interface,
  631. /* .context = */ ctx
  632. };
  633. return instances[endpoint];
  634. }
  635. GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
  636. return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
  637. }
  638. static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
  639. // input serialization format: | 0 bytes |
  640. std::vector<uint8_t> input;
  641. std::vector<uint8_t> output;
  642. bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output);
  643. GGML_ASSERT(status);
  644. GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
  645. // output serialization format: | free (8 bytes) | total (8 bytes) |
  646. uint64_t free_mem;
  647. memcpy(&free_mem, output.data(), sizeof(free_mem));
  648. uint64_t total_mem;
  649. memcpy(&total_mem, output.data() + sizeof(uint64_t), sizeof(total_mem));
  650. *free = free_mem;
  651. *total = total_mem;
  652. }
  653. GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
  654. ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
  655. if (backend == nullptr) {
  656. *free = 0;
  657. *total = 0;
  658. return;
  659. }
  660. ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
  661. get_device_memory(ctx->sock, free, total);
  662. }
  663. // RPC server-side implementation
  664. static void rpc_alloc_buffer(ggml_backend_t backend, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
  665. // input serialization format: | size (8 bytes) |
  666. uint64_t size;
  667. memcpy(&size, input.data(), sizeof(size));
  668. ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
  669. ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
  670. uint64_t remote_ptr = reinterpret_cast<uint64_t>(buffer);
  671. uint64_t remote_size = buffer->size;
  672. GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
  673. // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
  674. output.resize(2*sizeof(uint64_t), 0);
  675. memcpy(output.data(), &remote_ptr, sizeof(remote_ptr));
  676. memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
  677. }
  678. static void rpc_get_alignment(ggml_backend_t backend, std::vector<uint8_t> & output) {
  679. ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
  680. size_t alignment = ggml_backend_buft_get_alignment(buft);
  681. GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
  682. // output serialization format: | alignment (8 bytes) |
  683. output.resize(sizeof(uint64_t), 0);
  684. memcpy(output.data(), &alignment, sizeof(alignment));
  685. }
  686. static void rpc_get_max_size(ggml_backend_t backend, std::vector<uint8_t> & output) {
  687. ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
  688. size_t max_size = ggml_backend_buft_get_max_size(buft);
  689. GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
  690. // output serialization format: | max_size (8 bytes) |
  691. output.resize(sizeof(uint64_t), 0);
  692. memcpy(output.data(), &max_size, sizeof(max_size));
  693. }
  694. static void rpc_buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
  695. // input serialization format: | remote_ptr (8 bytes) |
  696. uint64_t remote_ptr;
  697. memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
  698. GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
  699. ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
  700. void * base = ggml_backend_buffer_get_base(buffer);
  701. // output serialization format: | base_ptr (8 bytes) |
  702. uint64_t base_ptr = reinterpret_cast<uint64_t>(base);
  703. output.resize(sizeof(uint64_t), 0);
  704. memcpy(output.data(), &base_ptr, sizeof(base_ptr));
  705. }
  706. static void rpc_free_buffer(const std::vector<uint8_t> & input) {
  707. // input serialization format: | remote_ptr (8 bytes) |
  708. uint64_t remote_ptr;
  709. memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
  710. GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
  711. ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
  712. ggml_backend_buffer_free(buffer);
  713. }
  714. static void rpc_buffer_clear(const std::vector<uint8_t> & input) {
  715. // input serialization format: | remote_ptr (8 bytes) | value (1 byte) |
  716. uint64_t remote_ptr;
  717. memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
  718. uint8_t value;
  719. memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value));
  720. GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value);
  721. ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
  722. ggml_backend_buffer_clear(buffer, value);
  723. }
  724. static void rpc_set_tensor(const std::vector<uint8_t> & input) {
  725. // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
  726. const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
  727. uint64_t offset;
  728. memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
  729. size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
  730. struct ggml_init_params params {
  731. /*.mem_size =*/ ggml_tensor_overhead(),
  732. /*.mem_buffer =*/ NULL,
  733. /*.no_alloc =*/ true,
  734. };
  735. struct ggml_context * ctx = ggml_init(params);
  736. ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
  737. GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
  738. const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
  739. ggml_backend_tensor_set(tensor, data, offset, size);
  740. ggml_free(ctx);
  741. }
  742. static void rpc_get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
  743. // serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
  744. const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
  745. uint64_t offset;
  746. memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
  747. uint64_t size;
  748. memcpy(&size, input.data() + sizeof(rpc_tensor) + sizeof(offset), sizeof(size));
  749. struct ggml_init_params params {
  750. /*.mem_size =*/ ggml_tensor_overhead(),
  751. /*.mem_buffer =*/ NULL,
  752. /*.no_alloc =*/ true,
  753. };
  754. struct ggml_context * ctx = ggml_init(params);
  755. ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
  756. GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
  757. // output serialization format: | data (size bytes) |
  758. output.resize(size, 0);
  759. ggml_backend_tensor_get(tensor, output.data(), offset, size);
  760. ggml_free(ctx);
  761. }
  762. static void rpc_copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
  763. // serialization format: | rpc_tensor src | rpc_tensor dst |
  764. const rpc_tensor * rpc_src = (const rpc_tensor *)input.data();
  765. const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_src));
  766. struct ggml_init_params params {
  767. /*.mem_size =*/ 2*ggml_tensor_overhead(),
  768. /*.mem_buffer =*/ NULL,
  769. /*.no_alloc =*/ true,
  770. };
  771. struct ggml_context * ctx = ggml_init(params);
  772. ggml_tensor * src = deserialize_tensor(ctx, rpc_src);
  773. ggml_tensor * dst = deserialize_tensor(ctx, rpc_dst);
  774. GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
  775. bool result = ggml_backend_buffer_copy_tensor(src, dst);
  776. // output serialization format: | result (1 byte) |
  777. output.resize(1, 0);
  778. output[0] = result;
  779. ggml_free(ctx);
  780. }
  781. static struct ggml_tensor * create_node(uint64_t id,
  782. struct ggml_context * ctx,
  783. const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
  784. std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
  785. if (id == 0) {
  786. return nullptr;
  787. }
  788. if (tensor_map.find(id) != tensor_map.end()) {
  789. return tensor_map[id];
  790. }
  791. const rpc_tensor * tensor = tensor_ptrs.at(id);
  792. struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
  793. tensor_map[id] = result;
  794. for (int i = 0; i < GGML_MAX_SRC; i++) {
  795. result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
  796. }
  797. result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
  798. result->view_offs = tensor->view_offs;
  799. return result;
  800. }
  801. static void rpc_graph_compute(ggml_backend_t backend, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
  802. // serialization format:
  803. // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
  804. uint32_t n_nodes;
  805. memcpy(&n_nodes, input.data(), sizeof(n_nodes));
  806. const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
  807. uint32_t n_tensors;
  808. memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
  809. const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
  810. GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
  811. static size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
  812. struct ggml_init_params params = {
  813. /*.mem_size =*/ buf_size,
  814. /*.mem_buffer =*/ NULL,
  815. /*.no_alloc =*/ true,
  816. };
  817. struct ggml_context * ctx = ggml_init(params);
  818. struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
  819. graph->n_nodes = n_nodes;
  820. std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
  821. for (uint32_t i = 0; i < n_tensors; i++) {
  822. tensor_ptrs[tensors[i].id] = &tensors[i];
  823. }
  824. std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
  825. for (uint32_t i = 0; i < n_nodes; i++) {
  826. graph->nodes[i] = create_node(nodes[i], ctx, tensor_ptrs, tensor_map);
  827. }
  828. ggml_status status = ggml_backend_graph_compute(backend, graph);
  829. // output serialization format: | status (1 byte) |
  830. output.resize(1, 0);
  831. output[0] = status;
  832. ggml_free(ctx);
  833. }
  834. static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
  835. while (true) {
  836. uint8_t cmd;
  837. if (!recv_data(sockfd, &cmd, 1)) {
  838. break;
  839. }
  840. std::vector<uint8_t> input;
  841. std::vector<uint8_t> output;
  842. uint64_t input_size;
  843. if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
  844. break;
  845. }
  846. input.resize(input_size);
  847. if (!recv_data(sockfd, input.data(), input_size)) {
  848. break;
  849. }
  850. switch (cmd) {
  851. case ALLOC_BUFFER: {
  852. rpc_alloc_buffer(backend, input, output);
  853. break;
  854. }
  855. case GET_ALIGNMENT: {
  856. rpc_get_alignment(backend, output);
  857. break;
  858. }
  859. case GET_MAX_SIZE: {
  860. rpc_get_max_size(backend, output);
  861. break;
  862. }
  863. case BUFFER_GET_BASE: {
  864. rpc_buffer_get_base(input, output);
  865. break;
  866. }
  867. case FREE_BUFFER: {
  868. rpc_free_buffer(input);
  869. break;
  870. }
  871. case BUFFER_CLEAR: {
  872. rpc_buffer_clear(input);
  873. break;
  874. }
  875. case SET_TENSOR: {
  876. rpc_set_tensor(input);
  877. break;
  878. }
  879. case GET_TENSOR: {
  880. rpc_get_tensor(input, output);
  881. break;
  882. }
  883. case COPY_TENSOR: {
  884. rpc_copy_tensor(input, output);
  885. break;
  886. }
  887. case GRAPH_COMPUTE: {
  888. rpc_graph_compute(backend, input, output);
  889. break;
  890. }
  891. case GET_DEVICE_MEMORY: {
  892. // output serialization format: | free (8 bytes) | total (8 bytes) |
  893. output.resize(2*sizeof(uint64_t), 0);
  894. memcpy(output.data(), &free_mem, sizeof(free_mem));
  895. memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
  896. break;
  897. }
  898. default: {
  899. fprintf(stderr, "Unknown command: %d\n", cmd);
  900. return;
  901. }
  902. }
  903. uint64_t output_size = output.size();
  904. if (!send_data(sockfd, &output_size, sizeof(output_size))) {
  905. break;
  906. }
  907. if (!send_data(sockfd, output.data(), output_size)) {
  908. break;
  909. }
  910. }
  911. }
  912. void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
  913. std::string host;
  914. int port;
  915. if (!parse_endpoint(endpoint, host, port)) {
  916. return;
  917. }
  918. #ifdef _WIN32
  919. {
  920. WSADATA wsaData;
  921. int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
  922. if (res != 0) {
  923. fprintf(stderr, "WSAStartup failed: %d\n", res);
  924. return;
  925. }
  926. }
  927. #endif
  928. auto server_socket = create_server_socket(host.c_str(), port);
  929. if (server_socket == nullptr) {
  930. fprintf(stderr, "Failed to create server socket\n");
  931. return;
  932. }
  933. while (true) {
  934. auto client_socket = socket_accept(server_socket->fd);
  935. if (client_socket == nullptr) {
  936. fprintf(stderr, "Failed to accept client connection\n");
  937. return;
  938. }
  939. printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
  940. rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
  941. printf("Client connection closed\n");
  942. }
  943. #ifdef _WIN32
  944. WSACleanup();
  945. #endif
  946. }