Jelajahi Sumber

rpc : cache and reuse compute graphs (#15405)

Store the last computed graph and reuse it when possible.
Also do not return response from GRAPH_COMPUTE and assume it always
completes successfully. If this this is not the case, the server closes
the connection. This saves us a network round trip to the server.
Radoslav Gerganov 1 bulan lalu
induk
melakukan
15d2b46b4d
2 mengubah file dengan 92 tambahan dan 21 penghapusan
  1. 1 1
      ggml/include/ggml-rpc.h
  2. 91 20
      ggml/src/ggml-rpc/ggml-rpc.cpp

+ 1 - 1
ggml/include/ggml-rpc.h

@@ -8,7 +8,7 @@ extern "C" {
 #endif
 
 #define RPC_PROTO_MAJOR_VERSION    3
-#define RPC_PROTO_MINOR_VERSION    0
+#define RPC_PROTO_MINOR_VERSION    5
 #define RPC_PROTO_PATCH_VERSION    0
 #define GGML_RPC_MAX_SERVERS       16
 

+ 91 - 20
ggml/src/ggml-rpc/ggml-rpc.cpp

@@ -106,6 +106,7 @@ enum rpc_cmd {
     RPC_CMD_GET_ALLOC_SIZE,
     RPC_CMD_HELLO,
     RPC_CMD_DEVICE_COUNT,
+    RPC_CMD_GRAPH_RECOMPUTE,
     RPC_CMD_COUNT,
 };
 
@@ -205,10 +206,6 @@ struct rpc_msg_copy_tensor_rsp {
     uint8_t result;
 };
 
-struct rpc_msg_graph_compute_rsp {
-    uint8_t result;
-};
-
 struct rpc_msg_get_device_memory_req {
     uint32_t device;
 };
@@ -217,6 +214,11 @@ struct rpc_msg_get_device_memory_rsp {
     uint64_t free_mem;
     uint64_t total_mem;
 };
+
+struct rpc_msg_graph_recompute_req {
+    uint32_t device;
+};
+
 #pragma pack(pop)
 
 // RPC data structures
@@ -234,10 +236,35 @@ struct ggml_backend_rpc_buffer_type_context {
     size_t      max_size;
 };
 
+struct graph_cache {
+
+    bool is_cached(const ggml_cgraph * cgraph) {
+        if ((int)last_graph.size() != cgraph->n_nodes) {
+            return false;
+        }
+        for (int i = 0; i < cgraph->n_nodes; i++) {
+            if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
+                return false;
+            }
+        }
+        return true;
+    }
+
+    void add(const ggml_cgraph * cgraph) {
+        last_graph.resize(cgraph->n_nodes);
+        for (int i = 0; i < cgraph->n_nodes; i++) {
+            memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
+        }
+    }
+
+    std::vector<ggml_tensor> last_graph;
+};
+
 struct ggml_backend_rpc_context {
     std::string endpoint;
     uint32_t    device;
     std::string name;
+    graph_cache gc;
 };
 
 struct ggml_backend_rpc_buffer_context {
@@ -815,13 +842,24 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve
 
 static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
     ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
-    std::vector<uint8_t> input;
-    serialize_graph(rpc_ctx->device, cgraph, input);
-    rpc_msg_graph_compute_rsp response;
-    auto sock = get_socket(rpc_ctx->endpoint);
-    bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
-    RPC_STATUS_ASSERT(status);
-    return (enum ggml_status)response.result;
+
+    GGML_ASSERT(cgraph->n_nodes > 0);
+    bool reuse = rpc_ctx->gc.is_cached(cgraph);
+    if (reuse) {
+        rpc_msg_graph_recompute_req request;
+        request.device = rpc_ctx->device;
+        auto sock = get_socket(rpc_ctx->endpoint);
+        bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
+        RPC_STATUS_ASSERT(status);
+    } else {
+        rpc_ctx->gc.add(cgraph);
+        std::vector<uint8_t> input;
+        serialize_graph(rpc_ctx->device, cgraph, input);
+        auto sock = get_socket(rpc_ctx->endpoint);
+        bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
+        RPC_STATUS_ASSERT(status);
+    }
+    return GGML_STATUS_SUCCESS;
 }
 
 static ggml_backend_i ggml_backend_rpc_interface = {
@@ -880,7 +918,8 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
     ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
         /* .endpoint = */ endpoint,
         /* .device   = */ device,
-        /* .name     = */ dev_name
+        /* .name     = */ dev_name,
+        /* .gc       = */ {},
     };
     auto reg = ggml_backend_rpc_add_server(endpoint);
     ggml_backend_t backend = new ggml_backend {
@@ -920,8 +959,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device,
 
 class rpc_server {
 public:
-    rpc_server(std::vector<ggml_backend_t> backends, const char * cache_dir)
-        : backends(std::move(backends)), cache_dir(cache_dir) {
+    rpc_server(std::vector<ggml_backend_t> all_backends, const char * cache_dir)
+        : backends(std::move(all_backends)), cache_dir(cache_dir) {
+        stored_graphs.resize(backends.size());
     }
     ~rpc_server();
 
@@ -936,11 +976,17 @@ public:
     bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
     bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
     bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
-    bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
+    bool graph_compute(const std::vector<uint8_t> & input);
+    bool graph_recompute(const rpc_msg_graph_recompute_req & request);
     bool init_tensor(const rpc_msg_init_tensor_req & request);
     bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
     bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
 
+    struct stored_graph {
+        ggml_context_ptr ctx_ptr;
+        ggml_cgraph *    graph;
+    };
+
 private:
     bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
     ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
@@ -953,6 +999,8 @@ private:
     std::vector<ggml_backend_t> backends;
     const char * cache_dir;
     std::unordered_set<ggml_backend_buffer_t> buffers;
+    // store the last computed graph for each backend
+    std::vector<stored_graph> stored_graphs;
 };
 
 void rpc_server::hello(rpc_msg_hello_rsp & response) {
@@ -1394,7 +1442,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
     return result;
 }
 
-bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
+bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
     // serialization format:
     // | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
     if (input.size() < 2*sizeof(uint32_t)) {
@@ -1455,7 +1503,24 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
         }
     }
     ggml_status status = ggml_backend_graph_compute(backends[device], graph);
-    response.result = status;
+    GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
+    stored_graphs[device].ctx_ptr.swap(ctx_ptr);
+    stored_graphs[device].graph = graph;
+    return true;
+}
+
+bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {
+    uint32_t device = request.device;
+    if (device >= backends.size()) {
+        return false;
+    }
+    if (stored_graphs[device].graph == nullptr) {
+        return false;
+    }
+    ggml_cgraph * graph = stored_graphs[device].graph;
+    LOG_DBG("[%s] device: %u\n", __func__, device);
+    ggml_status status = ggml_backend_graph_compute(backends[device], graph);
+    GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
     return true;
 }
 
@@ -1690,11 +1755,17 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
                 if (!recv_msg(sockfd, input)) {
                     return;
                 }
-                rpc_msg_graph_compute_rsp response;
-                if (!server.graph_compute(input, response)) {
+                if (!server.graph_compute(input)) {
                     return;
                 }
-                if (!send_msg(sockfd, &response, sizeof(response))) {
+                break;
+            }
+            case RPC_CMD_GRAPH_RECOMPUTE: {
+                rpc_msg_graph_recompute_req request;
+                if (!recv_msg(sockfd, &request, sizeof(request))) {
+                    return;
+                }
+                if (!server.graph_recompute(request)) {
                     return;
                 }
                 break;