Browse Source

rpc : add check for rpc buffer type (#18242)

Chris Rohlf 1 month ago
parent
commit
12ee1763a6
1 changed files with 5 additions and 5 deletions
  1. 5 5
      ggml/src/ggml-rpc/ggml-rpc.cpp

+ 5 - 5
ggml/src/ggml-rpc/ggml-rpc.cpp

@@ -571,6 +571,10 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
     return ctx->base_ptr;
 }
 
+static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
+    return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
+}
+
 static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
     rpc_tensor result;
     if (!tensor) {
@@ -580,7 +584,7 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
 
     result.id = reinterpret_cast<uint64_t>(tensor);
     result.type = tensor->type;
-    if (tensor->buffer) {
+    if (tensor->buffer && ggml_backend_buffer_is_rpc(tensor->buffer)) {
         ggml_backend_buffer_t buffer = tensor->buffer;
         ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
         result.buffer = ctx != nullptr ? ctx->remote_ptr : 0;
@@ -664,10 +668,6 @@ static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, con
     RPC_STATUS_ASSERT(status);
 }
 
-static bool ggml_backend_buffer_is_rpc(ggml_backend_buffer_t buffer) {
-    return buffer->iface.free_buffer == ggml_backend_rpc_buffer_free_buffer;
-}
-
 static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
     if (ggml_backend_buffer_is_rpc(src->buffer)) {
         // check if src and dst are on the same server