|
|
@@ -92,12 +92,19 @@ enum rpc_cmd {
|
|
|
RPC_CMD_GET_DEVICE_MEMORY,
|
|
|
RPC_CMD_INIT_TENSOR,
|
|
|
RPC_CMD_GET_ALLOC_SIZE,
|
|
|
+ RPC_CMD_HELLO,
|
|
|
RPC_CMD_COUNT,
|
|
|
};
|
|
|
|
|
|
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
|
|
|
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
|
|
|
|
|
|
+struct rpc_msg_hello_rsp {
|
|
|
+ uint8_t major;
|
|
|
+ uint8_t minor;
|
|
|
+ uint8_t patch;
|
|
|
+};
|
|
|
+
|
|
|
struct rpc_msg_get_alloc_size_req {
|
|
|
rpc_tensor tensor;
|
|
|
};
|
|
|
@@ -400,6 +407,20 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
|
|
|
|
|
|
// RPC client-side implementation
|
|
|
|
|
|
+static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
|
|
|
+ rpc_msg_hello_rsp response;
|
|
|
+ bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
|
|
|
+ GGML_ASSERT(status);
|
|
|
+ if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
|
|
|
+ fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
|
|
|
+ fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
|
|
+ }
|
|
|
+ return true;
|
|
|
+}
|
|
|
+
|
|
|
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
|
|
static std::mutex mutex;
|
|
|
std::lock_guard<std::mutex> lock(mutex);
|
|
|
@@ -433,6 +454,9 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
|
|
if (sock == nullptr) {
|
|
|
return nullptr;
|
|
|
}
|
|
|
+ if (!check_server_version(sock)) {
|
|
|
+ return nullptr;
|
|
|
+ }
|
|
|
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
|
|
|
sockets[endpoint] = sock;
|
|
|
return sock;
|
|
|
@@ -818,6 +842,7 @@ public:
|
|
|
}
|
|
|
~rpc_server();
|
|
|
|
|
|
+ void hello(rpc_msg_hello_rsp & response);
|
|
|
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
|
|
|
void get_alignment(rpc_msg_get_alignment_rsp & response);
|
|
|
void get_max_size(rpc_msg_get_max_size_rsp & response);
|
|
|
@@ -846,6 +871,13 @@ private:
|
|
|
std::unordered_set<ggml_backend_buffer_t> buffers;
|
|
|
};
|
|
|
|
|
|
+void rpc_server::hello(rpc_msg_hello_rsp & response) {
|
|
|
+ response.major = RPC_PROTO_MAJOR_VERSION;
|
|
|
+ response.minor = RPC_PROTO_MINOR_VERSION;
|
|
|
+ response.patch = RPC_PROTO_PATCH_VERSION;
|
|
|
+ GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
|
|
|
+}
|
|
|
+
|
|
|
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
|
|
|
ggml_backend_buffer_type_t buft;
|
|
|
struct ggml_init_params params {
|
|
|
@@ -1271,8 +1303,24 @@ rpc_server::~rpc_server() {
|
|
|
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
|
|
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
|
|
|
rpc_server server(backend, cache_dir);
|
|
|
+ uint8_t cmd;
|
|
|
+ if (!recv_data(sockfd, &cmd, 1)) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ // the first command sent by the client must be HELLO
|
|
|
+ if (cmd != RPC_CMD_HELLO) {
|
|
|
+ fprintf(stderr, "Expected HELLO command, update client\n");
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ if (!recv_msg(sockfd, nullptr, 0)) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
+ rpc_msg_hello_rsp response;
|
|
|
+ server.hello(response);
|
|
|
+ if (!send_msg(sockfd, &response, sizeof(response))) {
|
|
|
+ return;
|
|
|
+ }
|
|
|
while (true) {
|
|
|
- uint8_t cmd;
|
|
|
if (!recv_data(sockfd, &cmd, 1)) {
|
|
|
break;
|
|
|
}
|
|
|
@@ -1282,6 +1330,10 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
|
|
break;
|
|
|
}
|
|
|
switch (cmd) {
|
|
|
+ case RPC_CMD_HELLO: {
|
|
|
+ // HELLO command is handled above
|
|
|
+ return;
|
|
|
+ }
|
|
|
case RPC_CMD_ALLOC_BUFFER: {
|
|
|
rpc_msg_alloc_buffer_req request;
|
|
|
if (!recv_msg(sockfd, &request, sizeof(request))) {
|