ggml-rpc.cpp 65 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813
  1. #include "ggml-rpc.h"
  2. #include "ggml-impl.h"
  3. #include "ggml-backend-impl.h"
  4. #include "ggml-cpp.h"
  5. #include <cinttypes>
  6. #include <string>
  7. #include <vector>
  8. #include <memory>
  9. #include <mutex>
  10. #include <unordered_map>
  11. #include <unordered_set>
  12. #ifdef _WIN32
  13. # define WIN32_LEAN_AND_MEAN
  14. # ifndef NOMINMAX
  15. # define NOMINMAX
  16. # endif
  17. # include <windows.h>
  18. # include <winsock2.h>
  19. #else
  20. # include <arpa/inet.h>
  21. # include <sys/socket.h>
  22. # include <sys/types.h>
  23. # include <netinet/in.h>
  24. # include <netinet/tcp.h>
  25. # include <netdb.h>
  26. # include <unistd.h>
  27. #endif
  28. #include <cstring>
  29. #include <fstream>
  30. #include <filesystem>
  31. namespace fs = std::filesystem;
  32. #ifdef _WIN32
  33. typedef SOCKET sockfd_t;
  34. using ssize_t = __int64;
  35. #else
  36. typedef int sockfd_t;
  37. #endif
  38. // cross-platform socket
  39. struct socket_t {
  40. sockfd_t fd;
  41. socket_t(sockfd_t fd) : fd(fd) {}
  42. ~socket_t() {
  43. GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
  44. #ifdef _WIN32
  45. closesocket(this->fd);
  46. #else
  47. close(this->fd);
  48. #endif
  49. }
  50. };
  51. // all RPC structures must be packed
  52. #pragma pack(push, 1)
  53. // ggml_tensor is serialized into rpc_tensor
  54. struct rpc_tensor {
  55. uint64_t id;
  56. uint32_t type;
  57. uint64_t buffer;
  58. uint32_t ne[GGML_MAX_DIMS];
  59. uint32_t nb[GGML_MAX_DIMS];
  60. uint32_t op;
  61. int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
  62. int32_t flags;
  63. uint64_t src[GGML_MAX_SRC];
  64. uint64_t view_src;
  65. uint64_t view_offs;
  66. uint64_t data;
  67. char name[GGML_MAX_NAME];
  68. char padding[4];
  69. };
  70. static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
  71. // RPC commands
  72. enum rpc_cmd {
  73. RPC_CMD_ALLOC_BUFFER = 0,
  74. RPC_CMD_GET_ALIGNMENT,
  75. RPC_CMD_GET_MAX_SIZE,
  76. RPC_CMD_BUFFER_GET_BASE,
  77. RPC_CMD_FREE_BUFFER,
  78. RPC_CMD_BUFFER_CLEAR,
  79. RPC_CMD_SET_TENSOR,
  80. RPC_CMD_SET_TENSOR_HASH,
  81. RPC_CMD_GET_TENSOR,
  82. RPC_CMD_COPY_TENSOR,
  83. RPC_CMD_GRAPH_COMPUTE,
  84. RPC_CMD_GET_DEVICE_MEMORY,
  85. RPC_CMD_INIT_TENSOR,
  86. RPC_CMD_GET_ALLOC_SIZE,
  87. RPC_CMD_HELLO,
  88. RPC_CMD_COUNT,
  89. };
  90. // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
  91. const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
  92. struct rpc_msg_hello_rsp {
  93. uint8_t major;
  94. uint8_t minor;
  95. uint8_t patch;
  96. };
  97. struct rpc_msg_get_alloc_size_req {
  98. rpc_tensor tensor;
  99. };
  100. struct rpc_msg_get_alloc_size_rsp {
  101. uint64_t alloc_size;
  102. };
  103. struct rpc_msg_init_tensor_req {
  104. rpc_tensor tensor;
  105. };
  106. struct rpc_msg_alloc_buffer_req {
  107. uint64_t size;
  108. };
  109. struct rpc_msg_alloc_buffer_rsp {
  110. uint64_t remote_ptr;
  111. uint64_t remote_size;
  112. };
  113. struct rpc_msg_get_alignment_rsp {
  114. uint64_t alignment;
  115. };
  116. struct rpc_msg_get_max_size_rsp {
  117. uint64_t max_size;
  118. };
  119. struct rpc_msg_buffer_get_base_req {
  120. uint64_t remote_ptr;
  121. };
  122. struct rpc_msg_buffer_get_base_rsp {
  123. uint64_t base_ptr;
  124. };
  125. struct rpc_msg_free_buffer_req {
  126. uint64_t remote_ptr;
  127. };
  128. struct rpc_msg_buffer_clear_req {
  129. uint64_t remote_ptr;
  130. uint8_t value;
  131. };
  132. struct rpc_msg_set_tensor_hash_req {
  133. rpc_tensor tensor;
  134. uint64_t offset;
  135. uint64_t hash;
  136. };
  137. struct rpc_msg_set_tensor_hash_rsp {
  138. uint8_t result;
  139. };
  140. struct rpc_msg_get_tensor_req {
  141. rpc_tensor tensor;
  142. uint64_t offset;
  143. uint64_t size;
  144. };
  145. struct rpc_msg_copy_tensor_req {
  146. rpc_tensor src;
  147. rpc_tensor dst;
  148. };
  149. struct rpc_msg_copy_tensor_rsp {
  150. uint8_t result;
  151. };
  152. struct rpc_msg_graph_compute_rsp {
  153. uint8_t result;
  154. };
  155. struct rpc_msg_get_device_memory_rsp {
  156. uint64_t free_mem;
  157. uint64_t total_mem;
  158. };
  159. #pragma pack(pop)
  160. // RPC data structures
  161. static ggml_guid_t ggml_backend_rpc_guid() {
  162. static ggml_guid guid = {0x99, 0x68, 0x5b, 0x6c, 0xd2, 0x83, 0x3d, 0x24, 0x25, 0x36, 0x72, 0xe1, 0x5b, 0x0e, 0x14, 0x03};
  163. return &guid;
  164. }
  165. struct ggml_backend_rpc_buffer_type_context {
  166. std::string endpoint;
  167. std::string name;
  168. size_t alignment;
  169. size_t max_size;
  170. };
  171. struct ggml_backend_rpc_context {
  172. std::string endpoint;
  173. std::string name;
  174. };
  175. struct ggml_backend_rpc_buffer_context {
  176. std::shared_ptr<socket_t> sock;
  177. void * base_ptr;
  178. uint64_t remote_ptr;
  179. };
  180. // RPC helper functions
  181. // Computes FNV-1a hash of the data
  182. static uint64_t fnv_hash(const uint8_t * data, size_t len) {
  183. const uint64_t fnv_prime = 0x100000001b3ULL;
  184. uint64_t hash = 0xcbf29ce484222325ULL;
  185. for (size_t i = 0; i < len; ++i) {
  186. hash ^= data[i];
  187. hash *= fnv_prime;
  188. }
  189. return hash;
  190. }
  191. static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
  192. #ifdef _WIN32
  193. if (fd == INVALID_SOCKET) {
  194. return nullptr;
  195. }
  196. #else
  197. if (fd < 0) {
  198. return nullptr;
  199. }
  200. #endif
  201. return std::make_shared<socket_t>(fd);
  202. }
  203. static bool set_no_delay(sockfd_t sockfd) {
  204. int flag = 1;
  205. // set TCP_NODELAY to disable Nagle's algorithm
  206. int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int));
  207. return ret == 0;
  208. }
  209. static bool set_reuse_addr(sockfd_t sockfd) {
  210. int flag = 1;
  211. int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int));
  212. return ret == 0;
  213. }
  214. static std::shared_ptr<socket_t> socket_connect(const char * host, int port) {
  215. struct sockaddr_in addr;
  216. auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
  217. auto sock_ptr = make_socket(sockfd);
  218. if (sock_ptr == nullptr) {
  219. return nullptr;
  220. }
  221. if (!set_no_delay(sockfd)) {
  222. fprintf(stderr, "Failed to set TCP_NODELAY\n");
  223. return nullptr;
  224. }
  225. addr.sin_family = AF_INET;
  226. addr.sin_port = htons(port);
  227. struct hostent * server = gethostbyname(host);
  228. if (server == NULL) {
  229. fprintf(stderr, "Cannot resolve host '%s'\n", host);
  230. return nullptr;
  231. }
  232. memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length);
  233. if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) {
  234. return nullptr;
  235. }
  236. return sock_ptr;
  237. }
  238. static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) {
  239. auto client_socket_fd = accept(srv_sockfd, NULL, NULL);
  240. auto client_socket = make_socket(client_socket_fd);
  241. if (client_socket == nullptr) {
  242. return nullptr;
  243. }
  244. if (!set_no_delay(client_socket_fd)) {
  245. fprintf(stderr, "Failed to set TCP_NODELAY\n");
  246. return nullptr;
  247. }
  248. return client_socket;
  249. }
  250. static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) {
  251. auto sockfd = socket(AF_INET, SOCK_STREAM, 0);
  252. auto sock = make_socket(sockfd);
  253. if (sock == nullptr) {
  254. return nullptr;
  255. }
  256. if (!set_reuse_addr(sockfd)) {
  257. fprintf(stderr, "Failed to set SO_REUSEADDR\n");
  258. return nullptr;
  259. }
  260. if (inet_addr(host) == INADDR_NONE) {
  261. fprintf(stderr, "Invalid host address: %s\n", host);
  262. return nullptr;
  263. }
  264. struct sockaddr_in serv_addr;
  265. serv_addr.sin_family = AF_INET;
  266. serv_addr.sin_addr.s_addr = inet_addr(host);
  267. serv_addr.sin_port = htons(port);
  268. if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) {
  269. return nullptr;
  270. }
  271. if (listen(sockfd, 1) < 0) {
  272. return nullptr;
  273. }
  274. return sock;
  275. }
  276. static bool send_data(sockfd_t sockfd, const void * data, size_t size) {
  277. size_t bytes_sent = 0;
  278. while (bytes_sent < size) {
  279. ssize_t n = send(sockfd, (const char *)data + bytes_sent, size - bytes_sent, 0);
  280. if (n < 0) {
  281. return false;
  282. }
  283. bytes_sent += n;
  284. }
  285. return true;
  286. }
  287. static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
  288. size_t bytes_recv = 0;
  289. while (bytes_recv < size) {
  290. ssize_t n = recv(sockfd, (char *)data + bytes_recv, size - bytes_recv, 0);
  291. if (n <= 0) {
  292. return false;
  293. }
  294. bytes_recv += n;
  295. }
  296. return true;
  297. }
  298. static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
  299. if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
  300. return false;
  301. }
  302. return send_data(sockfd, msg, msg_size);
  303. }
  304. static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
  305. uint64_t size;
  306. if (!recv_data(sockfd, &size, sizeof(size))) {
  307. return false;
  308. }
  309. if (size != msg_size) {
  310. return false;
  311. }
  312. return recv_data(sockfd, msg, msg_size);
  313. }
  314. static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) {
  315. uint64_t size;
  316. if (!recv_data(sockfd, &size, sizeof(size))) {
  317. return false;
  318. }
  319. try {
  320. input.resize(size);
  321. } catch (const std::bad_alloc & e) {
  322. fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", size);
  323. return false;
  324. }
  325. return recv_data(sockfd, input.data(), size);
  326. }
  327. static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
  328. size_t pos = endpoint.find(':');
  329. if (pos == std::string::npos) {
  330. return false;
  331. }
  332. host = endpoint.substr(0, pos);
  333. port = std::stoi(endpoint.substr(pos + 1));
  334. return true;
  335. }
  336. // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
  337. // No response
  338. static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size) {
  339. uint8_t cmd_byte = cmd;
  340. if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
  341. return false;
  342. }
  343. if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
  344. return false;
  345. }
  346. if (!send_data(sock->fd, input, input_size)) {
  347. return false;
  348. }
  349. return true;
  350. }
  351. // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
  352. // RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
  353. static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
  354. if (!send_rpc_cmd(sock, cmd, input, input_size)) {
  355. return false;
  356. }
  357. // TODO: currently the output_size is always known, do we need support for commands with variable output size?
  358. // even if we do, we can skip sending output_size from the server for commands with known output size
  359. uint64_t out_size;
  360. if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
  361. return false;
  362. }
  363. if (out_size != output_size) {
  364. return false;
  365. }
  366. if (!recv_data(sock->fd, output, output_size)) {
  367. return false;
  368. }
  369. return true;
  370. }
  371. // RPC client-side implementation
  372. static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
  373. rpc_msg_hello_rsp response;
  374. bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
  375. GGML_ASSERT(status);
  376. if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
  377. fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
  378. return false;
  379. }
  380. if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
  381. fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
  382. }
  383. return true;
  384. }
  385. static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
  386. static std::mutex mutex;
  387. std::lock_guard<std::mutex> lock(mutex);
  388. static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
  389. static bool initialized = false;
  390. auto it = sockets.find(endpoint);
  391. if (it != sockets.end()) {
  392. if (auto sock = it->second.lock()) {
  393. return sock;
  394. }
  395. }
  396. std::string host;
  397. int port;
  398. if (!parse_endpoint(endpoint, host, port)) {
  399. return nullptr;
  400. }
  401. #ifdef _WIN32
  402. if (!initialized) {
  403. WSADATA wsaData;
  404. int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
  405. if (res != 0) {
  406. return nullptr;
  407. }
  408. initialized = true;
  409. }
  410. #else
  411. GGML_UNUSED(initialized);
  412. #endif
  413. auto sock = socket_connect(host.c_str(), port);
  414. if (sock == nullptr) {
  415. return nullptr;
  416. }
  417. if (!check_server_version(sock)) {
  418. return nullptr;
  419. }
  420. GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
  421. sockets[endpoint] = sock;
  422. return sock;
  423. }
  424. static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
  425. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  426. rpc_msg_free_buffer_req request = {ctx->remote_ptr};
  427. bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
  428. GGML_ASSERT(status);
  429. delete ctx;
  430. }
  431. static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
  432. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  433. if (ctx->base_ptr != nullptr) {
  434. return ctx->base_ptr;
  435. }
  436. rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
  437. rpc_msg_buffer_get_base_rsp response;
  438. bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
  439. GGML_ASSERT(status);
  440. ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
  441. return ctx->base_ptr;
  442. }
  443. static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
  444. rpc_tensor result;
  445. result.id = reinterpret_cast<uint64_t>(tensor);
  446. result.type = tensor->type;
  447. if (tensor->buffer) {
  448. ggml_backend_buffer_t buffer = tensor->buffer;
  449. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  450. result.buffer = ctx->remote_ptr;
  451. } else {
  452. result.buffer = 0;
  453. }
  454. for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
  455. result.ne[i] = tensor->ne[i];
  456. result.nb[i] = tensor->nb[i];
  457. }
  458. result.op = tensor->op;
  459. for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
  460. result.op_params[i] = tensor->op_params[i];
  461. }
  462. result.flags = tensor->flags;
  463. for (uint32_t i = 0; i < GGML_MAX_SRC; i++) {
  464. result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]);
  465. }
  466. result.view_src = reinterpret_cast<uint64_t>(tensor->view_src);
  467. result.view_offs = tensor->view_offs;
  468. result.data = reinterpret_cast<uint64_t>(tensor->data);
  469. // Avoid sending uninitialized data over the wire
  470. memset(result.name, 0, sizeof(result.name));
  471. memset(result.padding, 0, sizeof(result.padding));
  472. snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name);
  473. return result;
  474. }
  475. static enum ggml_status ggml_backend_rpc_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
  476. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  477. // CUDA backend on the server pads everything to 512 due to CUDA limitations.
  478. // Due to bandwidth constraints, we only call the server init tensor functions if necessary.
  479. // In particular, only quantized tensors need padding
  480. if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
  481. rpc_msg_init_tensor_req request;
  482. request.tensor = serialize_tensor(tensor);
  483. bool status = send_rpc_cmd(ctx->sock, RPC_CMD_INIT_TENSOR, &request, sizeof(request), nullptr, 0);
  484. GGML_ASSERT(status);
  485. }
  486. return GGML_STATUS_SUCCESS;
  487. }
  488. static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
  489. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  490. rpc_tensor rpc_tensor = serialize_tensor(tensor);
  491. if (size > HASH_THRESHOLD) {
  492. rpc_msg_set_tensor_hash_req request;
  493. request.tensor = rpc_tensor;
  494. request.offset = offset;
  495. request.hash = fnv_hash((const uint8_t*)data, size);
  496. rpc_msg_set_tensor_hash_rsp response;
  497. bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR_HASH, &request, sizeof(request), &response, sizeof(response));
  498. GGML_ASSERT(status);
  499. if (response.result) {
  500. // the server has the same data, no need to send it
  501. return;
  502. }
  503. }
  504. // input serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes)
  505. size_t input_size = sizeof(rpc_tensor) + sizeof(uint64_t) + size;
  506. std::vector<uint8_t> input(input_size, 0);
  507. memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
  508. memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
  509. memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
  510. bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size());
  511. GGML_ASSERT(status);
  512. }
  513. static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
  514. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  515. rpc_msg_get_tensor_req request;
  516. request.tensor = serialize_tensor(tensor);
  517. request.offset = offset;
  518. request.size = size;
  519. bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
  520. GGML_ASSERT(status);
  521. }
  522. static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
  523. // check if src and dst are on the same server
  524. ggml_backend_buffer_t src_buffer = src->buffer;
  525. ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
  526. ggml_backend_buffer_t dst_buffer = dst->buffer;
  527. ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
  528. if (src_ctx->sock != dst_ctx->sock) {
  529. return false;
  530. }
  531. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  532. rpc_msg_copy_tensor_req request;
  533. request.src = serialize_tensor(src);
  534. request.dst = serialize_tensor(dst);
  535. rpc_msg_copy_tensor_rsp response;
  536. bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
  537. GGML_ASSERT(status);
  538. return response.result;
  539. }
  540. static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
  541. ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
  542. rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
  543. bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
  544. GGML_ASSERT(status);
  545. }
  546. static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = {
  547. /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer,
  548. /* .get_base = */ ggml_backend_rpc_buffer_get_base,
  549. /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor,
  550. /* .memset_tensor = */ NULL,
  551. /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor,
  552. /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor,
  553. /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor,
  554. /* .clear = */ ggml_backend_rpc_buffer_clear,
  555. /* .reset = */ NULL,
  556. };
  557. static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t buft) {
  558. ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
  559. return buft_ctx->name.c_str();
  560. }
  561. static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
  562. ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
  563. rpc_msg_alloc_buffer_req request = {size};
  564. rpc_msg_alloc_buffer_rsp response;
  565. auto sock = get_socket(buft_ctx->endpoint);
  566. bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
  567. GGML_ASSERT(status);
  568. if (response.remote_ptr != 0) {
  569. ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
  570. ggml_backend_rpc_buffer_interface,
  571. new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr},
  572. response.remote_size);
  573. return buffer;
  574. } else {
  575. return nullptr;
  576. }
  577. }
  578. static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
  579. rpc_msg_get_alignment_rsp response;
  580. bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
  581. GGML_ASSERT(status);
  582. return response.alignment;
  583. }
  584. static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
  585. ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
  586. return buft_ctx->alignment;
  587. }
  588. static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
  589. rpc_msg_get_max_size_rsp response;
  590. bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
  591. GGML_ASSERT(status);
  592. return response.max_size;
  593. }
  594. static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
  595. ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
  596. return buft_ctx->max_size;
  597. }
  598. static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
  599. // See comments in init_tensor.
  600. if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
  601. ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
  602. auto sock = get_socket(buft_ctx->endpoint);
  603. rpc_msg_get_alloc_size_req request;
  604. request.tensor = serialize_tensor(tensor);
  605. rpc_msg_get_alloc_size_rsp response;
  606. bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
  607. GGML_ASSERT(status);
  608. return response.alloc_size;
  609. } else {
  610. return ggml_nbytes(tensor);
  611. }
  612. }
  613. static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
  614. /* .get_name = */ ggml_backend_rpc_buffer_type_name,
  615. /* .alloc_buffer = */ ggml_backend_rpc_buffer_type_alloc_buffer,
  616. /* .get_alignment = */ ggml_backend_rpc_buffer_type_get_alignment,
  617. /* .get_max_size = */ ggml_backend_rpc_get_max_size,
  618. /* .get_alloc_size = */ ggml_backend_rpc_buffer_type_get_alloc_size,
  619. /* .is_host = */ NULL,
  620. };
  621. static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
  622. ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
  623. return rpc_ctx->name.c_str();
  624. }
  625. static void ggml_backend_rpc_free(ggml_backend_t backend) {
  626. ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
  627. delete rpc_ctx;
  628. delete backend;
  629. }
  630. static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
  631. GGML_UNUSED(backend);
  632. // this is no-op because we don't have any async operations
  633. }
  634. static void add_tensor(ggml_tensor * tensor, std::vector<rpc_tensor> & tensors, std::unordered_set<ggml_tensor*> & visited) {
  635. if (tensor == nullptr) {
  636. return;
  637. }
  638. if (visited.find(tensor) != visited.end()) {
  639. return;
  640. }
  641. visited.insert(tensor);
  642. for (int i = 0; i < GGML_MAX_SRC; i++) {
  643. add_tensor(tensor->src[i], tensors, visited);
  644. }
  645. add_tensor(tensor->view_src, tensors, visited);
  646. tensors.push_back(serialize_tensor(tensor));
  647. }
  648. static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) {
  649. uint32_t n_nodes = cgraph->n_nodes;
  650. std::vector<rpc_tensor> tensors;
  651. std::unordered_set<ggml_tensor*> visited;
  652. for (uint32_t i = 0; i < n_nodes; i++) {
  653. add_tensor(cgraph->nodes[i], tensors, visited);
  654. }
  655. // serialization format:
  656. // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
  657. uint32_t n_tensors = tensors.size();
  658. int output_size = sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(rpc_tensor);
  659. output.resize(output_size, 0);
  660. memcpy(output.data(), &n_nodes, sizeof(n_nodes));
  661. for (uint32_t i = 0; i < n_nodes; i++) {
  662. memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t));
  663. }
  664. uint32_t * out_ntensors = (uint32_t *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t));
  665. *out_ntensors = n_tensors;
  666. rpc_tensor * out_tensors = (rpc_tensor *)(output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t));
  667. memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
  668. }
  669. static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
  670. ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
  671. std::vector<uint8_t> input;
  672. serialize_graph(cgraph, input);
  673. rpc_msg_graph_compute_rsp response;
  674. auto sock = get_socket(rpc_ctx->endpoint);
  675. bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
  676. GGML_ASSERT(status);
  677. return (enum ggml_status)response.result;
  678. }
  679. static ggml_backend_i ggml_backend_rpc_interface = {
  680. /* .get_name = */ ggml_backend_rpc_name,
  681. /* .free = */ ggml_backend_rpc_free,
  682. /* .set_tensor_async = */ NULL,
  683. /* .get_tensor_async = */ NULL,
  684. /* .cpy_tensor_async = */ NULL,
  685. /* .synchronize = */ ggml_backend_rpc_synchronize,
  686. /* .graph_plan_create = */ NULL,
  687. /* .graph_plan_free = */ NULL,
  688. /* .graph_plan_update = */ NULL,
  689. /* .graph_plan_compute = */ NULL,
  690. /* .graph_compute = */ ggml_backend_rpc_graph_compute,
  691. /* .event_record = */ NULL,
  692. /* .event_wait = */ NULL,
  693. };
  694. ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
  695. static std::mutex mutex;
  696. std::lock_guard<std::mutex> lock(mutex);
  697. // NOTE: buffer types are allocated and never freed; this is by design
  698. static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
  699. auto it = buft_map.find(endpoint);
  700. if (it != buft_map.end()) {
  701. return it->second;
  702. }
  703. auto sock = get_socket(endpoint);
  704. if (sock == nullptr) {
  705. fprintf(stderr, "Failed to connect to %s\n", endpoint);
  706. return nullptr;
  707. }
  708. size_t alignment = get_alignment(sock);
  709. size_t max_size = get_max_size(sock);
  710. ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
  711. /* .endpoint = */ endpoint,
  712. /* .name = */ "RPC[" + std::string(endpoint) + "]",
  713. /* .alignment = */ alignment,
  714. /* .max_size = */ max_size
  715. };
  716. ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
  717. /* .iface = */ ggml_backend_rpc_buffer_type_interface,
  718. /* .device = */ ggml_backend_rpc_add_device(endpoint),
  719. /* .context = */ buft_ctx
  720. };
  721. buft_map[endpoint] = buft;
  722. return buft;
  723. }
  724. ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
  725. ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
  726. /* .endpoint = */ endpoint,
  727. /* .name = */ "RPC[" + std::string(endpoint) + "]",
  728. };
  729. ggml_backend_t backend = new ggml_backend {
  730. /* .guid = */ ggml_backend_rpc_guid(),
  731. /* .interface = */ ggml_backend_rpc_interface,
  732. /* .device = */ ggml_backend_rpc_add_device(endpoint),
  733. /* .context = */ ctx
  734. };
  735. return backend;
  736. }
  737. bool ggml_backend_is_rpc(ggml_backend_t backend) {
  738. return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_rpc_guid());
  739. }
  740. static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * free, size_t * total) {
  741. rpc_msg_get_device_memory_rsp response;
  742. bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
  743. GGML_ASSERT(status);
  744. *free = response.free_mem;
  745. *total = response.total_mem;
  746. }
  747. void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
  748. auto sock = get_socket(endpoint);
  749. if (sock == nullptr) {
  750. *free = 0;
  751. *total = 0;
  752. return;
  753. }
  754. get_device_memory(sock, free, total);
  755. }
  756. // RPC server-side implementation
  757. class rpc_server {
  758. public:
  759. rpc_server(ggml_backend_t backend, const char * cache_dir)
  760. : backend(backend), cache_dir(cache_dir) {
  761. }
  762. ~rpc_server();
  763. void hello(rpc_msg_hello_rsp & response);
  764. void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
  765. void get_alignment(rpc_msg_get_alignment_rsp & response);
  766. void get_max_size(rpc_msg_get_max_size_rsp & response);
  767. bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
  768. bool free_buffer(const rpc_msg_free_buffer_req & request);
  769. bool buffer_clear(const rpc_msg_buffer_clear_req & request);
  770. bool set_tensor(const std::vector<uint8_t> & input);
  771. bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
  772. bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
  773. bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
  774. bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
  775. bool init_tensor(const rpc_msg_init_tensor_req & request);
  776. bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
  777. private:
  778. bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
  779. ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
  780. ggml_tensor * create_node(uint64_t id,
  781. struct ggml_context * ctx,
  782. const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
  783. std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map);
  784. ggml_backend_t backend;
  785. const char * cache_dir;
  786. std::unordered_set<ggml_backend_buffer_t> buffers;
  787. };
  788. void rpc_server::hello(rpc_msg_hello_rsp & response) {
  789. response.major = RPC_PROTO_MAJOR_VERSION;
  790. response.minor = RPC_PROTO_MINOR_VERSION;
  791. response.patch = RPC_PROTO_PATCH_VERSION;
  792. GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
  793. }
  794. bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
  795. ggml_backend_buffer_type_t buft;
  796. struct ggml_init_params params {
  797. /*.mem_size =*/ ggml_tensor_overhead(),
  798. /*.mem_buffer =*/ NULL,
  799. /*.no_alloc =*/ true,
  800. };
  801. ggml_context_ptr ctx_ptr { ggml_init(params) };
  802. GGML_ASSERT(ctx_ptr != nullptr);
  803. ggml_context * ctx = ctx_ptr.get();
  804. ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
  805. if (tensor == nullptr) {
  806. GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
  807. return false;
  808. }
  809. if (tensor->buffer == nullptr) {
  810. //No buffer allocated.
  811. buft = ggml_backend_get_default_buffer_type(backend);
  812. } else {
  813. buft = tensor->buffer->buft;
  814. }
  815. response.alloc_size = ggml_backend_buft_get_alloc_size(buft,tensor);
  816. return true;
  817. }
  818. void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
  819. ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
  820. ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
  821. response.remote_ptr = 0;
  822. response.remote_size = 0;
  823. if (buffer != nullptr) {
  824. response.remote_ptr = reinterpret_cast<uint64_t>(buffer);
  825. response.remote_size = buffer->size;
  826. GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
  827. buffers.insert(buffer);
  828. } else {
  829. GGML_LOG_ERROR("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
  830. }
  831. }
  832. void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
  833. ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
  834. size_t alignment = ggml_backend_buft_get_alignment(buft);
  835. GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
  836. response.alignment = alignment;
  837. }
  838. void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
  839. ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
  840. size_t max_size = ggml_backend_buft_get_max_size(buft);
  841. GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
  842. response.max_size = max_size;
  843. }
  844. bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
  845. GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
  846. ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
  847. if (buffers.find(buffer) == buffers.end()) {
  848. GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
  849. return false;
  850. }
  851. void * base = ggml_backend_buffer_get_base(buffer);
  852. response.base_ptr = reinterpret_cast<uint64_t>(base);
  853. return true;
  854. }
  855. bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
  856. GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
  857. ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
  858. if (buffers.find(buffer) == buffers.end()) {
  859. GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
  860. return false;
  861. }
  862. ggml_backend_buffer_free(buffer);
  863. buffers.erase(buffer);
  864. return true;
  865. }
  866. bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
  867. GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
  868. ggml_backend_buffer_t buffer = reinterpret_cast<ggml_backend_buffer_t>(request.remote_ptr);
  869. if (buffers.find(buffer) == buffers.end()) {
  870. GGML_LOG_ERROR("[%s] buffer not found\n", __func__);
  871. return false;
  872. }
  873. ggml_backend_buffer_clear(buffer, request.value);
  874. return true;
  875. }
  876. ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor) {
  877. // Validate tensor type before using it
  878. if (tensor->type >= GGML_TYPE_COUNT) {
  879. GGML_LOG_ERROR("[%s] invalid tensor type received: %u\n", __func__, tensor->type);
  880. return nullptr;
  881. }
  882. ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type,
  883. tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
  884. // ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type
  885. if (result == nullptr) {
  886. GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type);
  887. return nullptr;
  888. }
  889. for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) {
  890. result->nb[i] = tensor->nb[i];
  891. }
  892. result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer);
  893. if (result->buffer && buffers.find(result->buffer) == buffers.end()) {
  894. result->buffer = nullptr;
  895. }
  896. if (result->buffer) {
  897. // require that the tensor data does not go beyond the buffer end
  898. uint64_t tensor_size = (uint64_t) ggml_nbytes(result);
  899. uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer);
  900. uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer);
  901. GGML_ASSERT(tensor->data + tensor_size >= tensor->data); // check for overflow
  902. GGML_ASSERT(tensor->data >= buffer_start && tensor->data + tensor_size <= buffer_start + buffer_size);
  903. }
  904. result->op = (ggml_op) tensor->op;
  905. for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) {
  906. result->op_params[i] = tensor->op_params[i];
  907. }
  908. result->flags = tensor->flags;
  909. result->data = reinterpret_cast<void *>(tensor->data);
  910. ggml_set_name(result, tensor->name);
  911. return result;
  912. }
  913. bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
  914. // serialization format: | rpc_tensor | offset (8 bytes) | data (size bytes) |
  915. if (input.size() < sizeof(rpc_tensor) + sizeof(uint64_t)) {
  916. return false;
  917. }
  918. const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
  919. uint64_t offset;
  920. memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
  921. const size_t size = input.size() - sizeof(rpc_tensor) - sizeof(offset);
  922. struct ggml_init_params params {
  923. /*.mem_size =*/ ggml_tensor_overhead(),
  924. /*.mem_buffer =*/ NULL,
  925. /*.no_alloc =*/ true,
  926. };
  927. ggml_context_ptr ctx_ptr { ggml_init(params) };
  928. GGML_ASSERT(ctx_ptr != nullptr);
  929. ggml_context * ctx = ctx_ptr.get();
  930. ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
  931. if (tensor == nullptr) {
  932. GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
  933. return false;
  934. }
  935. GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
  936. // sanitize tensor->data
  937. {
  938. const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
  939. const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
  940. if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
  941. GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu) out of buffer bounds [0x%zx, 0x%zx)\n",
  942. __func__, in_tensor->data, offset, size, p0, p1);
  943. return false;
  944. }
  945. }
  946. const void * data = input.data() + sizeof(rpc_tensor) + sizeof(offset);
  947. if (cache_dir && size > HASH_THRESHOLD) {
  948. uint64_t hash = fnv_hash((const uint8_t*)data, size);
  949. char hash_str[17];
  950. snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
  951. // save to cache_dir/hash_str
  952. fs::path cache_file = fs::path(cache_dir) / hash_str;
  953. std::ofstream ofs(cache_file, std::ios::binary);
  954. ofs.write((const char *)data, size);
  955. printf("[%s] saved to '%s'\n", __func__, cache_file.c_str());
  956. }
  957. ggml_backend_tensor_set(tensor, data, offset, size);
  958. return true;
  959. }
  960. bool rpc_server::get_cached_file(uint64_t hash, std::vector<uint8_t> & data) {
  961. if (!cache_dir) {
  962. return false;
  963. }
  964. char hash_str[17];
  965. snprintf(hash_str, sizeof(hash_str), "%016" PRIx64, hash);
  966. fs::path cache_file = fs::path(cache_dir) / hash_str;
  967. if (!fs::exists(cache_file)) {
  968. return false;
  969. }
  970. std::ifstream ifs(cache_file, std::ios::binary);
  971. ifs.seekg(0, std::ios::end);
  972. size_t size = ifs.tellg();
  973. ifs.seekg(0, std::ios::beg);
  974. data.resize(size);
  975. ifs.read((char *)data.data(), size);
  976. return true;
  977. }
  978. bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response)
  979. {
  980. std::vector<uint8_t> cached_file;
  981. if (!get_cached_file(request.hash, cached_file)) {
  982. response.result = 0;
  983. return true;
  984. }
  985. size_t size = cached_file.size();
  986. struct ggml_init_params params {
  987. /*.mem_size =*/ ggml_tensor_overhead(),
  988. /*.mem_buffer =*/ NULL,
  989. /*.no_alloc =*/ true,
  990. };
  991. ggml_context_ptr ctx_ptr { ggml_init(params) };
  992. GGML_ASSERT(ctx_ptr != nullptr);
  993. ggml_context * ctx = ctx_ptr.get();
  994. ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
  995. if (tensor == nullptr) {
  996. GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
  997. return false;
  998. }
  999. GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %zu, hash: %" PRIx64 "\n",
  1000. __func__, (void*)tensor->buffer, tensor->data, request.offset, size, request.hash);
  1001. // sanitize tensor->data
  1002. {
  1003. const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
  1004. const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
  1005. if (request.tensor.data + request.offset < p0
  1006. || request.tensor.data + request.offset >= p1
  1007. || size > (p1 - request.tensor.data - request.offset)) {
  1008. GGML_LOG_ERROR("[%s] tensor data region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%zu, hash=0x%" PRIx64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
  1009. __func__, request.tensor.data, request.offset, size, request.hash, p0, p1);
  1010. return false;
  1011. }
  1012. }
  1013. ggml_backend_tensor_set(tensor, cached_file.data(), request.offset, size);
  1014. response.result = 1;
  1015. return true;
  1016. }
  1017. bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) {
  1018. struct ggml_init_params params {
  1019. /*.mem_size =*/ ggml_tensor_overhead(),
  1020. /*.mem_buffer =*/ NULL,
  1021. /*.no_alloc =*/ true,
  1022. };
  1023. ggml_context_ptr ctx_ptr { ggml_init(params) };
  1024. GGML_ASSERT(ctx_ptr != nullptr);
  1025. ggml_context * ctx = ctx_ptr.get();
  1026. ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
  1027. if (tensor == nullptr) {
  1028. GGML_LOG_ERROR("Null tensor pointer passed to server init_tensor function.\n");
  1029. return false;
  1030. }
  1031. // Call the backend's buffer_init_tensor function
  1032. ggml_backend_buffer_t buffer = tensor->buffer;
  1033. if (buffer && buffer->iface.init_tensor) {
  1034. buffer->iface.init_tensor(buffer, tensor);
  1035. } else {
  1036. GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n");
  1037. }
  1038. if (tensor->extra != nullptr) {
  1039. // This pointer can either be passed around client/server, or probably better stored server-side and kept track of.
  1040. // Currently unimplemented.
  1041. GGML_LOG_ERROR("tensor->extra populated by the backend, this is currently unsupported.\n");
  1042. return false;
  1043. }
  1044. return true;
  1045. }
  1046. bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response) {
  1047. struct ggml_init_params params {
  1048. /*.mem_size =*/ ggml_tensor_overhead(),
  1049. /*.mem_buffer =*/ NULL,
  1050. /*.no_alloc =*/ true,
  1051. };
  1052. ggml_context_ptr ctx_ptr { ggml_init(params) };
  1053. GGML_ASSERT(ctx_ptr != nullptr);
  1054. ggml_context * ctx = ctx_ptr.get();
  1055. ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
  1056. if (tensor == nullptr) {
  1057. GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
  1058. return false;
  1059. }
  1060. GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
  1061. // sanitize tensor->data
  1062. {
  1063. const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
  1064. const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
  1065. if (request.tensor.data + request.offset < p0 ||
  1066. request.tensor.data + request.offset >= p1 ||
  1067. request.size > (p1 - request.tensor.data - request.offset)) {
  1068. GGML_LOG_ERROR("[%s] requested tensor region (data=0x%" PRIx64 ", offset=%" PRIu64 ", size=%" PRIu64 ") out of buffer bounds [0x%zx, 0x%zx)\n",
  1069. __func__, request.tensor.data, request.offset, request.size, p0, p1);
  1070. return false;
  1071. }
  1072. }
  1073. response.resize(request.size, 0);
  1074. ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
  1075. return true;
  1076. }
  1077. bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) {
  1078. struct ggml_init_params params {
  1079. /*.mem_size =*/ 2*ggml_tensor_overhead(),
  1080. /*.mem_buffer =*/ NULL,
  1081. /*.no_alloc =*/ true,
  1082. };
  1083. ggml_context_ptr ctx_ptr { ggml_init(params) };
  1084. GGML_ASSERT(ctx_ptr != nullptr);
  1085. ggml_context * ctx = ctx_ptr.get();
  1086. ggml_tensor * src = deserialize_tensor(ctx, &request.src);
  1087. ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
  1088. if (src == nullptr || dst == nullptr) {
  1089. GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__);
  1090. return false;
  1091. }
  1092. uint64_t src_size = (uint64_t) ggml_nbytes(src);
  1093. uint64_t dst_data = (uint64_t) dst->data;
  1094. uint64_t dst_base = (uint64_t) ggml_backend_buffer_get_base(dst->buffer);
  1095. uint64_t dst_buf_sz = (uint64_t) ggml_backend_buffer_get_size(dst->buffer);
  1096. if (dst_data + src_size > dst_base + dst_buf_sz) {
  1097. GGML_PRINT_DEBUG("[%s] out-of-bounds write in rpc_server::copy_tensor:\n"
  1098. " write range : [0x%" PRIx64 ", 0x%" PRIx64 "]\n"
  1099. " buffer base: [0x%" PRIx64 ", 0x%" PRIx64 "]\n",
  1100. __func__,
  1101. dst_data,
  1102. dst_data + src_size,
  1103. dst_base,
  1104. dst_base + dst_buf_sz);
  1105. return false;
  1106. }
  1107. GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n",
  1108. __func__, (void*) src->buffer, (void*) dst->buffer);
  1109. response.result = ggml_backend_buffer_copy_tensor(src, dst);
  1110. return true;
  1111. }
  1112. ggml_tensor * rpc_server::create_node(uint64_t id,
  1113. struct ggml_context * ctx,
  1114. const std::unordered_map<uint64_t, const rpc_tensor*> & tensor_ptrs,
  1115. std::unordered_map<uint64_t, struct ggml_tensor*> & tensor_map) {
  1116. if (tensor_map.find(id) != tensor_map.end()) {
  1117. return tensor_map[id];
  1118. }
  1119. // Safely find the tensor pointer
  1120. auto it_ptr = tensor_ptrs.find(id);
  1121. if (it_ptr == tensor_ptrs.end()) {
  1122. return nullptr;
  1123. }
  1124. const rpc_tensor * tensor = it_ptr->second;
  1125. struct ggml_tensor * result = deserialize_tensor(ctx, tensor);
  1126. if (result == nullptr) {
  1127. return nullptr;
  1128. }
  1129. tensor_map[id] = result;
  1130. for (int i = 0; i < GGML_MAX_SRC; i++) {
  1131. // Check if the source ID is 0 before calling create_node recursively
  1132. if (tensor->src[i] == 0) {
  1133. result->src[i] = nullptr;
  1134. } else {
  1135. result->src[i] = create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map);
  1136. // If the recursive call failed for a non-zero ID, propagate the error
  1137. if (result->src[i] == nullptr) {
  1138. GGML_LOG_ERROR("[%s] failed to create source node %d (src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
  1139. __func__, i, tensor->src[i], id);
  1140. // Must return nullptr to signal failure up the call stack
  1141. return nullptr;
  1142. }
  1143. }
  1144. }
  1145. // Handle view_src similarly
  1146. if (tensor->view_src == 0) {
  1147. result->view_src = nullptr;
  1148. } else {
  1149. result->view_src = create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map);
  1150. // If the recursive call failed for a non-zero ID, propagate the error
  1151. if (result->view_src == nullptr) {
  1152. GGML_LOG_ERROR("[%s] failed to create view_src node (view_src_id=%" PRIu64 ") for node id %" PRIu64 "\n",
  1153. __func__, tensor->view_src, id);
  1154. // Must return nullptr to signal failure up the call stack
  1155. return nullptr;
  1156. }
  1157. }
  1158. result->view_offs = tensor->view_offs;
  1159. return result;
  1160. }
  1161. bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
  1162. // serialization format:
  1163. // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
  1164. if (input.size() < sizeof(uint32_t)) {
  1165. return false;
  1166. }
  1167. uint32_t n_nodes;
  1168. memcpy(&n_nodes, input.data(), sizeof(n_nodes));
  1169. if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t)) {
  1170. return false;
  1171. }
  1172. const uint64_t * nodes = (const uint64_t *)(input.data() + sizeof(n_nodes));
  1173. uint32_t n_tensors;
  1174. memcpy(&n_tensors, input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t), sizeof(n_tensors));
  1175. if (input.size() < sizeof(uint32_t) + n_nodes*sizeof(uint64_t) + sizeof(uint32_t) + n_tensors*sizeof(rpc_tensor)) {
  1176. return false;
  1177. }
  1178. const rpc_tensor * tensors = (const rpc_tensor *)(input.data() + sizeof(n_nodes) + n_nodes*sizeof(uint64_t) + sizeof(n_tensors));
  1179. GGML_PRINT_DEBUG("[%s] n_nodes: %u, n_tensors: %u\n", __func__, n_nodes, n_tensors);
  1180. size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false);
  1181. struct ggml_init_params params = {
  1182. /*.mem_size =*/ buf_size,
  1183. /*.mem_buffer =*/ NULL,
  1184. /*.no_alloc =*/ true,
  1185. };
  1186. ggml_context_ptr ctx_ptr { ggml_init(params) };
  1187. GGML_ASSERT(ctx_ptr != nullptr);
  1188. ggml_context * ctx = ctx_ptr.get();
  1189. struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
  1190. graph->n_nodes = n_nodes;
  1191. std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
  1192. for (uint32_t i = 0; i < n_tensors; i++) {
  1193. tensor_ptrs[tensors[i].id] = &tensors[i];
  1194. }
  1195. std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
  1196. for (uint32_t i = 0; i < n_nodes; i++) {
  1197. int64_t id;
  1198. memcpy(&id, &nodes[i], sizeof(id));
  1199. graph->nodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
  1200. // Check if create_node failed for a *non-zero* ID.
  1201. // If id was 0, create_node returning nullptr is expected.
  1202. // If id was non-zero and create_node returned nullptr, it indicates a deserialization error.
  1203. if (graph->nodes[i] == nullptr && id != 0) {
  1204. GGML_LOG_ERROR("[%s] failed to create graph node %d (id=%" PRId64 ")\n", __func__, i, id);
  1205. return false;
  1206. }
  1207. }
  1208. ggml_status status = ggml_backend_graph_compute(backend, graph);
  1209. response.result = status;
  1210. return true;
  1211. }
  1212. rpc_server::~rpc_server() {
  1213. for (auto buffer : buffers) {
  1214. ggml_backend_buffer_free(buffer);
  1215. }
  1216. }
  1217. static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
  1218. sockfd_t sockfd, size_t free_mem, size_t total_mem) {
  1219. rpc_server server(backend, cache_dir);
  1220. uint8_t cmd;
  1221. if (!recv_data(sockfd, &cmd, 1)) {
  1222. return;
  1223. }
  1224. // the first command sent by the client must be HELLO
  1225. if (cmd != RPC_CMD_HELLO) {
  1226. fprintf(stderr, "Expected HELLO command, update client\n");
  1227. return;
  1228. }
  1229. if (!recv_msg(sockfd, nullptr, 0)) {
  1230. return;
  1231. }
  1232. rpc_msg_hello_rsp response;
  1233. server.hello(response);
  1234. if (!send_msg(sockfd, &response, sizeof(response))) {
  1235. return;
  1236. }
  1237. while (true) {
  1238. if (!recv_data(sockfd, &cmd, 1)) {
  1239. break;
  1240. }
  1241. if (cmd >= RPC_CMD_COUNT) {
  1242. // fail fast if the command is invalid
  1243. fprintf(stderr, "Unknown command: %d\n", cmd);
  1244. break;
  1245. }
  1246. switch (cmd) {
  1247. case RPC_CMD_HELLO: {
  1248. // HELLO command is handled above
  1249. return;
  1250. }
  1251. case RPC_CMD_ALLOC_BUFFER: {
  1252. rpc_msg_alloc_buffer_req request;
  1253. if (!recv_msg(sockfd, &request, sizeof(request))) {
  1254. return;
  1255. }
  1256. rpc_msg_alloc_buffer_rsp response;
  1257. server.alloc_buffer(request, response);
  1258. if (!send_msg(sockfd, &response, sizeof(response))) {
  1259. return;
  1260. }
  1261. break;
  1262. }
  1263. case RPC_CMD_GET_ALLOC_SIZE: {
  1264. rpc_msg_get_alloc_size_req request;
  1265. if (!recv_msg(sockfd, &request, sizeof(request))) {
  1266. return;
  1267. }
  1268. rpc_msg_get_alloc_size_rsp response;
  1269. if (!server.get_alloc_size(request, response)) {
  1270. return;
  1271. }
  1272. if (!send_msg(sockfd, &response, sizeof(response))) {
  1273. return;
  1274. }
  1275. break;
  1276. }
  1277. case RPC_CMD_GET_ALIGNMENT: {
  1278. if (!recv_msg(sockfd, nullptr, 0)) {
  1279. return;
  1280. }
  1281. rpc_msg_get_alignment_rsp response;
  1282. server.get_alignment(response);
  1283. if (!send_msg(sockfd, &response, sizeof(response))) {
  1284. return;
  1285. }
  1286. break;
  1287. }
  1288. case RPC_CMD_GET_MAX_SIZE: {
  1289. if (!recv_msg(sockfd, nullptr, 0)) {
  1290. return;
  1291. }
  1292. rpc_msg_get_max_size_rsp response;
  1293. server.get_max_size(response);
  1294. if (!send_msg(sockfd, &response, sizeof(response))) {
  1295. return;
  1296. }
  1297. break;
  1298. }
  1299. case RPC_CMD_BUFFER_GET_BASE: {
  1300. rpc_msg_buffer_get_base_req request;
  1301. if (!recv_msg(sockfd, &request, sizeof(request))) {
  1302. return;
  1303. }
  1304. rpc_msg_buffer_get_base_rsp response;
  1305. if (!server.buffer_get_base(request, response)) {
  1306. return;
  1307. }
  1308. if (!send_msg(sockfd, &response, sizeof(response))) {
  1309. return;
  1310. }
  1311. break;
  1312. }
  1313. case RPC_CMD_FREE_BUFFER: {
  1314. rpc_msg_free_buffer_req request;
  1315. if (!recv_msg(sockfd, &request, sizeof(request))) {
  1316. return;
  1317. }
  1318. if (!server.free_buffer(request)) {
  1319. return;
  1320. }
  1321. if (!send_msg(sockfd, nullptr, 0)) {
  1322. return;
  1323. }
  1324. break;
  1325. }
  1326. case RPC_CMD_BUFFER_CLEAR: {
  1327. rpc_msg_buffer_clear_req request;
  1328. if (!recv_msg(sockfd, &request, sizeof(request))) {
  1329. return;
  1330. }
  1331. if (!server.buffer_clear(request)) {
  1332. return;
  1333. }
  1334. if (!send_msg(sockfd, nullptr, 0)) {
  1335. return;
  1336. }
  1337. break;
  1338. }
  1339. case RPC_CMD_SET_TENSOR: {
  1340. std::vector<uint8_t> input;
  1341. if (!recv_msg(sockfd, input)) {
  1342. return;
  1343. }
  1344. if (!server.set_tensor(input)) {
  1345. return;
  1346. }
  1347. break;
  1348. }
  1349. case RPC_CMD_SET_TENSOR_HASH: {
  1350. rpc_msg_set_tensor_hash_req request;
  1351. if (!recv_msg(sockfd, &request, sizeof(request))) {
  1352. return;
  1353. }
  1354. rpc_msg_set_tensor_hash_rsp response;
  1355. if (!server.set_tensor_hash(request, response)) {
  1356. return;
  1357. }
  1358. if (!send_msg(sockfd, &response, sizeof(response))) {
  1359. return;
  1360. }
  1361. break;
  1362. }
  1363. case RPC_CMD_INIT_TENSOR: {
  1364. rpc_msg_init_tensor_req request;
  1365. if (!recv_msg(sockfd, &request,sizeof(request))) {
  1366. return;
  1367. }
  1368. if (!server.init_tensor(request)) {
  1369. return;
  1370. }
  1371. if (!send_msg(sockfd, nullptr, 0)) {
  1372. return;
  1373. }
  1374. break;
  1375. }
  1376. case RPC_CMD_GET_TENSOR: {
  1377. rpc_msg_get_tensor_req request;
  1378. if (!recv_msg(sockfd, &request, sizeof(request))) {
  1379. return;
  1380. }
  1381. std::vector<uint8_t> response;
  1382. if (!server.get_tensor(request, response)) {
  1383. return;
  1384. }
  1385. if (!send_msg(sockfd, response.data(), response.size())) {
  1386. return;
  1387. }
  1388. break;
  1389. }
  1390. case RPC_CMD_COPY_TENSOR: {
  1391. rpc_msg_copy_tensor_req request;
  1392. if (!recv_msg(sockfd, &request, sizeof(request))) {
  1393. return;
  1394. }
  1395. rpc_msg_copy_tensor_rsp response;
  1396. if (!server.copy_tensor(request, response)) {
  1397. return;
  1398. }
  1399. if (!send_msg(sockfd, &response, sizeof(response))) {
  1400. return;
  1401. }
  1402. break;
  1403. }
  1404. case RPC_CMD_GRAPH_COMPUTE: {
  1405. std::vector<uint8_t> input;
  1406. if (!recv_msg(sockfd, input)) {
  1407. return;
  1408. }
  1409. rpc_msg_graph_compute_rsp response;
  1410. if (!server.graph_compute(input, response)) {
  1411. return;
  1412. }
  1413. if (!send_msg(sockfd, &response, sizeof(response))) {
  1414. return;
  1415. }
  1416. break;
  1417. }
  1418. case RPC_CMD_GET_DEVICE_MEMORY: {
  1419. if (!recv_msg(sockfd, nullptr, 0)) {
  1420. return;
  1421. }
  1422. rpc_msg_get_device_memory_rsp response;
  1423. response.free_mem = free_mem;
  1424. response.total_mem = total_mem;
  1425. if (!send_msg(sockfd, &response, sizeof(response))) {
  1426. return;
  1427. }
  1428. break;
  1429. }
  1430. default: {
  1431. fprintf(stderr, "Unknown command: %d\n", cmd);
  1432. return;
  1433. }
  1434. }
  1435. }
  1436. }
  1437. void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint,
  1438. const char * cache_dir,
  1439. size_t free_mem, size_t total_mem) {
  1440. printf("Starting RPC server v%d.%d.%d\n",
  1441. RPC_PROTO_MAJOR_VERSION,
  1442. RPC_PROTO_MINOR_VERSION,
  1443. RPC_PROTO_PATCH_VERSION);
  1444. printf(" endpoint : %s\n", endpoint);
  1445. printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a");
  1446. printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024));
  1447. std::string host;
  1448. int port;
  1449. if (!parse_endpoint(endpoint, host, port)) {
  1450. return;
  1451. }
  1452. #ifdef _WIN32
  1453. {
  1454. WSADATA wsaData;
  1455. int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
  1456. if (res != 0) {
  1457. fprintf(stderr, "WSAStartup failed: %d\n", res);
  1458. return;
  1459. }
  1460. }
  1461. #endif
  1462. auto server_socket = create_server_socket(host.c_str(), port);
  1463. if (server_socket == nullptr) {
  1464. fprintf(stderr, "Failed to create server socket\n");
  1465. return;
  1466. }
  1467. while (true) {
  1468. auto client_socket = socket_accept(server_socket->fd);
  1469. if (client_socket == nullptr) {
  1470. fprintf(stderr, "Failed to accept client connection\n");
  1471. return;
  1472. }
  1473. printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
  1474. fflush(stdout);
  1475. rpc_serve_client(backend, cache_dir, client_socket->fd, free_mem, total_mem);
  1476. printf("Client connection closed\n");
  1477. fflush(stdout);
  1478. }
  1479. #ifdef _WIN32
  1480. WSACleanup();
  1481. #endif
  1482. }
  1483. // device interface
  1484. struct ggml_backend_rpc_device_context {
  1485. std::string endpoint;
  1486. std::string name;
  1487. };
  1488. static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
  1489. ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
  1490. return ctx->name.c_str();
  1491. }
  1492. static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
  1493. ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
  1494. return ctx->name.c_str();
  1495. }
  1496. static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
  1497. ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
  1498. ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
  1499. GGML_UNUSED(dev);
  1500. }
  1501. static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
  1502. // TODO: obtain value from the server
  1503. return GGML_BACKEND_DEVICE_TYPE_GPU;
  1504. GGML_UNUSED(dev);
  1505. }
  1506. static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
  1507. props->name = ggml_backend_rpc_device_get_name(dev);
  1508. props->description = ggml_backend_rpc_device_get_description(dev);
  1509. props->type = ggml_backend_rpc_device_get_type(dev);
  1510. ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
  1511. props->caps = {
  1512. /* .async = */ false,
  1513. /* .host_buffer = */ false,
  1514. /* .buffer_from_host_ptr = */ false,
  1515. /* .events = */ false,
  1516. };
  1517. }
  1518. static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
  1519. ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
  1520. return ggml_backend_rpc_init(ctx->endpoint.c_str());
  1521. GGML_UNUSED(params);
  1522. }
  1523. static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
  1524. ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
  1525. return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
  1526. GGML_UNUSED(dev);
  1527. }
  1528. static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
  1529. GGML_UNUSED(dev);
  1530. GGML_UNUSED(op);
  1531. //TODO: call the remote backend and cache the results
  1532. return true;
  1533. }
  1534. static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
  1535. if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
  1536. return false;
  1537. }
  1538. ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
  1539. ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
  1540. return buft_ctx->endpoint == dev_ctx->endpoint;
  1541. }
  1542. static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
  1543. /* .get_name = */ ggml_backend_rpc_device_get_name,
  1544. /* .get_description = */ ggml_backend_rpc_device_get_description,
  1545. /* .get_memory = */ ggml_backend_rpc_device_get_memory,
  1546. /* .get_type = */ ggml_backend_rpc_device_get_type,
  1547. /* .get_props = */ ggml_backend_rpc_device_get_props,
  1548. /* .init_backend = */ ggml_backend_rpc_device_init,
  1549. /* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type,
  1550. /* .get_host_buffer_type = */ NULL,
  1551. /* .buffer_from_host_ptr = */ NULL,
  1552. /* .supports_op = */ ggml_backend_rpc_device_supports_op,
  1553. /* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
  1554. /* .offload_op = */ NULL,
  1555. /* .event_new = */ NULL,
  1556. /* .event_free = */ NULL,
  1557. /* .event_synchronize = */ NULL,
  1558. };
  1559. // backend reg interface
  1560. static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
  1561. return "RPC";
  1562. GGML_UNUSED(reg);
  1563. }
  1564. static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
  1565. return 0;
  1566. GGML_UNUSED(reg);
  1567. }
  1568. static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
  1569. GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
  1570. GGML_UNUSED(reg);
  1571. GGML_UNUSED(index);
  1572. }
  1573. static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
  1574. if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
  1575. return (void *)ggml_backend_rpc_add_device;
  1576. }
  1577. if (std::strcmp(name, "ggml_backend_rpc_start_server") == 0) {
  1578. return (void *)ggml_backend_rpc_start_server;
  1579. }
  1580. return NULL;
  1581. GGML_UNUSED(reg);
  1582. }
  1583. static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
  1584. /* .get_name = */ ggml_backend_rpc_reg_get_name,
  1585. /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
  1586. /* .get_device = */ ggml_backend_rpc_reg_get_device,
  1587. /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
  1588. };
  1589. ggml_backend_reg_t ggml_backend_rpc_reg(void) {
  1590. static struct ggml_backend_reg ggml_backend_rpc_reg = {
  1591. /* .api_version = */ GGML_BACKEND_API_VERSION,
  1592. /* .iface = */ ggml_backend_rpc_reg_i,
  1593. /* .context = */ NULL,
  1594. };
  1595. return &ggml_backend_rpc_reg;
  1596. }
  1597. ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
  1598. static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
  1599. static std::mutex mutex;
  1600. std::lock_guard<std::mutex> lock(mutex);
  1601. if (dev_map.find(endpoint) != dev_map.end()) {
  1602. return dev_map[endpoint];
  1603. }
  1604. ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
  1605. /* .endpoint = */ endpoint,
  1606. /* .name = */ "RPC[" + std::string(endpoint) + "]",
  1607. };
  1608. ggml_backend_dev_t dev = new ggml_backend_device {
  1609. /* .iface = */ ggml_backend_rpc_device_i,
  1610. /* .reg = */ ggml_backend_rpc_reg(),
  1611. /* .context = */ ctx,
  1612. };
  1613. dev_map[endpoint] = dev;
  1614. return dev;
  1615. }
  1616. GGML_BACKEND_DL_IMPL(ggml_backend_rpc_reg)