|
|
@@ -128,6 +128,7 @@ struct rpc_msg_device_count_rsp {
|
|
|
struct rpc_msg_get_alloc_size_req {
|
|
|
uint32_t device;
|
|
|
rpc_tensor tensor;
|
|
|
+ rpc_tensor srcs[GGML_MAX_SRC];
|
|
|
};
|
|
|
|
|
|
struct rpc_msg_get_alloc_size_rsp {
|
|
|
@@ -572,6 +573,11 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
|
|
|
|
|
|
static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
|
|
|
rpc_tensor result;
|
|
|
+ if (!tensor) {
|
|
|
+ memset(&result, 0, sizeof(result));
|
|
|
+ return result;
|
|
|
+ }
|
|
|
+
|
|
|
result.id = reinterpret_cast<uint64_t>(tensor);
|
|
|
result.type = tensor->type;
|
|
|
if (tensor->buffer) {
|
|
|
@@ -753,23 +759,41 @@ static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
|
|
|
}
|
|
|
|
|
|
static size_t ggml_backend_rpc_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
|
|
|
+ // should we query the remote server for the actual size
|
|
|
+ bool rpc_get = false;
|
|
|
+
|
|
|
// See comments in init_tensor.
|
|
|
- if (ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr)) {
|
|
|
+ rpc_get |= ggml_is_quantized(tensor->type) && (tensor->ne[0] % 512 != 0) && (tensor->view_src == nullptr);
|
|
|
+
|
|
|
+ // ops that require additional memory for fleeting data on certain backends
|
|
|
+ // ref: https://github.com/ggml-org/llama.cpp/pull/15966
|
|
|
+ rpc_get |= tensor->op == GGML_OP_FLASH_ATTN_EXT;
|
|
|
+ rpc_get |= tensor->op == GGML_OP_MUL_MAT_ID;
|
|
|
+
|
|
|
+ if (rpc_get) {
|
|
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
|
|
auto sock = get_socket(buft_ctx->endpoint);
|
|
|
|
|
|
- rpc_msg_get_alloc_size_req request;
|
|
|
- request.device = buft_ctx->device;
|
|
|
- request.tensor = serialize_tensor(tensor);
|
|
|
+ rpc_msg_get_alloc_size_req request = {
|
|
|
+ /*.device =*/ buft_ctx->device,
|
|
|
+ /*.tensor =*/ serialize_tensor(tensor),
|
|
|
+ /*.srcs =*/ {},
|
|
|
+ };
|
|
|
+
|
|
|
+ // .get_alloc_size could be a function of the tensor's srcs, so we must serialize them as well
|
|
|
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
|
|
|
+ request.srcs[i] = serialize_tensor(tensor->src[i]);
|
|
|
+ }
|
|
|
|
|
|
+ // TODO: cache the alloc responses to avoid extra RPC calls?
|
|
|
rpc_msg_get_alloc_size_rsp response;
|
|
|
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALLOC_SIZE, &request, sizeof(request), &response, sizeof(response));
|
|
|
RPC_STATUS_ASSERT(status);
|
|
|
|
|
|
return response.alloc_size;
|
|
|
- } else {
|
|
|
- return ggml_nbytes(tensor);
|
|
|
}
|
|
|
+
|
|
|
+ return ggml_nbytes(tensor);
|
|
|
}
|
|
|
|
|
|
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
|
|
@@ -1017,7 +1041,7 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
|
|
|
}
|
|
|
ggml_backend_buffer_type_t buft;
|
|
|
struct ggml_init_params params {
|
|
|
- /*.mem_size =*/ ggml_tensor_overhead(),
|
|
|
+ /*.mem_size =*/ ggml_tensor_overhead()*(1 + GGML_MAX_SRC),
|
|
|
/*.mem_buffer =*/ NULL,
|
|
|
/*.no_alloc =*/ true,
|
|
|
};
|
|
|
@@ -1025,12 +1049,18 @@ bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_
|
|
|
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
|
|
GGML_ASSERT(ctx_ptr != nullptr);
|
|
|
ggml_context * ctx = ctx_ptr.get();
|
|
|
- ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
|
|
|
|
|
+ ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
|
|
|
if (tensor == nullptr) {
|
|
|
GGML_LOG_ERROR("Null tensor pointer passed to server get_alloc_size function.\n");
|
|
|
return false;
|
|
|
}
|
|
|
+ for (int i = 0; i < GGML_MAX_SRC; i++) {
|
|
|
+ if (request.srcs[i].id != 0) {
|
|
|
+ tensor->src[i] = deserialize_tensor(ctx, &request.srcs[i]);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
LOG_DBG("[%s] device: %d, buffer: %p, data: %p\n", __func__, dev_id, (void*)tensor->buffer, tensor->data);
|
|
|
if (tensor->buffer == nullptr) {
|
|
|
//No buffer allocated.
|