|
|
@@ -434,8 +434,15 @@ static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGM
|
|
|
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
|
|
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
|
|
|
GGML_OP_RESHAPE };
|
|
|
+
|
|
|
+static constexpr std::initializer_list<ggml_op> topk_moe_sigmoid_norm_bias{ GGML_OP_UNARY, GGML_OP_RESHAPE, GGML_OP_ADD,
|
|
|
+ GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
|
|
|
+ GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
|
|
|
+ GGML_OP_DIV, GGML_OP_RESHAPE };
|
|
|
+
|
|
|
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
|
|
|
GGML_OP_VIEW, GGML_OP_GET_ROWS };
|
|
|
+
|
|
|
static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
|
|
|
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
|
|
|
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
|
|
|
@@ -464,6 +471,32 @@ static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softma
|
|
|
{ 9, 0, 8 }, // reshape->src[0] == div
|
|
|
};
|
|
|
|
|
|
+//node #436 ( UNARY): ffn_moe_probs-10 ( 256K) [Vulka ] use=2: ffn_moe_logits-10 ( 256K) [Vulka ]
|
|
|
+//node #437 ( RESHAPE): ffn_moe_probs-10 (re ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ]
|
|
|
+//node #438 ( ADD): ffn_moe_probs_biased ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ] blk.10.exp_probs_b.b ( 0K) [Vulka ]
|
|
|
+//node #439 ( ARGSORT): ffn_moe_argsort-10 ( 256K) [Vulka ] use=1: ffn_moe_probs_biased ( 256K) [Vulka ]
|
|
|
+//node #440 ( VIEW): ffn_moe_topk-10 ( 255K) [Vulka ] use=3: ffn_moe_argsort-10 ( 256K) [Vulka ]
|
|
|
+//node #441 ( GET_ROWS): ffn_moe_weights-10 ( 12K) [Vulka ] use=1: ffn_moe_probs-10 (re ( 256K) [Vulka ] ffn_moe_topk-10 ( 255K) [Vulka ]
|
|
|
+//node #442 ( RESHAPE): ffn_moe_weights-10 ( ( 12K) [Vulka ] use=2: ffn_moe_weights-10 ( 12K) [Vulka ]
|
|
|
+//node #443 ( SUM_ROWS): ffn_moe_weights_sum- ( 2K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ]
|
|
|
+//node #444 ( CLAMP): ffn_moe_weights_sum_ ( 2K) [Vulka ] use=1: ffn_moe_weights_sum- ( 2K) [Vulka ]
|
|
|
+//node #445 ( DIV): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ] ffn_moe_weights_sum_ ( 2K) [Vulka ]
|
|
|
+//node #446 ( RESHAPE): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights_norm ( 12K) [Vulka ]
|
|
|
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_sigmoid_norm_bias_edges {
|
|
|
+ { 1, 0, 0 }, // reshape->src[0] == sigmoid
|
|
|
+ { 2, 0, 0 }, // add->src[0] == sigmoid
|
|
|
+ { 3, 0, 2 }, // argsort->src[0] == add
|
|
|
+ { 4, 0, 3 }, // view->src[0] == argsort
|
|
|
+ { 5, 0, 1 }, // get_rows->src[0] == reshape
|
|
|
+ { 5, 1, 4 }, // get_rows->src[1] == view
|
|
|
+ { 6, 0, 5 }, // reshape->src[0] == get_rows
|
|
|
+ { 7, 0, 6 }, // sum_rows->src[0] == reshape
|
|
|
+ { 8, 0, 7 }, // clamp->src[0] == sum_rows
|
|
|
+ { 9, 0, 6 }, // div->src[0] == reshape
|
|
|
+ { 9, 1, 8 }, // div->src[1] == clamp
|
|
|
+ {10, 0, 9 }, // reshape->src[0] == div
|
|
|
+};
|
|
|
+
|
|
|
// same as early_softmax_norm but ending after the get_rows
|
|
|
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
|
|
|
{ 1, 0, 0 }, // reshape->src[0] == softmax
|
|
|
@@ -491,16 +524,10 @@ enum topk_moe_mode {
|
|
|
TOPK_MOE_EARLY_SOFTMAX,
|
|
|
TOPK_MOE_EARLY_SOFTMAX_NORM,
|
|
|
TOPK_MOE_LATE_SOFTMAX,
|
|
|
+ TOPK_MOE_SIGMOID_NORM_BIAS,
|
|
|
TOPK_MOE_COUNT,
|
|
|
};
|
|
|
|
|
|
-static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
|
|
|
- topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
|
|
|
- num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
|
|
|
- TOPK_MOE_LATE_SOFTMAX;
|
|
|
- return mode;
|
|
|
-}
|
|
|
-
|
|
|
static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_edges {
|
|
|
{ 1, 0, 0 }, // view->src[0] == rope
|
|
|
{ 2, 0, 1 }, // set_rows->src[0] == view
|
|
|
@@ -766,7 +793,7 @@ struct vk_device_struct {
|
|
|
vk_pipeline pipeline_count_experts;
|
|
|
|
|
|
// [2] is for whether to take n_experts from spec constant (0) or push constant (1)
|
|
|
- vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2];
|
|
|
+ vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
|
|
|
|
|
|
std::vector<vk_pipeline_ref> all_pipelines;
|
|
|
|
|
|
@@ -1181,6 +1208,11 @@ struct vk_op_topk_moe_push_constants {
|
|
|
uint32_t n_expert_used;
|
|
|
float clamp_min;
|
|
|
float clamp_max;
|
|
|
+ uint32_t gating_func;
|
|
|
+ uint32_t has_bias;
|
|
|
+ uint32_t with_norm;
|
|
|
+ float output_scale;
|
|
|
+ float output_bias;
|
|
|
};
|
|
|
|
|
|
struct vk_op_add_id_push_constants {
|
|
|
@@ -1771,6 +1803,8 @@ struct ggml_backend_vk_context {
|
|
|
// Bit 'i' means nodes[start_of_fusion + i] writes to memory.
|
|
|
// If there's no fusion, bit 0 is still set.
|
|
|
int fused_ops_write_mask {};
|
|
|
+ topk_moe_mode fused_topk_moe_mode {};
|
|
|
+ bool fused_topk_moe_scale {};
|
|
|
|
|
|
// for GGML_VK_PERF_LOGGER
|
|
|
std::unique_ptr<vk_perf_logger> perf_logger;
|
|
|
@@ -4291,9 +4325,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
|
|
|
for (uint32_t use_push = 0; use_push < 2; ++use_push) {
|
|
|
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
|
|
|
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0, use_push}, 1, true, true, device->subgroup_size);
|
|
|
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0, use_push}, 1, true, true, device->subgroup_size);
|
|
|
- ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1, use_push}, 1, true, true, device->subgroup_size);
|
|
|
+ ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][use_push], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 4, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, use_push}, 1, true, true, device->subgroup_size);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -8684,10 +8716,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
if (ctx->num_additional_fused_ops) {
|
|
|
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
|
|
|
GGML_ASSERT(idx < num_topk_moe_pipelines);
|
|
|
- topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
|
|
// use n_experts from push constant if it's not equal to the power of two spec constant
|
|
|
bool use_push = dst->ne[0] != (1u << idx);
|
|
|
- return ctx->device->pipeline_topk_moe[idx][mode][use_push];
|
|
|
+ return ctx->device->pipeline_topk_moe[idx][use_push];
|
|
|
}
|
|
|
|
|
|
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
|
|
|
@@ -10346,14 +10377,16 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
}
|
|
|
|
|
|
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
|
|
|
- topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
|
|
|
+ topk_moe_mode mode = ctx->fused_topk_moe_mode;
|
|
|
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
|
|
|
- ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
|
|
|
- (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] :
|
|
|
- cgraph->nodes[node_idx + 5];
|
|
|
- ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
|
|
|
+ ggml_tensor * bias = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits;
|
|
|
+ ggml_tensor * weights = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
|
|
|
+ ggml_tensor * ids = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] :
|
|
|
+ (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] :
|
|
|
+ cgraph->nodes[node_idx + 3];
|
|
|
|
|
|
GGML_ASSERT(logits->type == GGML_TYPE_F32);
|
|
|
+ GGML_ASSERT(bias->type == GGML_TYPE_F32);
|
|
|
GGML_ASSERT(weights->type == GGML_TYPE_F32);
|
|
|
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
|
|
|
|
|
@@ -10368,6 +10401,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
|
|
|
|
|
|
vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits);
|
|
|
+ vk_subbuffer bias_buf = ggml_vk_tensor_subbuffer(ctx, bias);
|
|
|
vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights);
|
|
|
vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids);
|
|
|
|
|
|
@@ -10375,18 +10409,45 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
|
|
|
pc.n_rows = n_rows;
|
|
|
pc.n_experts_push = n_experts;
|
|
|
pc.n_expert_used = n_expert_used;
|
|
|
+ pc.clamp_min = -std::numeric_limits<float>::infinity();
|
|
|
+ pc.clamp_max = std::numeric_limits<float>::infinity();
|
|
|
if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
|
|
|
ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
|
|
|
+ GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
|
|
|
+ pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
|
|
|
+ pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
|
|
|
+ }
|
|
|
+ if (mode == TOPK_MOE_SIGMOID_NORM_BIAS) {
|
|
|
+ ggml_tensor * clamp = cgraph->nodes[node_idx + 8];
|
|
|
+ GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
|
|
|
pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
|
|
|
pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
|
|
|
}
|
|
|
|
|
|
+#define GATING_FUNC_SOFTMAX 0
|
|
|
+#define GATING_FUNC_SIGMOID 1
|
|
|
+#define GATING_FUNC_SOFTMAX_WEIGHT 2
|
|
|
+
|
|
|
+ pc.gating_func = mode == TOPK_MOE_SIGMOID_NORM_BIAS ? GATING_FUNC_SIGMOID :
|
|
|
+ mode == TOPK_MOE_LATE_SOFTMAX ? GATING_FUNC_SOFTMAX_WEIGHT :
|
|
|
+ GATING_FUNC_SOFTMAX;
|
|
|
+ pc.has_bias = mode == TOPK_MOE_SIGMOID_NORM_BIAS;
|
|
|
+ pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS;
|
|
|
+ if (ctx->fused_topk_moe_scale) {
|
|
|
+ GGML_ASSERT(weights->op == GGML_OP_SCALE);
|
|
|
+ pc.output_scale = ggml_get_op_params_f32(weights, 0);
|
|
|
+ pc.output_bias = ggml_get_op_params_f32(weights, 1);
|
|
|
+ } else {
|
|
|
+ pc.output_scale = 1.0f;
|
|
|
+ pc.output_bias = 0.0f;
|
|
|
+ }
|
|
|
+
|
|
|
GGML_ASSERT(n_expert_used <= n_experts);
|
|
|
|
|
|
const uint32_t rows_per_block = 4;
|
|
|
std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
|
|
|
|
|
|
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, weights_buf, ids_buf}, pc, elements);
|
|
|
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, bias_buf, weights_buf, ids_buf}, pc, elements);
|
|
|
}
|
|
|
|
|
|
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) {
|
|
|
@@ -12128,6 +12189,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
|
|
|
break;
|
|
|
case GGML_OP_UNARY:
|
|
|
+ if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
|
|
|
+ ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
|
|
|
+ break;
|
|
|
+ }
|
|
|
+
|
|
|
switch (ggml_get_unary_op(node)) {
|
|
|
case GGML_UNARY_OP_EXP:
|
|
|
case GGML_UNARY_OP_SILU:
|
|
|
@@ -12175,7 +12241,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
|
|
|
break;
|
|
|
case GGML_OP_SOFT_MAX:
|
|
|
- if (ctx->num_additional_fused_ops) {
|
|
|
+ if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
|
|
|
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
|
|
|
} else {
|
|
|
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node);
|
|
|
@@ -12195,7 +12261,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
|
|
|
break;
|
|
|
case GGML_OP_ARGSORT:
|
|
|
- if (ctx->num_additional_fused_ops) {
|
|
|
+ if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
|
|
|
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
|
|
|
} else {
|
|
|
ggml_vk_argsort(ctx, compute_ctx, src0, node);
|
|
|
@@ -13048,6 +13114,24 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
|
|
|
get_rows = cgraph->nodes[node_idx + 4];
|
|
|
argsort = cgraph->nodes[node_idx + 2];
|
|
|
break;
|
|
|
+ case TOPK_MOE_SIGMOID_NORM_BIAS:
|
|
|
+ softmax = cgraph->nodes[node_idx + 0]; // really sigmoid
|
|
|
+ weights = cgraph->nodes[node_idx + 10];
|
|
|
+ get_rows = cgraph->nodes[node_idx + 5];
|
|
|
+ argsort = cgraph->nodes[node_idx + 3];
|
|
|
+ if (ggml_get_unary_op(softmax) != GGML_UNARY_OP_SIGMOID) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ // bias is expected to be 1D
|
|
|
+ if (ggml_nrows(cgraph->nodes[node_idx + 2]->src[1]) != 1 ||
|
|
|
+ !ggml_is_contiguous(cgraph->nodes[node_idx + 2]->src[1])) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ // sigmoid fusion seems to generate infinities on moltenvk
|
|
|
+ if (ctx->device->driver_id == vk::DriverId::eMoltenvk) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ break;
|
|
|
case TOPK_MOE_EARLY_SOFTMAX:
|
|
|
softmax = cgraph->nodes[node_idx + 0];
|
|
|
weights = cgraph->nodes[node_idx + 4];
|
|
|
@@ -13071,26 +13155,28 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
|
|
|
probs = probs->src[0];
|
|
|
ggml_tensor * selection_probs = argsort->src[0];
|
|
|
|
|
|
- if (probs != selection_probs) {
|
|
|
+ if (probs != selection_probs && mode != TOPK_MOE_SIGMOID_NORM_BIAS) {
|
|
|
return false;
|
|
|
}
|
|
|
|
|
|
- const float * op_params = (const float *)softmax->op_params;
|
|
|
-
|
|
|
- float scale = op_params[0];
|
|
|
- float max_bias = op_params[1];
|
|
|
-
|
|
|
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
|
|
|
return false;
|
|
|
}
|
|
|
|
|
|
- if (scale != 1.0f || max_bias != 0.0f) {
|
|
|
- return false;
|
|
|
- }
|
|
|
+ if (softmax->op == GGML_OP_SOFT_MAX) {
|
|
|
+ const float * op_params = (const float *)softmax->op_params;
|
|
|
|
|
|
- // don't fuse when masks or sinks are present
|
|
|
- if (softmax->src[1] || softmax->src[2]) {
|
|
|
- return false;
|
|
|
+ float scale = op_params[0];
|
|
|
+ float max_bias = op_params[1];
|
|
|
+
|
|
|
+ if (scale != 1.0f || max_bias != 0.0f) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ // don't fuse when masks or sinks are present
|
|
|
+ if (softmax->src[1] || softmax->src[2]) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
const int n_expert = softmax->ne[0];
|
|
|
@@ -13363,6 +13449,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
total_mul_mat_bytes += bytes;
|
|
|
}
|
|
|
|
|
|
+ ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
|
|
|
+ ctx->fused_topk_moe_scale = false;
|
|
|
const char *fusion_string {};
|
|
|
if (!ctx->device->disable_fusion) {
|
|
|
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
|
|
|
@@ -13408,13 +13496,23 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
|
|
|
// view of argsort writes to memory
|
|
|
ctx->fused_ops_write_mask |= 1 << 3;
|
|
|
+ ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;
|
|
|
fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
|
|
|
+ } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&
|
|
|
+ ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&
|
|
|
+ ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {
|
|
|
+ ctx->num_additional_fused_ops = topk_moe_sigmoid_norm_bias.size() - 1;
|
|
|
+ // view of argsort writes to memory
|
|
|
+ ctx->fused_ops_write_mask |= 1 << 4;
|
|
|
+ ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS;
|
|
|
+ fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS";
|
|
|
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
|
|
|
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
|
|
|
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
|
|
|
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
|
|
|
// view of argsort writes to memory
|
|
|
ctx->fused_ops_write_mask |= 1 << 3;
|
|
|
+ ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX;
|
|
|
fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
|
|
|
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
|
|
|
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
|
|
|
@@ -13422,8 +13520,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
|
|
|
// view of argsort writes to memory
|
|
|
ctx->fused_ops_write_mask |= 1 << 1;
|
|
|
+ ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX;
|
|
|
fusion_string = "TOPK_MOE_LATE_SOFTMAX";
|
|
|
}
|
|
|
+ if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
|
|
|
+ // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano.
|
|
|
+ if (ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops - 1, { GGML_OP_DIV, GGML_OP_RESHAPE, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 }) ||
|
|
|
+ ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) {
|
|
|
+ ctx->fused_topk_moe_scale = true;
|
|
|
+ ctx->num_additional_fused_ops++;
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
|
|
|
|
|
|
@@ -13602,6 +13709,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|
|
if (keep_pattern(topk_moe_early_softmax_norm)) {
|
|
|
continue;
|
|
|
}
|
|
|
+ if (keep_pattern(topk_moe_sigmoid_norm_bias)) {
|
|
|
+ continue;
|
|
|
+ }
|
|
|
if (keep_pattern(topk_moe_early_softmax)) {
|
|
|
continue;
|
|
|
}
|
|
|
@@ -13628,6 +13738,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
|
|
|
}
|
|
|
// Don't pull forward nodes from fusion patterns
|
|
|
if (match_pattern(topk_moe_early_softmax_norm, j) ||
|
|
|
+ match_pattern(topk_moe_sigmoid_norm_bias, j) ||
|
|
|
match_pattern(topk_moe_early_softmax, j) ||
|
|
|
match_pattern(topk_moe_late_softmax, j)) {
|
|
|
continue;
|