|
@@ -425,6 +425,7 @@ struct vk_device_struct {
|
|
|
vk_pipeline pipeline_norm_f32;
|
|
vk_pipeline pipeline_norm_f32;
|
|
|
vk_pipeline pipeline_group_norm_f32;
|
|
vk_pipeline pipeline_group_norm_f32;
|
|
|
vk_pipeline pipeline_rms_norm_f32;
|
|
vk_pipeline pipeline_rms_norm_f32;
|
|
|
|
|
+ vk_pipeline pipeline_rms_norm_mul_f32;
|
|
|
vk_pipeline pipeline_rms_norm_back_f32;
|
|
vk_pipeline pipeline_rms_norm_back_f32;
|
|
|
vk_pipeline pipeline_l2_norm_f32;
|
|
vk_pipeline pipeline_l2_norm_f32;
|
|
|
|
|
|
|
@@ -978,6 +979,10 @@ struct ggml_backend_vk_context {
|
|
|
|
|
|
|
|
vk_command_pool compute_cmd_pool;
|
|
vk_command_pool compute_cmd_pool;
|
|
|
vk_command_pool transfer_cmd_pool;
|
|
vk_command_pool transfer_cmd_pool;
|
|
|
|
|
+
|
|
|
|
|
+ // number of additional consecutive nodes that are being fused with the
|
|
|
|
|
+ // node currently being processed
|
|
|
|
|
+ uint32_t num_additional_fused_ops {};
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
|
|
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
|
|
@@ -2655,7 +2660,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
|
|
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
|
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
|
|
|
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
|
|
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
|
|
|
|
|
@@ -6430,7 +6436,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
return nullptr;
|
|
return nullptr;
|
|
|
case GGML_OP_RMS_NORM:
|
|
case GGML_OP_RMS_NORM:
|
|
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
|
- return ctx->device->pipeline_rms_norm_f32;
|
|
|
|
|
|
|
+ return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
|
|
|
}
|
|
}
|
|
|
return nullptr;
|
|
return nullptr;
|
|
|
case GGML_OP_RMS_NORM_BACK:
|
|
case GGML_OP_RMS_NORM_BACK:
|
|
@@ -7530,18 +7536,19 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
|
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
|
|
|
|
|
+static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
|
float * op_params = (float *)dst->op_params;
|
|
float * op_params = (float *)dst->op_params;
|
|
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
|
|
|
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
|
|
|
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
|
|
|
|
|
|
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
|
|
|
|
|
|
|
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
|
|
|
(uint32_t)ggml_nelements(src0),
|
|
(uint32_t)ggml_nelements(src0),
|
|
|
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
|
|
|
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
|
|
|
|
|
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
|
|
|
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
|
|
|
|
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
|
0,
|
|
0,
|
|
|
- op_params[0], 0.0f,
|
|
|
|
|
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
|
|
|
|
|
+ op_params[0], 0.0f, 0,
|
|
|
}, dryrun);
|
|
}, dryrun);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -8736,7 +8743,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* t
|
|
|
|
|
|
|
|
// Returns true if node has enqueued work into the queue, false otherwise
|
|
// Returns true if node has enqueued work into the queue, false otherwise
|
|
|
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
|
|
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
|
|
|
-static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
|
|
|
|
|
|
|
+static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
|
|
|
|
|
+ ggml_tensor * node = cgraph->nodes[node_idx];
|
|
|
if (ggml_is_empty(node) || !node->buffer) {
|
|
if (ggml_is_empty(node) || !node->buffer) {
|
|
|
return false;
|
|
return false;
|
|
|
}
|
|
}
|
|
@@ -8974,8 +8982,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
|
|
|
|
|
break;
|
|
break;
|
|
|
case GGML_OP_RMS_NORM:
|
|
case GGML_OP_RMS_NORM:
|
|
|
- ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
|
|
|
|
|
-
|
|
|
|
|
|
|
+ if (ctx->num_additional_fused_ops > 0) {
|
|
|
|
|
+ // fused rms_norm + mul
|
|
|
|
|
+ ggml_tensor *mul = cgraph->nodes[node_idx + 1];
|
|
|
|
|
+ ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
|
|
|
|
|
+ ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
|
|
|
|
|
+ }
|
|
|
break;
|
|
break;
|
|
|
case GGML_OP_RMS_NORM_BACK:
|
|
case GGML_OP_RMS_NORM_BACK:
|
|
|
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
@@ -9710,10 +9724,15 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
|
|
|
|
|
uint64_t total_mat_mul_bytes = 0;
|
|
uint64_t total_mat_mul_bytes = 0;
|
|
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
|
- ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
|
|
|
|
|
|
|
+ if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
|
|
|
+ ctx->num_additional_fused_ops = 1;
|
|
|
|
|
+ }
|
|
|
|
|
+ ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
|
|
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
|
|
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
|
|
|
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
|
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
|
|
}
|
|
}
|
|
|
|
|
+ i += ctx->num_additional_fused_ops;
|
|
|
|
|
+ ctx->num_additional_fused_ops = 0;
|
|
|
}
|
|
}
|
|
|
if (ctx->device->need_compiles) {
|
|
if (ctx->device->need_compiles) {
|
|
|
ggml_vk_load_shaders(ctx->device);
|
|
ggml_vk_load_shaders(ctx->device);
|
|
@@ -9775,14 +9794,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
|
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
|
|
|
+ ctx->num_additional_fused_ops = 1;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
|
|
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
|
|
|
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
|
|
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
|
|
|
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
|
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
|
|
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
|
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
|
|
- (i == last_node) ||
|
|
|
|
|
|
|
+ (i + ctx->num_additional_fused_ops == last_node) ||
|
|
|
(almost_ready && !ctx->almost_ready_fence_pending);
|
|
(almost_ready && !ctx->almost_ready_fence_pending);
|
|
|
|
|
|
|
|
- bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit);
|
|
|
|
|
|
|
+ bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
|
|
|
|
|
|
|
|
if (vk_perf_logger_enabled) {
|
|
if (vk_perf_logger_enabled) {
|
|
|
if (ctx->compute_ctx.expired()) {
|
|
if (ctx->compute_ctx.expired()) {
|
|
@@ -9792,7 +9815,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
} else {
|
|
} else {
|
|
|
compute_ctx = ctx->compute_ctx.lock();
|
|
compute_ctx = ctx->compute_ctx.lock();
|
|
|
}
|
|
}
|
|
|
- compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+1);
|
|
|
|
|
|
|
+ // If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
|
|
|
|
|
+ for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
|
|
|
|
|
+ compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1);
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if (enqueued) {
|
|
if (enqueued) {
|
|
@@ -9814,6 +9840,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
}
|
|
}
|
|
|
submit_count++;
|
|
submit_count++;
|
|
|
}
|
|
}
|
|
|
|
|
+ i += ctx->num_additional_fused_ops;
|
|
|
|
|
+ ctx->num_additional_fused_ops = 0;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if (vk_perf_logger_enabled) {
|
|
if (vk_perf_logger_enabled) {
|