Sfoglia il codice sorgente

rpc : fix alloc size logic (#17116)

* rpc : fix alloc size logic

* rpc : bump version
Georgi Gerganov 1 mese fa
parent
commit
8160b38a5f
2 ha cambiato i file con 39 aggiunte e 10 eliminazioni
  1. 1 2
      ggml/include/ggml-rpc.h
  2. 38 8
      ggml/src/ggml-rpc/ggml-rpc.cpp

+ 1 - 2
ggml/include/ggml-rpc.h

@@ -1,6 +1,5 @@
 #pragma once
 
-#include "ggml.h"
 #include "ggml-backend.h"
 
 #ifdef  __cplusplus
@@ -8,7 +7,7 @@ extern "C" {
 #endif
 
 #define RPC_PROTO_MAJOR_VERSION    3
-#define RPC_PROTO_MINOR_VERSION    5
+#define RPC_PROTO_MINOR_VERSION    6
 #define RPC_PROTO_PATCH_VERSION    0
 #define GGML_RPC_MAX_SERVERS       16
 

+ 38 - 8
ggml/src/ggml-rpc/ggml-rpc.cpp

@@ -128,6 +128,7 @@ struct rpc_msg_device_count_rsp {
 struct rpc_msg_get_alloc_size_req {
     uint32_t   device;
     rpc_tensor tensor;
+    rpc_tensor srcs[GGML_MAX_SRC];
 };
 
 struct rpc_msg_get_alloc_size_rsp {
@@ -572,6 +573,11 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
 
 static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
     rpc_tensor result;
+    if (!tensor) {
+        memset(&result, 0, sizeof(result));
+        return result;
+    }
+
     result.id = reinterpret_cast<uint64_t>(tensor);
     result.type = tensor->type;
     if (tensor->buffer) {
@@ -753,23 +759,41 @@ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
 }
 
 static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
+    // should we query the remote server for the actual size
+    bool rpc_get = false;
+
     // See comments in init_tensor.
-    if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
+    rpc_get |= ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr);
+
+    // ops that require additional memory for fleeting data on certain backends
+    // ref: https://github.com/ggml-org/llama.cpp/pull/15966
+    rpc_get |= tensor->op == GGML_OP_FLASH_ATTN_EXT;
+    rpc_get |= tensor->op == GGML_OP_MUL_MAT_ID;
+
+    if (rpc_get) {
         ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
         auto sock = get_socket(buft_ctx->endpoint);
 
-        rpc_msg_get_alloc_size_req request;
-        request.device = buft_ctx->device;
-        request.tensor = serialize_tensor(tensor);
+        rpc_msg_get_alloc_size_req request = {
+            /*.device =*/ buft_ctx->device,
+            /*.tensor =*/ serialize_tensor(tensor),
+            /*.srcs   =*/ {},
+        };
+
+        // .get_alloc_size could be a function of the tensor's srcs, so we must serialize them as well
+        for (int i = 0; i < GGML_MAX_SRC; i++) {
+            request.srcs[i] = serialize_tensor(tensor->src[i]);
+        }
 
+        // TODO: cache the alloc responses to avoid extra RPC calls?
         rpc_msg_get_alloc_size_rsp response;
         bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
         RPC_STATUS_ASSERT(status);
 
         return response.alloc_size;
-    } else {
-        return ggml_nbytes(tensor);
     }
+
+    return ggml_nbytes(tensor);
 }
 
 static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@@ -1017,7 +1041,7 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
     }
     ggml_backend_buffer_type_t buft;
     struct ggml_init_params params {
-        /*.mem_size   =*/ ggml_tensor_overhead(),
+        /*.mem_size   =*/ ggml_tensor_overhead()*(1 + GGML_MAX_SRC),
         /*.mem_buffer =*/ NULL,
         /*.no_alloc   =*/ true,
     };
@@ -1025,12 +1049,18 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
     ggml_context_ptr ctx_ptr { ggml_init(params) };
     GGML_ASSERT(ctx_ptr != nullptr);
     ggml_context * ctx = ctx_ptr.get();
-    ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
 
+    ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
     if (tensor == nullptr) {
         GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
         return false;
     }
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        if (request.srcs[i].id != 0) {
+            tensor->src[i] = deserialize_tensor(ctx, &request.srcs[i]);
+        }
+    }
+
     LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data);
     if (tensor->buffer == nullptr) {
         //No buffer allocated.