|
|
@@ -25,7 +25,7 @@
|
|
|
# include <netdb.h>
|
|
|
# include <unistd.h>
|
|
|
#endif
|
|
|
-#include <string.h>
|
|
|
+#include <cstring>
|
|
|
|
|
|
#define UNUSED GGML_UNUSED
|
|
|
|
|
|
@@ -630,22 +630,6 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
|
|
|
return (enum ggml_status)output[0];
|
|
|
}
|
|
|
|
|
|
-static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
|
|
|
- UNUSED(backend);
|
|
|
- UNUSED(op);
|
|
|
- //TODO: call the remote backend and cache the results
|
|
|
- return true;
|
|
|
-}
|
|
|
-
|
|
|
-static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
|
|
- if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
|
|
|
- return false;
|
|
|
- }
|
|
|
- ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
|
- ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
|
|
- return buft_ctx->endpoint == rpc_ctx->endpoint;
|
|
|
-}
|
|
|
-
|
|
|
static ggml_backend_i ggml_backend_rpc_interface = {
|
|
|
/* .get_name = */ ggml_backend_rpc_name,
|
|
|
/* .free = */ ggml_backend_rpc_free,
|
|
|
@@ -659,8 +643,8 @@ static ggml_backend_i ggml_backend_rpc_interface = {
|
|
|
/* .graph_plan_update = */ NULL,
|
|
|
/* .graph_plan_compute = */ NULL,
|
|
|
/* .graph_compute = */ ggml_backend_rpc_graph_compute,
|
|
|
- /* .supports_op = */ ggml_backend_rpc_supports_op,
|
|
|
- /* .supports_buft = */ ggml_backend_rpc_supports_buft,
|
|
|
+ /* .supports_op = */ NULL,
|
|
|
+ /* .supports_buft = */ NULL,
|
|
|
/* .offload_op = */ NULL,
|
|
|
/* .event_record = */ NULL,
|
|
|
/* .event_wait = */ NULL,
|
|
|
@@ -691,7 +675,7 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * en
|
|
|
|
|
|
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
|
|
|
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
|
|
|
- /* .device = */ nullptr,
|
|
|
+ /* .device = */ ggml_backend_rpc_add_device(endpoint),
|
|
|
/* .context = */ buft_ctx
|
|
|
};
|
|
|
buft_map[endpoint] = buft;
|
|
|
@@ -707,7 +691,7 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
|
|
|
ggml_backend_t backend = new ggml_backend {
|
|
|
/* .guid = */ ggml_backend_rpc_guid(),
|
|
|
/* .interface = */ ggml_backend_rpc_interface,
|
|
|
- /* .device = */ nullptr,
|
|
|
+ /* .device = */ ggml_backend_rpc_add_device(endpoint),
|
|
|
/* .context = */ ctx
|
|
|
};
|
|
|
return backend;
|
|
|
@@ -1189,7 +1173,7 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
|
|
|
+void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
|
|
|
std::string host;
|
|
|
int port;
|
|
|
if (!parse_endpoint(endpoint, host, port)) {
|
|
|
@@ -1226,3 +1210,179 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
|
|
|
WSACleanup();
|
|
|
#endif
|
|
|
}
|
|
|
+
|
|
|
+// device interface
|
|
|
+
|
|
|
+struct ggml_backend_rpc_device_context {
|
|
|
+ std::string endpoint;
|
|
|
+ std::string name;
|
|
|
+};
|
|
|
+
|
|
|
+static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
|
|
|
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
|
+
|
|
|
+ return ctx->name.c_str();
|
|
|
+}
|
|
|
+
|
|
|
+static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
|
|
|
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
|
+
|
|
|
+ return ctx->name.c_str();
|
|
|
+}
|
|
|
+
|
|
|
+static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
|
|
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
|
+
|
|
|
+ ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
|
|
|
+
|
|
|
+ UNUSED(dev);
|
|
|
+}
|
|
|
+
|
|
|
+static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
|
|
|
+ // TODO: obtain value from the server
|
|
|
+ return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
|
|
|
+
|
|
|
+ UNUSED(dev);
|
|
|
+}
|
|
|
+
|
|
|
+static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
|
|
+ props->name = ggml_backend_rpc_device_get_name(dev);
|
|
|
+ props->description = ggml_backend_rpc_device_get_description(dev);
|
|
|
+ props->type = ggml_backend_rpc_device_get_type(dev);
|
|
|
+ ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
|
|
+ props->caps = {
|
|
|
+ /* .async = */ false,
|
|
|
+ /* .host_buffer = */ false,
|
|
|
+ /* .buffer_from_host_ptr = */ false,
|
|
|
+ /* .events = */ false,
|
|
|
+ };
|
|
|
+}
|
|
|
+
|
|
|
+static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
|
|
|
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
|
+
|
|
|
+ return ggml_backend_rpc_init(ctx->endpoint.c_str());
|
|
|
+
|
|
|
+ UNUSED(params);
|
|
|
+}
|
|
|
+
|
|
|
+static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
|
|
|
+ ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
|
+
|
|
|
+ return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
|
|
|
+
|
|
|
+ UNUSED(dev);
|
|
|
+}
|
|
|
+
|
|
|
+static ggml_backend_buffer_t ggml_backend_rpc_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
|
|
+ return ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
|
|
+
|
|
|
+ UNUSED(dev);
|
|
|
+ UNUSED(max_tensor_size);
|
|
|
+}
|
|
|
+
|
|
|
+static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
|
|
+ UNUSED(dev);
|
|
|
+ UNUSED(op);
|
|
|
+ //TODO: call the remote backend and cache the results
|
|
|
+ return true;
|
|
|
+}
|
|
|
+
|
|
|
+static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
|
|
+ if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
|
+ ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
|
|
|
+ return buft_ctx->endpoint == dev_ctx->endpoint;
|
|
|
+}
|
|
|
+
|
|
|
+static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
|
|
|
+ /* .get_name = */ ggml_backend_rpc_device_get_name,
|
|
|
+ /* .get_description = */ ggml_backend_rpc_device_get_description,
|
|
|
+ /* .get_memory = */ ggml_backend_rpc_device_get_memory,
|
|
|
+ /* .get_type = */ ggml_backend_rpc_device_get_type,
|
|
|
+ /* .get_props = */ ggml_backend_rpc_device_get_props,
|
|
|
+ /* .init_backend = */ ggml_backend_rpc_device_init,
|
|
|
+ /* .get_buffer_type = */ ggml_backend_rpc_device_get_buffer_type,
|
|
|
+ /* .get_host_buffer_type = */ NULL,
|
|
|
+ /* .buffer_from_host_ptr = */ ggml_backend_rpc_device_buffer_from_ptr,
|
|
|
+ /* .supports_op = */ ggml_backend_rpc_device_supports_op,
|
|
|
+ /* .supports_buft = */ ggml_backend_rpc_device_supports_buft,
|
|
|
+ /* .offload_op = */ NULL,
|
|
|
+ /* .event_new = */ NULL,
|
|
|
+ /* .event_free = */ NULL,
|
|
|
+ /* .event_synchronize = */ NULL,
|
|
|
+};
|
|
|
+
|
|
|
+// backend reg interface
|
|
|
+
|
|
|
+static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
|
|
|
+ return "RPC";
|
|
|
+
|
|
|
+ UNUSED(reg);
|
|
|
+}
|
|
|
+
|
|
|
+static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
|
|
|
+ return 0;
|
|
|
+
|
|
|
+ UNUSED(reg);
|
|
|
+}
|
|
|
+
|
|
|
+static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
|
|
+ GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
|
|
|
+
|
|
|
+ UNUSED(reg);
|
|
|
+ UNUSED(index);
|
|
|
+}
|
|
|
+
|
|
|
+static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
|
|
+ if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
|
|
|
+ return (void *)ggml_backend_rpc_add_device;
|
|
|
+ }
|
|
|
+ return NULL;
|
|
|
+
|
|
|
+ UNUSED(reg);
|
|
|
+}
|
|
|
+
|
|
|
+static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
|
|
|
+ /* .get_name = */ ggml_backend_rpc_reg_get_name,
|
|
|
+ /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
|
|
|
+ /* .get_device = */ ggml_backend_rpc_reg_get_device,
|
|
|
+ /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
|
|
|
+};
|
|
|
+
|
|
|
+ggml_backend_reg_t ggml_backend_rpc_reg(void) {
|
|
|
+ static struct ggml_backend_reg ggml_backend_rpc_reg = {
|
|
|
+ /* .iface = */ ggml_backend_rpc_reg_i,
|
|
|
+ /* .context = */ NULL,
|
|
|
+ };
|
|
|
+
|
|
|
+ return &ggml_backend_rpc_reg;
|
|
|
+}
|
|
|
+
|
|
|
+ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
|
|
|
+ static std::unordered_map<std::string, ggml_backend_dev_t> dev_map;
|
|
|
+
|
|
|
+ static std::mutex mutex;
|
|
|
+ std::lock_guard<std::mutex> lock(mutex);
|
|
|
+
|
|
|
+ if (dev_map.find(endpoint) != dev_map.end()) {
|
|
|
+ return dev_map[endpoint];
|
|
|
+ }
|
|
|
+
|
|
|
+ ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
|
|
|
+ /* .endpoint = */ endpoint,
|
|
|
+ /* .name = */ "RPC[" + std::string(endpoint) + "]",
|
|
|
+ };
|
|
|
+
|
|
|
+ ggml_backend_dev_t dev = new ggml_backend_device {
|
|
|
+ /* .iface = */ ggml_backend_rpc_device_i,
|
|
|
+ /* .reg = */ ggml_backend_rpc_reg(),
|
|
|
+ /* .context = */ ctx,
|
|
|
+ };
|
|
|
+
|
|
|
+ dev_map[endpoint] = dev;
|
|
|
+
|
|
|
+ return dev;
|
|
|
+}
|