|
@@ -6,6 +6,7 @@
|
|
|
#include <string>
|
|
#include <string>
|
|
|
#include <vector>
|
|
#include <vector>
|
|
|
#include <memory>
|
|
#include <memory>
|
|
|
|
|
+#include <mutex>
|
|
|
#include <unordered_map>
|
|
#include <unordered_map>
|
|
|
#include <unordered_set>
|
|
#include <unordered_set>
|
|
|
#ifdef _WIN32
|
|
#ifdef _WIN32
|
|
@@ -47,6 +48,7 @@ struct socket_t {
|
|
|
sockfd_t fd;
|
|
sockfd_t fd;
|
|
|
socket_t(sockfd_t fd) : fd(fd) {}
|
|
socket_t(sockfd_t fd) : fd(fd) {}
|
|
|
~socket_t() {
|
|
~socket_t() {
|
|
|
|
|
+ GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
|
|
|
#ifdef _WIN32
|
|
#ifdef _WIN32
|
|
|
closesocket(this->fd);
|
|
closesocket(this->fd);
|
|
|
#else
|
|
#else
|
|
@@ -97,7 +99,7 @@ static ggml_guid_t ggml_backend_rpc_guid() {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
struct ggml_backend_rpc_buffer_type_context {
|
|
struct ggml_backend_rpc_buffer_type_context {
|
|
|
- std::shared_ptr<socket_t> sock;
|
|
|
|
|
|
|
+ std::string endpoint;
|
|
|
std::string name;
|
|
std::string name;
|
|
|
size_t alignment;
|
|
size_t alignment;
|
|
|
size_t max_size;
|
|
size_t max_size;
|
|
@@ -106,8 +108,6 @@ struct ggml_backend_rpc_buffer_type_context {
|
|
|
struct ggml_backend_rpc_context {
|
|
struct ggml_backend_rpc_context {
|
|
|
std::string endpoint;
|
|
std::string endpoint;
|
|
|
std::string name;
|
|
std::string name;
|
|
|
- std::shared_ptr<socket_t> sock;
|
|
|
|
|
- ggml_backend_buffer_type_t buft;
|
|
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
struct ggml_backend_rpc_buffer_context {
|
|
struct ggml_backend_rpc_buffer_context {
|
|
@@ -231,14 +231,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
|
|
|
return true;
|
|
return true;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-static bool parse_endpoint(const char * endpoint, std::string & host, int & port) {
|
|
|
|
|
- std::string str(endpoint);
|
|
|
|
|
- size_t pos = str.find(':');
|
|
|
|
|
|
|
+static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
|
|
|
|
|
+ size_t pos = endpoint.find(':');
|
|
|
if (pos == std::string::npos) {
|
|
if (pos == std::string::npos) {
|
|
|
return false;
|
|
return false;
|
|
|
}
|
|
}
|
|
|
- host = str.substr(0, pos);
|
|
|
|
|
- port = std::stoi(str.substr(pos + 1));
|
|
|
|
|
|
|
+ host = endpoint.substr(0, pos);
|
|
|
|
|
+ port = std::stoi(endpoint.substr(pos + 1));
|
|
|
return true;
|
|
return true;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -273,6 +272,44 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
|
|
|
|
|
|
|
|
// RPC client-side implementation
|
|
// RPC client-side implementation
|
|
|
|
|
|
|
|
|
|
+static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
|
|
|
|
+ static std::mutex mutex;
|
|
|
|
|
+ std::lock_guard<std::mutex> lock(mutex);
|
|
|
|
|
+ static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
|
|
|
|
|
+ static bool initialized = false;
|
|
|
|
|
+
|
|
|
|
|
+ auto it = sockets.find(endpoint);
|
|
|
|
|
+ if (it != sockets.end()) {
|
|
|
|
|
+ if (auto sock = it->second.lock()) {
|
|
|
|
|
+ return sock;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ std::string host;
|
|
|
|
|
+ int port;
|
|
|
|
|
+ if (!parse_endpoint(endpoint, host, port)) {
|
|
|
|
|
+ return nullptr;
|
|
|
|
|
+ }
|
|
|
|
|
+#ifdef _WIN32
|
|
|
|
|
+ if (!initialized) {
|
|
|
|
|
+ WSADATA wsaData;
|
|
|
|
|
+ int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
|
|
|
|
|
+ if (res != 0) {
|
|
|
|
|
+ return nullptr;
|
|
|
|
|
+ }
|
|
|
|
|
+ initialized = true;
|
|
|
|
|
+ }
|
|
|
|
|
+#else
|
|
|
|
|
+ UNUSED(initialized);
|
|
|
|
|
+#endif
|
|
|
|
|
+ auto sock = socket_connect(host.c_str(), port);
|
|
|
|
|
+ if (sock == nullptr) {
|
|
|
|
|
+ return nullptr;
|
|
|
|
|
+ }
|
|
|
|
|
+ GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
|
|
|
|
|
+ sockets[endpoint] = sock;
|
|
|
|
|
+ return sock;
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
|
|
GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
|
|
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
|
return ctx->name.c_str();
|
|
return ctx->name.c_str();
|
|
@@ -442,7 +479,8 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
|
|
std::vector<uint8_t> input(input_size, 0);
|
|
std::vector<uint8_t> input(input_size, 0);
|
|
|
memcpy(input.data(), &size, sizeof(size));
|
|
memcpy(input.data(), &size, sizeof(size));
|
|
|
std::vector<uint8_t> output;
|
|
std::vector<uint8_t> output;
|
|
|
- bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output);
|
|
|
|
|
|
|
+ auto sock = get_socket(buft_ctx->endpoint);
|
|
|
|
|
+ bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
|
|
|
GGML_ASSERT(status);
|
|
GGML_ASSERT(status);
|
|
|
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
|
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
|
|
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
|
|
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
|
|
@@ -453,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
|
|
if (remote_ptr != 0) {
|
|
if (remote_ptr != 0) {
|
|
|
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
|
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
|
|
ggml_backend_rpc_buffer_interface,
|
|
ggml_backend_rpc_buffer_interface,
|
|
|
- new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
|
|
|
|
|
|
|
+ new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
|
|
|
remote_size);
|
|
remote_size);
|
|
|
return buffer;
|
|
return buffer;
|
|
|
} else {
|
|
} else {
|
|
@@ -508,7 +546,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend
|
|
|
}
|
|
}
|
|
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
|
|
- return buft_ctx->sock == rpc_ctx->sock;
|
|
|
|
|
|
|
+ return buft_ctx->endpoint == rpc_ctx->endpoint;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
|
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
|
@@ -521,7 +559,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
|
|
/* .is_host = */ NULL,
|
|
/* .is_host = */ NULL,
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
-
|
|
|
|
|
GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
|
|
GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
|
|
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
|
|
|
|
|
|
@@ -530,16 +567,13 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
|
|
|
|
|
|
|
|
GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
|
|
GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
|
|
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
|
|
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context;
|
|
|
|
|
- delete buft_ctx;
|
|
|
|
|
- delete rpc_ctx->buft;
|
|
|
|
|
delete rpc_ctx;
|
|
delete rpc_ctx;
|
|
|
delete backend;
|
|
delete backend;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
|
|
GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
|
|
|
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
|
|
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
|
|
|
- return ctx->buft;
|
|
|
|
|
|
|
+ return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
|
|
GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
|
|
@@ -590,7 +624,8 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
|
|
|
std::vector<uint8_t> input;
|
|
std::vector<uint8_t> input;
|
|
|
serialize_graph(cgraph, input);
|
|
serialize_graph(cgraph, input);
|
|
|
std::vector<uint8_t> output;
|
|
std::vector<uint8_t> output;
|
|
|
- bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output);
|
|
|
|
|
|
|
+ auto sock = get_socket(rpc_ctx->endpoint);
|
|
|
|
|
+ bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
|
|
|
GGML_ASSERT(status);
|
|
GGML_ASSERT(status);
|
|
|
GGML_ASSERT(output.size() == 1);
|
|
GGML_ASSERT(output.size() == 1);
|
|
|
return (enum ggml_status)output[0];
|
|
return (enum ggml_status)output[0];
|
|
@@ -624,65 +659,48 @@ static ggml_backend_i ggml_backend_rpc_interface = {
|
|
|
/* .event_synchronize = */ NULL,
|
|
/* .event_synchronize = */ NULL,
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
-static std::unordered_map<std::string, ggml_backend_t> instances;
|
|
|
|
|
-
|
|
|
|
|
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
|
|
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
|
|
|
- ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
|
|
|
|
|
- return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr;
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
|
|
|
|
|
- std::string endpoint_str(endpoint);
|
|
|
|
|
- if (instances.find(endpoint_str) != instances.end()) {
|
|
|
|
|
- return instances[endpoint_str];
|
|
|
|
|
- }
|
|
|
|
|
-#ifdef _WIN32
|
|
|
|
|
- {
|
|
|
|
|
- WSADATA wsaData;
|
|
|
|
|
- int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
|
|
|
|
|
- if (res != 0) {
|
|
|
|
|
- return nullptr;
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-#endif
|
|
|
|
|
- fprintf(stderr, "Connecting to %s\n", endpoint);
|
|
|
|
|
- std::string host;
|
|
|
|
|
- int port;
|
|
|
|
|
- if (!parse_endpoint(endpoint, host, port)) {
|
|
|
|
|
- return nullptr;
|
|
|
|
|
- }
|
|
|
|
|
- auto sock = socket_connect(host.c_str(), port);
|
|
|
|
|
|
|
+ static std::mutex mutex;
|
|
|
|
|
+ std::lock_guard<std::mutex> lock(mutex);
|
|
|
|
|
+ // NOTE: buffer types are allocated and never freed; this is by design
|
|
|
|
|
+ static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
|
|
|
|
|
+ auto it = buft_map.find(endpoint);
|
|
|
|
|
+ if (it != buft_map.end()) {
|
|
|
|
|
+ return it->second;
|
|
|
|
|
+ }
|
|
|
|
|
+ auto sock = get_socket(endpoint);
|
|
|
if (sock == nullptr) {
|
|
if (sock == nullptr) {
|
|
|
return nullptr;
|
|
return nullptr;
|
|
|
}
|
|
}
|
|
|
size_t alignment = get_alignment(sock);
|
|
size_t alignment = get_alignment(sock);
|
|
|
size_t max_size = get_max_size(sock);
|
|
size_t max_size = get_max_size(sock);
|
|
|
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
|
|
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
|
|
|
- /* .sock = */ sock,
|
|
|
|
|
- /* .name = */ "RPC" + std::to_string(sock->fd),
|
|
|
|
|
|
|
+ /* .endpoint = */ endpoint,
|
|
|
|
|
+ /* .name = */ "RPC[" + std::string(endpoint) + "]",
|
|
|
/* .alignment = */ alignment,
|
|
/* .alignment = */ alignment,
|
|
|
- /* .max_size = */ max_size
|
|
|
|
|
|
|
+ /* .max_size = */ max_size
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
|
|
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
|
|
|
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
|
|
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
|
|
|
/* .context = */ buft_ctx
|
|
/* .context = */ buft_ctx
|
|
|
};
|
|
};
|
|
|
|
|
+ buft_map[endpoint] = buft;
|
|
|
|
|
+ return buft;
|
|
|
|
|
+}
|
|
|
|
|
|
|
|
|
|
+GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
|
|
|
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
|
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
|
|
- /* .endpoint = */ endpoint,
|
|
|
|
|
- /* .name = */ "RPC" + std::to_string(sock->fd),
|
|
|
|
|
- /* .sock = */ sock,
|
|
|
|
|
- /* .buft = */ buft
|
|
|
|
|
|
|
+ /* .endpoint = */ endpoint,
|
|
|
|
|
+ /* .name = */ "RPC",
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
- instances[endpoint] = new ggml_backend {
|
|
|
|
|
|
|
+ ggml_backend_t backend = new ggml_backend {
|
|
|
/* .guid = */ ggml_backend_rpc_guid(),
|
|
/* .guid = */ ggml_backend_rpc_guid(),
|
|
|
/* .interface = */ ggml_backend_rpc_interface,
|
|
/* .interface = */ ggml_backend_rpc_interface,
|
|
|
/* .context = */ ctx
|
|
/* .context = */ ctx
|
|
|
};
|
|
};
|
|
|
-
|
|
|
|
|
- return instances[endpoint];
|
|
|
|
|
|
|
+ return backend;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
|
|
GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
|
|
@@ -706,14 +724,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
|
|
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
|
|
|
- ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
|
|
|
|
|
- if (backend == nullptr) {
|
|
|
|
|
|
|
+ auto sock = get_socket(endpoint);
|
|
|
|
|
+ if (sock == nullptr) {
|
|
|
*free = 0;
|
|
*free = 0;
|
|
|
*total = 0;
|
|
*total = 0;
|
|
|
return;
|
|
return;
|
|
|
}
|
|
}
|
|
|
- ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
|
|
|
|
|
- get_device_memory(ctx->sock, free, total);
|
|
|
|
|
|
|
+ get_device_memory(sock, free, total);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// RPC server-side implementation
|
|
// RPC server-side implementation
|