Sfoglia il codice sorgente

rpc : check for null buffers in get/set/copy tensor endpoints (#14868)

Chris Rohlf 5 mesi fa
parent
commit
64bf1c3744
1 ha cambiato i file con 4 aggiunte e 4 eliminazioni
  1. 4 4
      ggml/src/ggml-rpc/ggml-rpc.cpp

+ 4 - 4
ggml/src/ggml-rpc/ggml-rpc.cpp

@@ -1055,7 +1055,7 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) {
     GGML_ASSERT(ctx_ptr != nullptr);
     ggml_context * ctx = ctx_ptr.get();
     ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
-    if (tensor == nullptr) {
+    if (tensor == nullptr || tensor->buffer == nullptr) {
         GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
         return false;
     }
@@ -1124,7 +1124,7 @@ bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rp
     GGML_ASSERT(ctx_ptr != nullptr);
     ggml_context * ctx = ctx_ptr.get();
     ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
-    if (tensor == nullptr) {
+    if (tensor == nullptr || tensor->buffer == nullptr) {
         GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
         return false;
     }
@@ -1192,7 +1192,7 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector<
     GGML_ASSERT(ctx_ptr != nullptr);
     ggml_context * ctx = ctx_ptr.get();
     ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
-    if (tensor == nullptr) {
+    if (tensor == nullptr || tensor->buffer == nullptr) {
         GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__);
         return false;
     }
@@ -1229,7 +1229,7 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
 
     ggml_tensor * src = deserialize_tensor(ctx, &request.src);
     ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
-    if (src == nullptr || dst == nullptr) {
+    if (src == nullptr || dst == nullptr || src->buffer == nullptr || dst->buffer == nullptr) {
         GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__);
         return false;
     }