ggml-rpc.cpp 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178
  1. #include "ggml-rpc.h"
  2. #include "ggml.h"
  3. #include "ggml-backend-impl.h"
  4. #include <cinttypes>
  5. #include <string>
  6. #include <vector>
  7. #include <memory>
  8. #include <mutex>
  9. #include <unordered_map>
  10. #include <unordered_set>
  11. #ifdef _WIN32
  12. # define WIN32_LEAN_AND_MEAN
  13. # ifndef NOMINMAX
  14. # define NOMINMAX
  15. # endif
  16. # include <windows.h>
  17. # include <winsock2.h>
  18. #else
  19. # include <arpa/inet.h>
  20. # include <sys/socket.h>
  21. # include <sys/types.h>
  22. # include <netinet/in.h>
  23. # include <netinet/tcp.h>
  24. # include <netdb.h>
  25. # include <unistd.h>
  26. #endif
  27. #include <string.h>
  28. #define UNUSED GGML_UNUSED
  29. #define GGML_DEBUG 0
  30. #if (GGML_DEBUG >= 1)
  31. #define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__)
  32. #else
  33. #define GGML_PRINT_DEBUG(...)
  34. #endif
  35. #ifdef _WIN32
  36. typedef SOCKET sockfd_t;
  37. using ssize_t = __int64;
  38. #else
  39. typedef int sockfd_t;
  40. #endif
  41. // cross-platform socket
  42. struct socket_t {
  43. sockfd_t fd;
  44. socket_t(sockfd_t fd) : fd(fd) {}
  45. ~socket_t() {
  46. GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
  47. #ifdef _WIN32
  48. closesocket(this->fd);
  49. #else
  50. close(this->fd);
  51. #endif
  52. }
  53. };
  54. // ggml_tensor is serialized into rpc_tensor
  55. #pragma pack(push, 1)
  56. struct rpc_tensor {
  57. uint64_t id;
  58. uint32_t type;
  59. uint64_t buffer;
  60. uint32_t ne[GGML_MAX_DIMS];
  61. uint32_t nb[GGML_MAX_DIMS];
  62. uint32_t op;
  63. int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
  64. int32_t flags;
  65. uint64_t src[GGML_MAX_SRC];
  66. uint64_t view_src;
  67. uint64_t view_offs;
  68. uint64_t data;
  69. char name[GGML_MAX_NAME];
  70. char padding[4];
  71. };
  72. #pragma pack(pop)
  73. static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
  74. // RPC commands
  75. enum rpc_cmd {
  76. ALLOC_BUFFER = 0,
  77. GET_ALIGNMENT,
  78. GET_MAX_SIZE,
  79. BUFFER_GET_BASE,
  80. FREE_BUFFER,
  81. BUFFER_CLEAR,
  82. SET_TENSOR,
  83. GET_TENSOR,
  84. COPY_TENSOR,
  85. GRAPH_COMPUTE,
  86. GET_DEVICE_MEMORY,
  87. };
  88. // RPC data structures
  89. static ggml_guid_t ggml_backend_rpc_guid() {
  90. static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};
  91. return &guid;
  92. }
  93. struct ggml_backend_rpc_buffer_type_context {
  94. std::string endpoint;
  95. std::string name;
  96. size_t alignment;
  97. size_t max_size;
  98. };
  99. struct ggml_backend_rpc_context {
  100. std::string endpoint;
  101. std::string name;
  102. };
  103. struct ggml_backend_rpc_buffer_context {
  104. std::shared_ptr<socket_t> sock;
  105. std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
  106. uint64_t remote_ptr;
  107. std::string name;
  108. };
  109. // RPC helper functions
  110. static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
  111. #ifdef _WIN32
  112. if (fd == INVALID_SOCKET) {
  113. return nullptr;
  114. }
  115. #else
  116. if (fd < 0) {
  117. return nullptr;
  118. }
  119. #endif
  120. return std::make_shared<socket_t>(fd);
  121. }
  122. static bool set_no_delay(sockfd_t sockfd) {
  123. int flag = 1;
  124. // set TCP_NODELAY to disable Nagle's algorithm
  125. int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
  126. return ret == 0;
  127. }
  128. static bool set_reuse_addr(sockfd_t sockfd) {
  129. int flag = 1;
  130. int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int));
  131. return ret == 0;
  132. }
  133. static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
  134. struct sockaddr_in addr;
  135. auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
  136. auto sock_ptr = make_socket(sockfd);
  137. if (sock_ptr == nullptr) {
  138. return nullptr;
  139. }
  140. if (!set_no_delay(sockfd)) {
  141. fprintf(stderr, "Failed to set TCP_NODELAY\n");
  142. return nullptr;
  143. }
  144. addr.sin_family = AF_INET;
  145. addr.sin_port = htons(port);
  146. struct hostent * server = gethostbyname(host);
  147. if (server == NULL) {
  148. fprintf(stderr, "Cannot resolve host '%s'\n", host);
  149. return nullptr;
  150. }
  151. memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
  152. if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
  153. return nullptr;
  154. }
  155. return sock_ptr;
  156. }
  157. static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {
  158. auto client_socket_fd = accept(srv_sockfd, NULL, NULL);
  159. auto client_socket = make_socket(client_socket_fd);
  160. if (client_socket == nullptr) {
  161. return nullptr;
  162. }
  163. if (!set_no_delay(client_socket_fd)) {
  164. fprintf(stderr, "Failed to set TCP_NODELAY\n");
  165. return nullptr;
  166. }
  167. return client_socket;
  168. }
  169. static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) {
  170. auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
  171. auto sock = make_socket(sockfd);
  172. if (sock == nullptr) {
  173. return nullptr;
  174. }
  175. if (!set_reuse_addr(sockfd)) {
  176. fprintf(stderr, "Failed to set SO_REUSEADDR\n");
  177. return nullptr;
  178. }
  179. struct sockaddr_in serv_addr;
  180. serv_addr.sin_family = AF_INET;
  181. serv_addr.sin_addr.s_addr = inet_addr(host);
  182. serv_addr.sin_port = htons(port);
  183. if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
  184. return nullptr;
  185. }
  186. if (listen(sockfd, 1) < 0) {
  187. return nullptr;
  188. }
  189. return sock;
  190. }
  191. static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
  192. size_t bytes_sent = 0;
  193. while (bytes_sent < size) {
  194. ssize_t n = send(sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0);
  195. if (n < 0) {
  196. return false;
  197. }
  198. bytes_sent += n;
  199. }
  200. return true;
  201. }
  202. static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
  203. size_t bytes_recv = 0;
  204. while (bytes_recv < size) {
  205. ssize_t n = recv(sockfd, (char *)data + bytes_recv, size - bytes_recv, 0);
  206. if (n <= 0) {
  207. return false;
  208. }
  209. bytes_recv += n;
  210. }
  211. return true;
  212. }
  213. static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
  214. size_t pos = endpoint.find(':');
  215. if (pos == std::string::npos) {
  216. return false;
  217. }
  218. host = endpoint.substr(0, pos);
  219. port = std::stoi(endpoint.substr(pos + 1));
  220. return true;
  221. }
  222. // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
  223. // RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
  224. static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
  225. uint8_t cmd_byte = cmd;
  226. if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
  227. return false;
  228. }
  229. uint64_t input_size = input.size();
  230. if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
  231. return false;
  232. }
  233. if (!send_data(sock->fd, input.data(), input.size())) {
  234. return false;
  235. }
  236. uint64_t output_size;
  237. if (!recv_data(sock->fd, &output_size, sizeof(output_size))) {
  238. return false;
  239. }
  240. if (output_size == 0) {
  241. output.clear();
  242. return true;
  243. }
  244. output.resize(output_size);
  245. if (!recv_data(sock->fd, output.data(), output_size)) {
  246. return false;
  247. }
  248. return true;
  249. }
  250. // RPC client-side implementation
  251. static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
  252. static std::mutex mutex;
  253. std::lock_guard<std::mutex> lock(mutex);
  254. static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
  255. static bool initialized = false;
  256. auto it = sockets.find(endpoint);
  257. if (it != sockets.end()) {
  258. if (auto sock = it->second.lock()) {
  259. return sock;
  260. }
  261. }
  262. std::string host;
  263. int port;
  264. if (!parse_endpoint(endpoint, host, port)) {
  265. return nullptr;
  266. }
  267. #ifdef _WIN32
  268. if (!initialized) {
  269. WSADATA wsaData;
  270. int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
  271. if (res != 0) {
  272. return nullptr;
  273. }
  274. initialized = true;
  275. }
  276. #else
  277. UNUSED(initialized);
  278. #endif
  279. auto sock = socket_connect(host.c_str(), port);
  280. if (sock == nullptr) {
  281. return nullptr;
  282. }
  283. GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
  284. sockets[endpoint] = sock;
  285. return sock;
  286. }
  287. GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
  288. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  289. return ctx->name.c_str();
  290. }
  291. GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
  292. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  293. // input serialization format: | remote_ptr (8 bytes) |
  294. std::vector<uint8_t> input(sizeof(uint64_t), 0);
  295. uint64_t remote_ptr = ctx->remote_ptr;
  296. memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
  297. std::vector<uint8_t> output;
  298. bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output);
  299. GGML_ASSERT(status);
  300. GGML_ASSERT(output.empty());
  301. delete ctx;
  302. }
  303. GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
  304. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  305. if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
  306. return ctx->base_cache[buffer];
  307. }
  308. // input serialization format: | remote_ptr (8 bytes) |
  309. std::vector<uint8_t> input(sizeof(uint64_t), 0);
  310. uint64_t remote_ptr = ctx->remote_ptr;
  311. memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
  312. std::vector<uint8_t> output;
  313. bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output);
  314. GGML_ASSERT(status);
  315. GGML_ASSERT(output.size() == sizeof(uint64_t));
  316. // output serialization format: | base_ptr (8 bytes) |
  317. uint64_t base_ptr;
  318. memcpy(&base_ptr, output.data(), sizeof(base_ptr));
  319. void * base = reinterpret_cast<void *>(base_ptr);
  320. ctx->base_cache[buffer] = base;
  321. return base;
  322. }
  323. static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
  324. rpc_tensor result;
  325. result.id = reinterpret_cast<uint64_t>(tensor);
  326. result.type = tensor->type;
  327. if (tensor->buffer) {
  328. ggml_backend_buffer_t buffer = tensor->buffer;
  329. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  330. result.buffer = ctx->remote_ptr;
  331. } else {
  332. result.buffer = 0;
  333. }
  334. for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
  335. result.ne[i] = tensor->ne[i];
  336. result.nb[i] = tensor->nb[i];
  337. }
  338. result.op = tensor->op;
  339. for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
  340. result.op_params[i] = tensor->op_params[i];
  341. }
  342. result.flags = tensor->flags;
  343. for (uint32_t i = 0; i < GGML_MAX_SRC; i++) {
  344. result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]);
  345. }
  346. result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
  347. result.view_offs = tensor->view_offs;
  348. result.data = reinterpret_cast<uint64_t>(tensor->data);
  349. snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
  350. return result;
  351. }
  352. GGML_CALL static void ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
  353. UNUSED(buffer);
  354. if (ggml_is_quantized(tensor->type)) {
  355. // TODO: this check is due to MATRIX_ROW_PADDING in CUDA and should be generalized
  356. GGML_ASSERT(tensor->ne[0] % 512 == 0 && "unsupported quantized tensor");
  357. }
  358. }
  359. GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
  360. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  361. // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
  362. size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
  363. std::vector<uint8_t> input(input_size, 0);
  364. rpc_tensor rpc_tensor = serialize_tensor(tensor);
  365. memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
  366. memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
  367. memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
  368. std::vector<uint8_t> output;
  369. bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output);
  370. GGML_ASSERT(status);
  371. }
  372. GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
  373. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  374. // input serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
  375. int input_size = sizeof(rpc_tensor) + 2*sizeof(uint64_t);
  376. std::vector<uint8_t> input(input_size, 0);
  377. rpc_tensor rpc_tensor = serialize_tensor(tensor);
  378. memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
  379. memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
  380. memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
  381. std::vector<uint8_t> output;
  382. bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output);
  383. GGML_ASSERT(status);
  384. GGML_ASSERT(output.size() == size);
  385. // output serialization format: | data (size bytes) |
  386. memcpy(data, output.data(), size);
  387. }
  388. GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
  389. // check if src and dst are on the same server
  390. ggml_backend_buffer_t src_buffer = src->buffer;
  391. ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
  392. ggml_backend_buffer_t dst_buffer = dst->buffer;
  393. ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
  394. if (src_ctx->sock != dst_ctx->sock) {
  395. return false;
  396. }
  397. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  398. // input serialization format: | rpc_tensor src | rpc_tensor dst |
  399. int input_size = 2*sizeof(rpc_tensor);
  400. std::vector<uint8_t> input(input_size, 0);
  401. rpc_tensor rpc_src = serialize_tensor(src);
  402. rpc_tensor rpc_dst = serialize_tensor(dst);
  403. memcpy(input.data(), &rpc_src, sizeof(rpc_src));
  404. memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
  405. std::vector<uint8_t> output;
  406. bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output);
  407. GGML_ASSERT(status);
  408. // output serialization format: | result (1 byte) |
  409. GGML_ASSERT(output.size() == 1);
  410. return output[0];
  411. }
  412. GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
  413. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  414. // serialization format: | bufptr (8 bytes) | value (1 byte) |
  415. int input_size = sizeof(uint64_t) + sizeof(uint8_t);
  416. std::vector<uint8_t> input(input_size, 0);
  417. memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
  418. memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
  419. std::vector<uint8_t> output;
  420. bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output);
  421. GGML_ASSERT(status);
  422. }
  423. static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
  424. /* .get_name = */ ggml_backend_rpc_buffer_get_name,
  425. /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
  426. /* .get_base = */ ggml_backend_rpc_buffer_get_base,
  427. /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
  428. /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
  429. /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
  430. /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
  431. /* .clear = */ ggml_backend_rpc_buffer_clear,
  432. /* .reset = */ NULL,
  433. };
  434. GGML_CALL static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {
  435. ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
  436. return buft_ctx->name.c_str();
  437. }
  438. GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
  439. ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
  440. // input serialization format: | size (8 bytes) |
  441. int input_size = sizeof(uint64_t);
  442. std::vector<uint8_t> input(input_size, 0);
  443. memcpy(input.data(), &size, sizeof(size));
  444. std::vector<uint8_t> output;
  445. auto sock = get_socket(buft_ctx->endpoint);
  446. bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
  447. GGML_ASSERT(status);
  448. GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
  449. // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
  450. uint64_t remote_ptr;
  451. memcpy(&remote_ptr, output.data(), sizeof(remote_ptr));
  452. size_t remote_size;
  453. memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size));
  454. if (remote_ptr != 0) {
  455. ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
  456. ggml_backend_rpc_buffer_interface,
  457. new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
  458. remote_size);
  459. return buffer;
  460. } else {
  461. return nullptr;
  462. }
  463. }
  464. static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
  465. // input serialization format: | 0 bytes |
  466. std::vector<uint8_t> input;
  467. std::vector<uint8_t> output;
  468. bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output);
  469. GGML_ASSERT(status);
  470. GGML_ASSERT(output.size() == sizeof(uint64_t));
  471. // output serialization format: | alignment (8 bytes) |
  472. uint64_t alignment;
  473. memcpy(&alignment, output.data(), sizeof(alignment));
  474. return alignment;
  475. }
  476. GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
  477. ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
  478. return buft_ctx->alignment;
  479. }
  480. static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
  481. // input serialization format: | 0 bytes |
  482. std::vector<uint8_t> input;
  483. std::vector<uint8_t> output;
  484. bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output);
  485. GGML_ASSERT(status);
  486. GGML_ASSERT(output.size() == sizeof(uint64_t));
  487. // output serialization format: | max_size (8 bytes) |
  488. uint64_t max_size;
  489. memcpy(&max_size, output.data(), sizeof(max_size));
  490. return max_size;
  491. }
  492. GGML_CALL static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
  493. ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
  494. return buft_ctx->max_size;
  495. }
  496. GGML_CALL static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
  497. UNUSED(buft);
  498. return ggml_nbytes(tensor);
  499. }
  500. static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
  501. /* .get_name = */ ggml_backend_rpc_buffer_type_name,
  502. /* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer,
  503. /* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment,
  504. /* .get_max_size = */ ggml_backend_rpc_get_max_size,
  505. /* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size,
  506. /* .is_host = */ NULL,
  507. };
  508. GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
  509. ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
  510. return rpc_ctx->name.c_str();
  511. }
  512. GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
  513. ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
  514. delete rpc_ctx;
  515. delete backend;
  516. }
  517. GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
  518. ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
  519. return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
  520. }
  521. GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
  522. UNUSED(backend);
  523. // this is no-op because we don't have any async operations
  524. }
  525. static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors, std::unordered_set<ggml_tensor*> & visited) {
  526. if (tensor == nullptr) {
  527. return;
  528. }
  529. if (visited.find(tensor) != visited.end()) {
  530. return;
  531. }
  532. visited.insert(tensor);
  533. for (int i = 0; i < GGML_MAX_SRC; i++) {
  534. add_tensor(tensor->src[i], tensors, visited);
  535. }
  536. add_tensor(tensor->view_src, tensors, visited);
  537. tensors.push_back(serialize_tensor(tensor));
  538. }
  539. static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
  540. uint32_t n_nodes = cgraph->n_nodes;
  541. std::vector<rpc_tensor> tensors;
  542. std::unordered_set<ggml_tensor*> visited;
  543. for (uint32_t i = 0; i < n_nodes; i++) {
  544. add_tensor(cgraph->nodes[i], tensors, visited);
  545. }
  546. // serialization format:
  547. // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
  548. uint32_t n_tensors = tensors.size();
  549. int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
  550. output.resize(output_size, 0);
  551. memcpy(output.data(), &n_nodes, sizeof(n_nodes));
  552. for (uint32_t i = 0; i < n_nodes; i++) {
  553. memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
  554. }
  555. uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
  556. *out_ntensors = n_tensors;
  557. rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
  558. memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
  559. }
  560. GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
  561. ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
  562. std::vector<uint8_t> input;
  563. serialize_graph(cgraph, input);
  564. std::vector<uint8_t> output;
  565. auto sock = get_socket(rpc_ctx->endpoint);
  566. bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
  567. GGML_ASSERT(status);
  568. GGML_ASSERT(output.size() == 1);
  569. return (enum ggml_status)output[0];
  570. }
  571. GGML_CALL static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
  572. UNUSED(backend);
  573. UNUSED(op);
  574. //TODO: call the remote backend and cache the results
  575. return true;
  576. }
  577. GGML_CALL static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
  578. if (buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
  579. return false;
  580. }
  581. ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
  582. ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
  583. return buft_ctx->endpoint == rpc_ctx->endpoint;
  584. }
  585. static ggml_backend_i ggml_backend_rpc_interface = {
  586. /* .get_name = */ ggml_backend_rpc_name,
  587. /* .free = */ ggml_backend_rpc_free,
  588. /* .get_default_buffer_type = */ ggml_backend_rpc_get_default_buffer_type,
  589. /* .set_tensor_async = */ NULL,
  590. /* .get_tensor_async = */ NULL,
  591. /* .cpy_tensor_async = */ NULL,
  592. /* .synchronize = */ ggml_backend_rpc_synchronize,
  593. /* .graph_plan_create = */ NULL,
  594. /* .graph_plan_free = */ NULL,
  595. /* .graph_plan_update = */ NULL,
  596. /* .graph_plan_compute = */ NULL,
  597. /* .graph_compute = */ ggml_backend_rpc_graph_compute,
  598. /* .supports_op = */ ggml_backend_rpc_supports_op,
  599. /* .supports_buft = */ ggml_backend_rpc_supports_buft,
  600. /* .offload_op = */ NULL,
  601. /* .event_new = */ NULL,
  602. /* .event_free = */ NULL,
  603. /* .event_record = */ NULL,
  604. /* .event_wait = */ NULL,
  605. /* .event_synchronize = */ NULL,
  606. };
  607. GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
  608. static std::mutex mutex;
  609. std::lock_guard<std::mutex> lock(mutex);
  610. // NOTE: buffer types are allocated and never freed; this is by design
  611. static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
  612. auto it = buft_map.find(endpoint);
  613. if (it != buft_map.end()) {
  614. return it->second;
  615. }
  616. auto sock = get_socket(endpoint);
  617. if (sock == nullptr) {
  618. return nullptr;
  619. }
  620. size_t alignment = get_alignment(sock);
  621. size_t max_size = get_max_size(sock);
  622. ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
  623. /* .endpoint = */ endpoint,
  624. /* .name = */ "RPC[" + std::string(endpoint) + "]",
  625. /* .alignment = */ alignment,
  626. /* .max_size = */ max_size
  627. };
  628. ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
  629. /* .iface = */ ggml_backend_rpc_buffer_type_interface,
  630. /* .context = */ buft_ctx
  631. };
  632. buft_map[endpoint] = buft;
  633. return buft;
  634. }
  635. GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
  636. ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
  637. /* .endpoint = */ endpoint,
  638. /* .name = */ "RPC[" + std::string(endpoint) + "]",
  639. };
  640. ggml_backend_t backend = new ggml_backend {
  641. /* .guid = */ ggml_backend_rpc_guid(),
  642. /* .interface = */ ggml_backend_rpc_interface,
  643. /* .context = */ ctx
  644. };
  645. return backend;
  646. }
  647. GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
  648. return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
  649. }
  650. static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
  651. // input serialization format: | 0 bytes |
  652. std::vector<uint8_t> input;
  653. std::vector<uint8_t> output;
  654. bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output);
  655. GGML_ASSERT(status);
  656. GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
  657. // output serialization format: | free (8 bytes) | total (8 bytes) |
  658. uint64_t free_mem;
  659. memcpy(&free_mem, output.data(), sizeof(free_mem));
  660. uint64_t total_mem;
  661. memcpy(&total_mem, output.data() + sizeof(uint64_t), sizeof(total_mem));
  662. *free = free_mem;
  663. *total = total_mem;
  664. }
  665. GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
  666. auto sock = get_socket(endpoint);
  667. if (sock == nullptr) {
  668. *free = 0;
  669. *total = 0;
  670. return;
  671. }
  672. get_device_memory(sock, free, total);
  673. }
  674. // RPC server-side implementation
  675. class rpc_server {
  676. public:
  677. rpc_server(ggml_backend_t backend) : backend(backend) {}
  678. ~rpc_server();
  679. bool alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
  680. void get_alignment(std::vector<uint8_t> & output);
  681. void get_max_size(std::vector<uint8_t> & output);
  682. bool buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
  683. bool free_buffer(const std::vector<uint8_t> & input);
  684. bool buffer_clear(const std::vector<uint8_t> & input);
  685. bool set_tensor(const std::vector<uint8_t> & input);
  686. bool get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
  687. bool copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
  688. bool graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output);
  689. private:
  690. ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
  691. ggml_tensor * create_node(uint64_t id,
  692. struct ggml_context * ctx,
  693. const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
  694. std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
  695. ggml_backend_t backend;
  696. std::unordered_set<ggml_backend_buffer_t> buffers;
  697. };
  698. bool rpc_server::alloc_buffer(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
  699. // input serialization format: | size (8 bytes) |
  700. if (input.size() != sizeof(uint64_t)) {
  701. return false;
  702. }
  703. uint64_t size;
  704. memcpy(&size, input.data(), sizeof(size));
  705. ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
  706. ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
  707. uint64_t remote_ptr = 0;
  708. uint64_t remote_size = 0;
  709. if (buffer != nullptr) {
  710. remote_ptr = reinterpret_cast<uint64_t>(buffer);
  711. remote_size = buffer->size;
  712. GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
  713. buffers.insert(buffer);
  714. } else {
  715. GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, size);
  716. }
  717. // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
  718. output.resize(2*sizeof(uint64_t), 0);
  719. memcpy(output.data(), &remote_ptr, sizeof(remote_ptr));
  720. memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
  721. return true;
  722. }
  723. void rpc_server::get_alignment(std::vector<uint8_t> & output) {
  724. ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
  725. size_t alignment = ggml_backend_buft_get_alignment(buft);
  726. GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
  727. // output serialization format: | alignment (8 bytes) |
  728. output.resize(sizeof(uint64_t), 0);
  729. memcpy(output.data(), &alignment, sizeof(alignment));
  730. }
  731. void rpc_server::get_max_size(std::vector<uint8_t> & output) {
  732. ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
  733. size_t max_size = ggml_backend_buft_get_max_size(buft);
  734. GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
  735. // output serialization format: | max_size (8 bytes) |
  736. output.resize(sizeof(uint64_t), 0);
  737. memcpy(output.data(), &max_size, sizeof(max_size));
  738. }
  739. bool rpc_server::buffer_get_base(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
  740. // input serialization format: | remote_ptr (8 bytes) |
  741. if (input.size() != sizeof(uint64_t)) {
  742. return false;
  743. }
  744. uint64_t remote_ptr;
  745. memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
  746. GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
  747. ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
  748. if (buffers.find(buffer) == buffers.end()) {
  749. GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
  750. return false;
  751. }
  752. void * base = ggml_backend_buffer_get_base(buffer);
  753. // output serialization format: | base_ptr (8 bytes) |
  754. uint64_t base_ptr = reinterpret_cast<uint64_t>(base);
  755. output.resize(sizeof(uint64_t), 0);
  756. memcpy(output.data(), &base_ptr, sizeof(base_ptr));
  757. return true;
  758. }
  759. bool rpc_server::free_buffer(const std::vector<uint8_t> & input) {
  760. // input serialization format: | remote_ptr (8 bytes) |
  761. if (input.size() != sizeof(uint64_t)) {
  762. return false;
  763. }
  764. uint64_t remote_ptr;
  765. memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
  766. GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
  767. ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
  768. if (buffers.find(buffer) == buffers.end()) {
  769. GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
  770. return false;
  771. }
  772. ggml_backend_buffer_free(buffer);
  773. buffers.erase(buffer);
  774. return true;
  775. }
  776. bool rpc_server::buffer_clear(const std::vector<uint8_t> & input) {
  777. // input serialization format: | remote_ptr (8 bytes) | value (1 byte) |
  778. if (input.size() != sizeof(uint64_t) + sizeof(uint8_t)) {
  779. return false;
  780. }
  781. uint64_t remote_ptr;
  782. memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
  783. uint8_t value;
  784. memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value));
  785. GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value);
  786. ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(remote_ptr);
  787. if (buffers.find(buffer) == buffers.end()) {
  788. GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
  789. return false;
  790. }
  791. ggml_backend_buffer_clear(buffer, value);
  792. return true;
  793. }
  794. ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
  795. ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
  796. tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
  797. for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
  798. result->nb[i] = tensor->nb[i];
  799. }
  800. result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
  801. if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
  802. return nullptr;
  803. }
  804. result->op = (ggml_op) tensor->op;
  805. for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
  806. result->op_params[i] = tensor->op_params[i];
  807. }
  808. result->flags = tensor->flags;
  809. result->data = reinterpret_cast<void *>(tensor->data);
  810. ggml_set_name(result, tensor->name);
  811. return result;
  812. }
  813. bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
  814. // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
  815. if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) {
  816. return false;
  817. }
  818. const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
  819. uint64_t offset;
  820. memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
  821. size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
  822. struct ggml_init_params params {
  823. /*.mem_size =*/ ggml_tensor_overhead(),
  824. /*.mem_buffer =*/ NULL,
  825. /*.no_alloc =*/ true,
  826. };
  827. struct ggml_context * ctx = ggml_init(params);
  828. ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
  829. if (tensor == nullptr) {
  830. GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
  831. ggml_free(ctx);
  832. return false;
  833. }
  834. GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
  835. const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
  836. ggml_backend_tensor_set(tensor, data, offset, size);
  837. ggml_free(ctx);
  838. return true;
  839. }
  840. bool rpc_server::get_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
  841. // serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
  842. if (input.size() != sizeof(rpc_tensor) + 2*sizeof(uint64_t)) {
  843. return false;
  844. }
  845. const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
  846. uint64_t offset;
  847. memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
  848. uint64_t size;
  849. memcpy(&size, input.data() + sizeof(rpc_tensor) + sizeof(offset), sizeof(size));
  850. struct ggml_init_params params {
  851. /*.mem_size =*/ ggml_tensor_overhead(),
  852. /*.mem_buffer =*/ NULL,
  853. /*.no_alloc =*/ true,
  854. };
  855. struct ggml_context * ctx = ggml_init(params);
  856. ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
  857. if (tensor == nullptr) {
  858. GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
  859. ggml_free(ctx);
  860. return false;
  861. }
  862. GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
  863. // output serialization format: | data (size bytes) |
  864. output.resize(size, 0);
  865. ggml_backend_tensor_get(tensor, output.data(), offset, size);
  866. ggml_free(ctx);
  867. return true;
  868. }
  869. bool rpc_server::copy_tensor(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
  870. // serialization format: | rpc_tensor src | rpc_tensor dst |
  871. if (input.size() != 2*sizeof(rpc_tensor)) {
  872. return false;
  873. }
  874. const rpc_tensor * rpc_src = (const rpc_tensor *)input.data();
  875. const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_src));
  876. struct ggml_init_params params {
  877. /*.mem_size =*/ 2*ggml_tensor_overhead(),
  878. /*.mem_buffer =*/ NULL,
  879. /*.no_alloc =*/ true,
  880. };
  881. struct ggml_context * ctx = ggml_init(params);
  882. ggml_tensor * src = deserialize_tensor(ctx, rpc_src);
  883. ggml_tensor * dst = deserialize_tensor(ctx, rpc_dst);
  884. if (src == nullptr || dst == nullptr) {
  885. GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
  886. ggml_free(ctx);
  887. return false;
  888. }
  889. GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
  890. bool result = ggml_backend_buffer_copy_tensor(src, dst);
  891. // output serialization format: | result (1 byte) |
  892. output.resize(1, 0);
  893. output[0] = result;
  894. ggml_free(ctx);
  895. return true;
  896. }
  897. ggml_tensor * rpc_server::create_node(uint64_t id,
  898. struct ggml_context * ctx,
  899. const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
  900. std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
  901. if (id == 0) {
  902. return nullptr;
  903. }
  904. if (tensor_map.find(id) != tensor_map.end()) {
  905. return tensor_map[id];
  906. }
  907. const rpc_tensor * tensor = tensor_ptrs.at(id);
  908. struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
  909. if (result == nullptr) {
  910. return nullptr;
  911. }
  912. tensor_map[id] = result;
  913. for (int i = 0; i < GGML_MAX_SRC; i++) {
  914. result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
  915. }
  916. result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
  917. result->view_offs = tensor->view_offs;
  918. return result;
  919. }
  920. bool rpc_server::graph_compute(const std::vector<uint8_t> & input, std::vector<uint8_t> & output) {
  921. // serialization format:
  922. // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
  923. if (input.size() < sizeof(uint32_t)) {
  924. return false;
  925. }
  926. uint32_t n_nodes;
  927. memcpy(&n_nodes, input.data(), sizeof(n_nodes));
  928. if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
  929. return false;
  930. }
  931. const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
  932. uint32_t n_tensors;
  933. memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
  934. if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
  935. return false;
  936. }
  937. const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
  938. GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
  939. static size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
  940. struct ggml_init_params params = {
  941. /*.mem_size =*/ buf_size,
  942. /*.mem_buffer =*/ NULL,
  943. /*.no_alloc =*/ true,
  944. };
  945. struct ggml_context * ctx = ggml_init(params);
  946. struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
  947. graph->n_nodes = n_nodes;
  948. std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
  949. for (uint32_t i = 0; i < n_tensors; i++) {
  950. tensor_ptrs[tensors[i].id] = &tensors[i];
  951. }
  952. std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
  953. for (uint32_t i = 0; i < n_nodes; i++) {
  954. int64_t id;
  955. memcpy(&id, &nodes[i], sizeof(id));
  956. graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
  957. }
  958. ggml_status status = ggml_backend_graph_compute(backend, graph);
  959. // output serialization format: | status (1 byte) |
  960. output.resize(1, 0);
  961. output[0] = status;
  962. ggml_free(ctx);
  963. return true;
  964. }
  965. rpc_server::~rpc_server() {
  966. for (auto buffer : buffers) {
  967. ggml_backend_buffer_free(buffer);
  968. }
  969. }
  970. static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t free_mem, size_t total_mem) {
  971. rpc_server server(backend);
  972. while (true) {
  973. uint8_t cmd;
  974. if (!recv_data(sockfd, &cmd, 1)) {
  975. break;
  976. }
  977. std::vector<uint8_t> input;
  978. std::vector<uint8_t> output;
  979. uint64_t input_size;
  980. if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
  981. break;
  982. }
  983. input.resize(input_size);
  984. if (!recv_data(sockfd, input.data(), input_size)) {
  985. break;
  986. }
  987. bool ok = true;
  988. switch (cmd) {
  989. case ALLOC_BUFFER: {
  990. ok = server.alloc_buffer(input, output);
  991. break;
  992. }
  993. case GET_ALIGNMENT: {
  994. server.get_alignment(output);
  995. break;
  996. }
  997. case GET_MAX_SIZE: {
  998. server.get_max_size(output);
  999. break;
  1000. }
  1001. case BUFFER_GET_BASE: {
  1002. ok = server.buffer_get_base(input, output);
  1003. break;
  1004. }
  1005. case FREE_BUFFER: {
  1006. ok = server.free_buffer(input);
  1007. break;
  1008. }
  1009. case BUFFER_CLEAR: {
  1010. ok = server.buffer_clear(input);
  1011. break;
  1012. }
  1013. case SET_TENSOR: {
  1014. ok = server.set_tensor(input);
  1015. break;
  1016. }
  1017. case GET_TENSOR: {
  1018. ok = server.get_tensor(input, output);
  1019. break;
  1020. }
  1021. case COPY_TENSOR: {
  1022. ok = server.copy_tensor(input, output);
  1023. break;
  1024. }
  1025. case GRAPH_COMPUTE: {
  1026. ok = server.graph_compute(input, output);
  1027. break;
  1028. }
  1029. case GET_DEVICE_MEMORY: {
  1030. // output serialization format: | free (8 bytes) | total (8 bytes) |
  1031. output.resize(2*sizeof(uint64_t), 0);
  1032. memcpy(output.data(), &free_mem, sizeof(free_mem));
  1033. memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
  1034. break;
  1035. }
  1036. default: {
  1037. fprintf(stderr, "Unknown command: %d\n", cmd);
  1038. ok = false;
  1039. }
  1040. }
  1041. if (!ok) {
  1042. break;
  1043. }
  1044. uint64_t output_size = output.size();
  1045. if (!send_data(sockfd, &output_size, sizeof(output_size))) {
  1046. break;
  1047. }
  1048. if (!send_data(sockfd, output.data(), output_size)) {
  1049. break;
  1050. }
  1051. }
  1052. }
  1053. void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
  1054. std::string host;
  1055. int port;
  1056. if (!parse_endpoint(endpoint, host, port)) {
  1057. return;
  1058. }
  1059. #ifdef _WIN32
  1060. {
  1061. WSADATA wsaData;
  1062. int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
  1063. if (res != 0) {
  1064. fprintf(stderr, "WSAStartup failed: %d\n", res);
  1065. return;
  1066. }
  1067. }
  1068. #endif
  1069. auto server_socket = create_server_socket(host.c_str(), port);
  1070. if (server_socket == nullptr) {
  1071. fprintf(stderr, "Failed to create server socket\n");
  1072. return;
  1073. }
  1074. while (true) {
  1075. auto client_socket = socket_accept(server_socket->fd);
  1076. if (client_socket == nullptr) {
  1077. fprintf(stderr, "Failed to accept client connection\n");
  1078. return;
  1079. }
  1080. printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
  1081. rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
  1082. printf("Client connection closed\n");
  1083. }
  1084. #ifdef _WIN32
  1085. WSACleanup();
  1086. #endif
  1087. }