|
@@ -385,6 +385,14 @@ enum shader_reduction_mode {
|
|
|
|
|
|
|
|
static constexpr uint32_t num_argsort_pipelines = 11;
|
|
static constexpr uint32_t num_argsort_pipelines = 11;
|
|
|
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
|
|
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
|
|
|
|
|
+static constexpr uint32_t num_topk_moe_pipelines = 10;
|
|
|
|
|
+
|
|
|
|
|
+static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
|
|
|
|
+ GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
|
|
|
|
+ GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
|
|
|
|
|
+static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
|
|
|
|
+ GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
|
|
|
|
+
|
|
|
|
|
|
|
|
struct vk_device_struct {
|
|
struct vk_device_struct {
|
|
|
std::recursive_mutex mutex;
|
|
std::recursive_mutex mutex;
|
|
@@ -598,6 +606,9 @@ struct vk_device_struct {
|
|
|
|
|
|
|
|
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
|
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
|
|
|
|
|
|
|
|
|
+ // [2] is {!norm, norm}
|
|
|
|
|
+ vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
|
|
|
|
|
+
|
|
|
std::vector<vk_pipeline_ref> all_pipelines;
|
|
std::vector<vk_pipeline_ref> all_pipelines;
|
|
|
|
|
|
|
|
std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
|
|
std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
|
|
@@ -941,6 +952,11 @@ struct vk_op_multi_add_push_constants {
|
|
|
static_assert(MAX_PARAMETER_COUNT == 12);
|
|
static_assert(MAX_PARAMETER_COUNT == 12);
|
|
|
static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
|
|
static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
|
|
|
|
|
|
|
|
|
|
+struct vk_op_topk_moe_push_constants {
|
|
|
|
|
+ uint32_t n_rows;
|
|
|
|
|
+ uint32_t n_expert_used;
|
|
|
|
|
+};
|
|
|
|
|
+
|
|
|
struct vk_op_add_id_push_constants {
|
|
struct vk_op_add_id_push_constants {
|
|
|
uint32_t ne0;
|
|
uint32_t ne0;
|
|
|
uint32_t ne1;
|
|
uint32_t ne1;
|
|
@@ -3722,6 +3738,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
|
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
|
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
|
|
|
|
|
|
|
|
|
+ for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
|
|
|
|
|
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
|
|
|
|
|
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
for (auto &c : compiles) {
|
|
for (auto &c : compiles) {
|
|
|
c.wait();
|
|
c.wait();
|
|
|
}
|
|
}
|
|
@@ -8004,6 +8025,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
|
|
|
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
|
|
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
|
|
|
|
|
|
|
|
|
|
+ if (ctx->num_additional_fused_ops) {
|
|
|
|
|
+ uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
|
|
|
|
+ GGML_ASSERT(idx < num_topk_moe_pipelines);
|
|
|
|
|
+ bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
|
|
|
|
|
+ return ctx->device->pipeline_topk_moe[idx][with_norm];
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
|
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
|
|
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
|
|
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
|
|
|
}
|
|
}
|
|
@@ -9589,6 +9617,87 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
|
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
|
|
|
|
|
+
|
|
|
|
|
+ bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
|
|
|
|
|
+ ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
|
|
|
|
|
+ ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
|
|
|
|
|
+ ggml_tensor * ids = cgraph->nodes[node_idx + 3];
|
|
|
|
|
+
|
|
|
|
|
+ GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
|
|
|
|
+ GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
|
|
|
|
+ GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
|
|
|
|
+
|
|
|
|
|
+ const int n_experts = logits->ne[0];
|
|
|
|
|
+ const int n_rows = logits->ne[1];
|
|
|
|
|
+ const int n_expert_used = weights->ne[1];
|
|
|
|
|
+
|
|
|
|
|
+ GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
|
|
|
|
|
+
|
|
|
|
|
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, cgraph->nodes[node_idx], GGML_OP_SOFT_MAX);
|
|
|
|
|
+
|
|
|
|
|
+ if (dryrun) {
|
|
|
|
|
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ ggml_backend_vk_buffer_context * logits_buf_ctx = (ggml_backend_vk_buffer_context *)logits->buffer->context;
|
|
|
|
|
+ ggml_backend_vk_buffer_context * weights_buf_ctx = (ggml_backend_vk_buffer_context *)weights->buffer->context;
|
|
|
|
|
+ ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
|
|
|
|
|
+
|
|
|
|
|
+ vk_buffer d_logits = nullptr;
|
|
|
|
|
+ size_t logits_buf_offset = 0;
|
|
|
|
|
+ vk_buffer d_weights = nullptr;
|
|
|
|
|
+ size_t weights_buf_offset = 0;
|
|
|
|
|
+ vk_buffer d_ids = nullptr;
|
|
|
|
|
+ size_t ids_buf_offset = 0;
|
|
|
|
|
+
|
|
|
|
|
+ bool logits_uma = false;
|
|
|
|
|
+ bool weights_uma = false;
|
|
|
|
|
+ bool ids_uma = false;
|
|
|
|
|
+
|
|
|
|
|
+ if (ctx->device->uma) {
|
|
|
|
|
+ ggml_vk_host_get(ctx->device, logits->data, d_logits, logits_buf_offset);
|
|
|
|
|
+ ggml_vk_host_get(ctx->device, weights->data, d_weights, weights_buf_offset);
|
|
|
|
|
+ ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
|
|
|
|
|
+ logits_uma = d_logits != nullptr;
|
|
|
|
|
+ weights_uma = d_weights != nullptr;
|
|
|
|
|
+ ids_uma = d_ids != nullptr;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (!logits_uma) {
|
|
|
|
|
+ d_logits = logits_buf_ctx->dev_buffer;
|
|
|
|
|
+ logits_buf_offset = vk_tensor_offset(logits) + logits->view_offs;
|
|
|
|
|
+ GGML_ASSERT(d_logits != nullptr);
|
|
|
|
|
+ }
|
|
|
|
|
+ if (!weights_uma) {
|
|
|
|
|
+ d_weights = weights_buf_ctx->dev_buffer;
|
|
|
|
|
+ weights_buf_offset = vk_tensor_offset(weights) + weights->view_offs;
|
|
|
|
|
+ GGML_ASSERT(d_weights != nullptr);
|
|
|
|
|
+ }
|
|
|
|
|
+ if (!ids_uma) {
|
|
|
|
|
+ d_ids = ids_buf_ctx->dev_buffer;
|
|
|
|
|
+ ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
|
|
|
|
|
+ GGML_ASSERT(d_ids != nullptr);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ vk_op_topk_moe_push_constants pc;
|
|
|
|
|
+ pc.n_rows = n_rows;
|
|
|
|
|
+ pc.n_expert_used = n_expert_used;
|
|
|
|
|
+
|
|
|
|
|
+ GGML_ASSERT(n_expert_used <= n_experts);
|
|
|
|
|
+
|
|
|
|
|
+ const uint32_t rows_per_block = 4;
|
|
|
|
|
+ std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
|
|
|
|
|
+
|
|
|
|
|
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
|
|
|
+ {
|
|
|
|
|
+ ggml_vk_subbuffer(ctx, d_logits, logits_buf_offset),
|
|
|
|
|
+ ggml_vk_subbuffer(ctx, d_weights, weights_buf_offset),
|
|
|
|
|
+ ggml_vk_subbuffer(ctx, d_ids, ids_buf_offset),
|
|
|
|
|
+ }, pc, elements);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
|
|
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
|
|
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
|
const int mode = ((int32_t *) dst->op_params)[2];
|
|
const int mode = ((int32_t *) dst->op_params)[2];
|
|
@@ -11174,11 +11283,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
ctx->unsynced_nodes_read.clear();
|
|
ctx->unsynced_nodes_read.clear();
|
|
|
ggml_vk_sync_buffers(ctx, compute_ctx);
|
|
ggml_vk_sync_buffers(ctx, compute_ctx);
|
|
|
}
|
|
}
|
|
|
- // Add the last fused node and all fused source nodes to the unsynchronized list.
|
|
|
|
|
- const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
|
|
|
|
|
- ctx->unsynced_nodes_written.push_back(last_node);
|
|
|
|
|
|
|
+ // Add all fused nodes to the unsynchronized lists.
|
|
|
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
|
|
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
|
|
|
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
|
|
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
|
|
|
|
|
+ // Multiple outputs could be written, e.g. in topk_moe. Add them all to the list.
|
|
|
|
|
+ ctx->unsynced_nodes_written.push_back(cur_node);
|
|
|
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
|
|
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
|
|
|
if (!cur_node->src[j]) {
|
|
if (!cur_node->src[j]) {
|
|
|
continue;
|
|
continue;
|
|
@@ -11345,7 +11454,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
|
|
|
|
|
break;
|
|
break;
|
|
|
case GGML_OP_SOFT_MAX:
|
|
case GGML_OP_SOFT_MAX:
|
|
|
- ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
|
|
|
|
|
|
+ if (ctx->num_additional_fused_ops) {
|
|
|
|
|
+ ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
|
|
|
|
|
+ } else {
|
|
|
|
|
+ ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
break;
|
|
break;
|
|
|
case GGML_OP_SOFT_MAX_BACK:
|
|
case GGML_OP_SOFT_MAX_BACK:
|
|
@@ -12141,6 +12254,120 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
|
|
|
return true;
|
|
return true;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
|
|
|
|
|
+ int node_idx, bool with_norm) {
|
|
|
|
|
+
|
|
|
|
|
+ if (with_norm) {
|
|
|
|
|
+ if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+ for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
|
|
|
|
|
+ if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ } else {
|
|
|
|
|
+ if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+ for (size_t i = 0; i < topk_moe.size(); ++i) {
|
|
|
|
|
+ if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ const ggml_tensor * softmax = cgraph->nodes[node_idx + 0];
|
|
|
|
|
+ const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
|
|
|
|
|
+
|
|
|
|
|
+ const float * op_params = (const float *)softmax->op_params;
|
|
|
|
|
+
|
|
|
|
|
+ float scale = op_params[0];
|
|
|
|
|
+ float max_bias = op_params[1];
|
|
|
|
|
+
|
|
|
|
|
+ if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (scale != 1.0f || max_bias != 0.0f) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // don't fuse when masks or sinks are present
|
|
|
|
|
+ if (softmax->src[1] || softmax->src[2]) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ const int n_expert = softmax->ne[0];
|
|
|
|
|
+ // n_expert must be a power of 2
|
|
|
|
|
+ if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // Check that the nodes don't have any unexpected uses
|
|
|
|
|
+ const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
|
|
|
|
|
+ const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
|
|
|
|
|
+ const ggml_tensor * view = cgraph->nodes[node_idx + 3];
|
|
|
|
|
+ const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
|
|
|
|
|
+ const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
|
|
|
|
|
+ const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
|
|
|
|
|
+ const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
|
|
|
|
|
+ const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
|
|
|
|
|
+
|
|
|
|
|
+ // softmax is used by reshape and argsort
|
|
|
|
|
+ if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
|
|
|
|
|
+ reshape1->src[0] != softmax ||
|
|
|
|
|
+ argsort->src[0] != softmax) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+ // reshape is used by get_rows
|
|
|
|
|
+ if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
|
|
|
|
|
+ get_rows->src[0] != reshape1) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+ // argsort is used by view
|
|
|
|
|
+ if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
|
|
|
|
|
+ view->src[0] != argsort) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+ // view is written (via argsort), we can skip checking it
|
|
|
|
|
+
|
|
|
|
|
+ if (with_norm) {
|
|
|
|
|
+ // get_rows is used by reshape
|
|
|
|
|
+ if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
|
|
|
|
|
+ reshape5->src[0] != get_rows) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // reshape is used by sum_rows and div
|
|
|
|
|
+ if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
|
|
|
|
|
+ sum_rows->src[0] != reshape5 ||
|
|
|
|
|
+ div->src[0] != reshape5) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // sum_rows is used by div
|
|
|
|
|
+ if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
|
|
|
|
|
+ div->src[1] != sum_rows) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ // div/reshape are written
|
|
|
|
|
+ if (reshape8->src[0] != div) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (!ctx->device->subgroup_arithmetic ||
|
|
|
|
|
+ !ctx->device->subgroup_shuffle ||
|
|
|
|
|
+ !ctx->device->subgroup_require_full_support ||
|
|
|
|
|
+ ctx->device->disable_fusion) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return true;
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
|
static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
|
|
|
|
|
|
|
|
const ggml_tensor *first_node = cgraph->nodes[node_idx];
|
|
const ggml_tensor *first_node = cgraph->nodes[node_idx];
|
|
@@ -12216,6 +12443,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
ctx->num_additional_fused_ops = num_adds - 1;
|
|
ctx->num_additional_fused_ops = num_adds - 1;
|
|
|
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
|
ctx->num_additional_fused_ops = 1;
|
|
ctx->num_additional_fused_ops = 1;
|
|
|
|
|
+ } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
|
|
|
|
|
+ ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
|
|
|
|
|
+ } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
|
|
|
|
|
+ ctx->num_additional_fused_ops = topk_moe.size() - 1;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
|
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
|
|
@@ -12313,6 +12544,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
ctx->num_additional_fused_ops = num_adds - 1;
|
|
ctx->num_additional_fused_ops = num_adds - 1;
|
|
|
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
|
|
|
ctx->num_additional_fused_ops = 1;
|
|
ctx->num_additional_fused_ops = 1;
|
|
|
|
|
+ } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
|
|
|
|
|
+ ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
|
|
|
|
|
+ } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
|
|
|
|
|
+ ctx->num_additional_fused_ops = topk_moe.size() - 1;
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -12320,10 +12555,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
|
|
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
|
|
|
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
|
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
|
|
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
|
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
|
|
- (i + ctx->num_additional_fused_ops == last_node) ||
|
|
|
|
|
|
|
+ (i + ctx->num_additional_fused_ops >= last_node) ||
|
|
|
(almost_ready && !ctx->almost_ready_fence_pending);
|
|
(almost_ready && !ctx->almost_ready_fence_pending);
|
|
|
|
|
|
|
|
- bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
|
|
|
|
|
|
|
+ bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);
|
|
|
|
|
|
|
|
if (vk_perf_logger_enabled) {
|
|
if (vk_perf_logger_enabled) {
|
|
|
if (ctx->compute_ctx.expired()) {
|
|
if (ctx->compute_ctx.expired()) {
|
|
@@ -12444,6 +12679,25 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|
|
while (first_unused < graph->n_nodes) {
|
|
while (first_unused < graph->n_nodes) {
|
|
|
std::vector<int> current_set;
|
|
std::vector<int> current_set;
|
|
|
|
|
|
|
|
|
|
+ // Avoid reordering topk_moe_norm
|
|
|
|
|
+ if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
|
|
|
|
|
+ bool is_topk_moe_norm = true;
|
|
|
|
|
+ for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
|
|
|
|
|
+ if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
|
|
|
|
|
+ is_topk_moe_norm = false;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ if (is_topk_moe_norm) {
|
|
|
|
|
+ for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
|
|
|
|
|
+ new_order.push_back(graph->nodes[first_unused + j]);
|
|
|
|
|
+ used[first_unused + j] = true;
|
|
|
|
|
+ }
|
|
|
|
|
+ while (first_unused < graph->n_nodes && used[first_unused]) {
|
|
|
|
|
+ first_unused++;
|
|
|
|
|
+ }
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
// First, grab the next unused node.
|
|
// First, grab the next unused node.
|
|
|
current_set.push_back(first_unused);
|
|
current_set.push_back(first_unused);
|
|
|
|
|
|