|
|
@@ -181,7 +181,7 @@ struct ggml_backend_rpc_context {
|
|
|
|
|
|
struct ggml_backend_rpc_buffer_context {
|
|
|
std::shared_ptr<socket_t> sock;
|
|
|
- std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
|
|
|
+ void * base_ptr;
|
|
|
uint64_t remote_ptr;
|
|
|
};
|
|
|
|
|
|
@@ -423,16 +423,15 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
|
|
|
|
|
static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
|
|
- if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
|
|
|
- return ctx->base_cache[buffer];
|
|
|
+ if (ctx->base_ptr != nullptr) {
|
|
|
+ return ctx->base_ptr;
|
|
|
}
|
|
|
rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
|
|
|
rpc_msg_buffer_get_base_rsp response;
|
|
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
|
|
|
GGML_ASSERT(status);
|
|
|
- void * base_ptr = reinterpret_cast<void *>(response.base_ptr);
|
|
|
- ctx->base_cache[buffer] = base_ptr;
|
|
|
- return base_ptr;
|
|
|
+ ctx->base_ptr = reinterpret_cast<void *>(response.base_ptr);
|
|
|
+ return ctx->base_ptr;
|
|
|
}
|
|
|
|
|
|
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
|
|
@@ -557,7 +556,7 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
|
|
|
if (response.remote_ptr != 0) {
|
|
|
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
|
|
ggml_backend_rpc_buffer_interface,
|
|
|
- new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr},
|
|
|
+ new ggml_backend_rpc_buffer_context{sock, nullptr, response.remote_ptr},
|
|
|
response.remote_size);
|
|
|
return buffer;
|
|
|
} else {
|