Parcourir la source

vulkan: Fuse mul_mat_id+add_id+mul and mul_mat+add+add. (#17287)

These both show up in gpt-oss. Also, cleanup the mul_mat_vec fusion code a bit.
Jeff Bolz il y a 2 mois
Parent
commit
24dc769f1b

+ 179 - 67
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -32,6 +32,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
 #include <memory>
 #include <limits>
 #include <map>
+#include <set>
 #include <unordered_map>
 #include <memory>
 #include <mutex>
@@ -824,6 +825,12 @@ struct vk_mat_mat_push_constants {
     uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
     uint32_t padded_N;
 };
+
+#define MAT_VEC_FUSION_FLAGS_BIAS0 0x1
+#define MAT_VEC_FUSION_FLAGS_BIAS1 0x2
+#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
+#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
+
 struct vk_mat_vec_push_constants {
     uint32_t ncols;
     uint32_t stride_a;
@@ -832,8 +839,7 @@ struct vk_mat_vec_push_constants {
     uint32_t batch_stride_a;
     uint32_t batch_stride_b;
     uint32_t batch_stride_d;
-    uint32_t enable_bias;
-    uint32_t enable_scale;
+    uint32_t fusion_flags;
     uint32_t ne02;
     uint32_t ne12;
     uint32_t broadcast2;
@@ -847,7 +853,7 @@ struct vk_mat_vec_p021_push_constants {
     uint32_t nchannels_y;
     uint32_t b_offset;
     uint32_t d_offset;
-    uint32_t enable_bias;
+    uint32_t fusion_flags;
 };
 
 struct vk_mat_vec_nc_push_constants {
@@ -863,7 +869,7 @@ struct vk_mat_vec_nc_push_constants {
     uint32_t nb03;
     uint32_t nb13;
     uint32_t nb23;
-    uint32_t enable_bias;
+    uint32_t fusion_flags;
 };
 
 struct vk_mat_mat_id_push_constants {
@@ -881,8 +887,7 @@ struct vk_mat_vec_id_push_constants {
     uint32_t batch_stride_a;
     uint32_t batch_stride_b;
     uint32_t batch_stride_d;
-    uint32_t enable_bias;
-    uint32_t enable_scale;
+    uint32_t fusion_flags;
     uint32_t nei0;
     uint32_t ne11;
 };
@@ -3465,8 +3470,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     const uint32_t force_subgroup_size = use_subgroups ? subgroup_size : 0;
     const uint32_t force_subgroup_size16 = use_subgroups16 ? subgroup_size16 : 0;
-    static constexpr uint32_t mul_mat_vec_num_bindings = 4;
-    static constexpr uint32_t mul_mat_vec_id_num_bindings = 5;
+    static constexpr uint32_t mul_mat_vec_num_bindings = 5;
+    static constexpr uint32_t mul_mat_vec_id_num_bindings = 6;
 
     for (uint32_t w = 0; w < DMMV_WG_SIZE_COUNT; ++w) {
         const uint32_t wg_size_subgroup   = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size : (subgroup_size * 4);
@@ -6871,21 +6876,31 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
         groups_x = CEIL_DIV(groups_x, groups_z);
     }
 
-    uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
+    uint32_t fusion_flags = 0;
 
-    vk_subbuffer d_B = d_D;
-
-    if (enable_bias) {
+    vk_subbuffer d_F0 = d_D;
+    if (ctx->num_additional_fused_ops > 0) {
         const ggml_tensor * add = cgraph->nodes[node_idx + 1];
         const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
 
-        d_B = ggml_vk_tensor_subbuffer(ctx, bias);
+        d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
+        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
+    }
+
+    vk_subbuffer d_F1 = d_D;
+    if (ctx->num_additional_fused_ops == 2) {
+        const ggml_tensor * add = cgraph->nodes[node_idx + 2];
+        const ggml_tensor * bias = add->src[0] == cgraph->nodes[node_idx + 1] ? add->src[1] : add->src[0];
+
+        d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);
+        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
     }
 
     // compute
     const vk_mat_vec_push_constants pc = {
         (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
-        stride_batch_x, stride_batch_y, stride_batch_d, enable_bias, 0,
+        stride_batch_x, stride_batch_y, stride_batch_d,
+        fusion_flags,
         (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
     };
     ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
@@ -6893,7 +6908,8 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
                                 d_X,
                                 d_Y,
                                 d_D,
-                                d_B,
+                                d_F0,
+                                d_F1,
                               },
                               pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
 
@@ -6946,22 +6962,31 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
     vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
     vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1, true);
 
-    vk_subbuffer d_B = d_D;
+    vk_subbuffer d_F0 = d_D;
 
-    uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
+    uint32_t fusion_flags = 0;
 
-    if (enable_bias) {
+    if (ctx->num_additional_fused_ops > 0) {
         const ggml_tensor * add = cgraph->nodes[node_idx + 1];
         const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
 
-        d_B = ggml_vk_tensor_subbuffer(ctx, bias);
+        d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
+        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
+    }
+
+    vk_subbuffer d_F1 = d_D;
+    if (ctx->num_additional_fused_ops > 1) {
+        const ggml_tensor * bias = cgraph->nodes[node_idx + 2]->src[1];
+
+        d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);
+        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
     }
 
     // compute
 
     vk_mat_vec_p021_push_constants pc = {
         (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12,
-        0, 0, enable_bias
+        0, 0, fusion_flags
     };
 
     init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
@@ -6977,7 +7002,8 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
             d_Qx,
             d_Qy,
             d_D,
-            d_B,
+            d_F0,
+            d_F1,
         }, pc, { 1, (uint32_t)ne01, workgroups_z });
 }
 
@@ -7029,15 +7055,24 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
     vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops], true);
     vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
     vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1, true);
-    vk_subbuffer d_B = d_D;
+    vk_subbuffer d_F0 = d_D;
 
-    uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
+    uint32_t fusion_flags = 0;
 
-    if (enable_bias) {
+    if (ctx->num_additional_fused_ops > 0) {
         const ggml_tensor * add = cgraph->nodes[node_idx + 1];
         const ggml_tensor * bias = add->src[0] == dst ? add->src[1] : add->src[0];
 
-        d_B = ggml_vk_tensor_subbuffer(ctx, bias);
+        d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
+        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
+    }
+
+    vk_subbuffer d_F1 = d_D;
+    if (ctx->num_additional_fused_ops > 1) {
+        const ggml_tensor * bias = cgraph->nodes[node_idx + 2]->src[1];
+
+        d_F1 = ggml_vk_tensor_subbuffer(ctx, bias);
+        fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
     }
 
     // compute
@@ -7046,7 +7081,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
         row_stride_x, channel_stride_x, channel_stride_y,
         (uint32_t)(ne12 / ne02), (uint32_t)ne12,
         0, 0,
-        nb03, nb13, nb23, enable_bias
+        nb03, nb13, nb23, fusion_flags
     };
 
     init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
@@ -7056,7 +7091,8 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
             d_Qx,
             d_Qy,
             d_D,
-            d_B,
+            d_F0,
+            d_F1,
         }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
 }
 
@@ -7477,7 +7513,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
     vk_subbuffer d_Qx = ggml_vk_tensor_subbuffer(ctx, src0);
     vk_subbuffer d_Qy = ggml_vk_tensor_subbuffer(ctx, src1);
     vk_subbuffer d_ids = ggml_vk_tensor_subbuffer(ctx, ids);
-    vk_subbuffer d_B = d_D;
+    vk_subbuffer d_F0 = d_D;
     vk_subbuffer d_X, d_Y;
 
     if (qx_needs_dequant) {
@@ -7530,30 +7566,34 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
         groups_x = CEIL_DIV(groups_x, groups_z);
     }
 
-    uint32_t enable_bias = 0;
-    uint32_t enable_scale = 0;
+    uint32_t fusion_flags = 0;
+
     if (ctx->num_additional_fused_ops > 0) {
+        const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1];
+
+        d_F0 = ggml_vk_tensor_subbuffer(ctx, bias);
+
         if (cgraph->nodes[node_idx + 1]->op == GGML_OP_MUL) {
-            enable_scale = 1;
+            fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE0;
         } else {
             GGML_ASSERT(cgraph->nodes[node_idx + 1]->op == GGML_OP_ADD_ID);
-            enable_bias = 1;
+            fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS0;
         }
     }
 
-    if (enable_bias || enable_scale) {
-        const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1];
+    vk_subbuffer d_F1 = d_D;
+    if (ctx->num_additional_fused_ops > 1) {
+        const ggml_tensor * scale = cgraph->nodes[node_idx + 2]->src[1];
 
-        d_B = ggml_vk_tensor_subbuffer(ctx, bias);
+        d_F1 = ggml_vk_tensor_subbuffer(ctx, scale);
+        fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;
     }
 
     // compute
     const vk_mat_vec_id_push_constants pc = {
         (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
         (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
-
-        enable_bias, enable_scale,
-
+        fusion_flags,
         (uint32_t)nei0, (uint32_t)ne11,
     };
     ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
@@ -7561,7 +7601,8 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
             d_X,
             d_Y,
             d_D,
-            d_B,
+            d_F0,
+            d_F1,
             d_ids,
         },
         pc, { groups_x, (uint32_t)nei0, groups_z });
@@ -12305,10 +12346,7 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g
             return false;
         }
     }
-    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT && ops.begin()[1] == GGML_OP_ADD) {
-        // additional constraints specific to this fusion
-        const ggml_tensor *mul = cgraph->nodes[node_idx];
-        const ggml_tensor *add = cgraph->nodes[node_idx + 1];
+    auto const &mm_add_ok = [&](const ggml_tensor *mul, const ggml_tensor *add) {
         const ggml_tensor *bias = add->src[0] == mul ? add->src[1] : add->src[0];
 
         // mat-vec only
@@ -12328,14 +12366,31 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g
         if (get_misalign_bytes(ctx, bias) != 0) {
             return false;
         }
-    }
-    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_ADD_ID) {
+        return true;
+    };
+
+    if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_MUL_MAT && ops.begin()[1] == GGML_OP_ADD) {
         // additional constraints specific to this fusion
         const ggml_tensor *mul = cgraph->nodes[node_idx];
         const ggml_tensor *add = cgraph->nodes[node_idx + 1];
-        const ggml_tensor *bias = add->src[1];
 
-        if (mul != add->src[0]) {
+        if (!mm_add_ok(mul, add)) {
+            return false;
+        }
+        if (ops.size() == 3) {
+            if (ops.begin()[2] != GGML_OP_ADD) {
+                return false;
+            }
+            if (!mm_add_ok(add, cgraph->nodes[node_idx + 2])) {
+                return false;
+            }
+        }
+    }
+
+    auto const &mmid_mul_ok = [&](const ggml_tensor *mmid, const ggml_tensor *mul) {
+        const ggml_tensor *scale = mul->src[1];
+
+        if (mmid != mul->src[0]) {
             return false;
         }
         // mat-vec only
@@ -12343,30 +12398,34 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g
             return false;
         }
         // shaders assume the types match
-        if (mul->type != bias->type) {
+        if (mmid->type != scale->type) {
             return false;
         }
         // shaders assume the bias is contiguous
-        if (!ggml_is_contiguous(bias)) {
+        if (!ggml_is_contiguous(scale)) {
             return false;
         }
-        // the ID tensor must be the same for mul_mat_id and add_id
-        if (mul->src[2] != add->src[2]) {
+        // unaligned bias isn't handled
+        if (get_misalign_bytes(ctx, scale) != 0) {
             return false;
         }
-        // unaligned bias isn't handled
-        if (get_misalign_bytes(ctx, bias) != 0) {
+        // shader only indexes by expert index
+        if (scale->ne[0] != 1 ||
+            scale->ne[1] != mul->ne[1] ||
+            scale->ne[2] != 1 ||
+            scale->ne[3] != 1) {
             return false;
         }
-    }
+        return true;
+    };
 
-    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) {
+    if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_ADD_ID) {
         // additional constraints specific to this fusion
-        const ggml_tensor *mmid = cgraph->nodes[node_idx];
-        const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
-        const ggml_tensor *scale = mul->src[1];
+        const ggml_tensor *mul = cgraph->nodes[node_idx];
+        const ggml_tensor *add = cgraph->nodes[node_idx + 1];
+        const ggml_tensor *bias = add->src[1];
 
-        if (mmid != mul->src[0]) {
+        if (mul != add->src[0]) {
             return false;
         }
         // mat-vec only
@@ -12374,22 +12433,37 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g
             return false;
         }
         // shaders assume the types match
-        if (mmid->type != scale->type) {
+        if (mul->type != bias->type) {
             return false;
         }
         // shaders assume the bias is contiguous
-        if (!ggml_is_contiguous(scale)) {
+        if (!ggml_is_contiguous(bias)) {
+            return false;
+        }
+        // the ID tensor must be the same for mul_mat_id and add_id
+        if (mul->src[2] != add->src[2]) {
             return false;
         }
         // unaligned bias isn't handled
-        if (get_misalign_bytes(ctx, scale) != 0) {
+        if (get_misalign_bytes(ctx, bias) != 0) {
             return false;
         }
-        // shader only indexes by expert index
-        if (scale->ne[0] != 1 ||
-            scale->ne[1] != mul->ne[1] ||
-            scale->ne[2] != 1 ||
-            scale->ne[3] != 1) {
+
+        if (ops.size() == 3) {
+            if (ops.begin()[2] != GGML_OP_MUL) {
+                return false;
+            }
+            const ggml_tensor *mul = cgraph->nodes[node_idx + 2];
+            return mmid_mul_ok(add, mul);
+        }
+    }
+
+    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) {
+        // additional constraints specific to this fusion
+        const ggml_tensor *mmid = cgraph->nodes[node_idx];
+        const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
+
+        if (!mmid_mul_ok(mmid, mul)) {
             return false;
         }
     }
@@ -12704,8 +12778,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
             uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
             if (num_adds) {
                 ctx->num_additional_fused_ops = num_adds - 1;
+            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) {
+                ctx->num_additional_fused_ops = 2;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
                 ctx->num_additional_fused_ops = 1;
+            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) {
+                ctx->num_additional_fused_ops = 2;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
                 ctx->num_additional_fused_ops = 1;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
@@ -12872,6 +12950,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
 
     std::vector<ggml_tensor *> new_order;
     std::vector<bool> used(graph->n_nodes, false);
+    std::set<ggml_tensor *> used_node_set;
+
     int first_unused = 0;
     while (first_unused < graph->n_nodes) {
         std::vector<int> current_set;
@@ -12894,6 +12974,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
             if (match_pattern(pattern, first_unused)) {
                 for (size_t j = 0; j < pattern.size(); ++j) {
                     new_order.push_back(graph->nodes[first_unused + j]);
+                    used_node_set.insert(graph->nodes[first_unused + j]);
                     used[first_unused + j] = true;
                 }
                 while (first_unused < graph->n_nodes && used[first_unused]) {
@@ -12997,6 +13078,36 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
                         used[set_rows_idx] = true;
                     }
                 }
+                // Look for MUL_MAT_ID + ADD_ID + MUL
+                if (j > 0 &&
+                    graph->nodes[j]->op == GGML_OP_ADD_ID &&
+                    graph->nodes[j-1]->op == GGML_OP_MUL_MAT_ID) {
+                    for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
+                        if (graph->nodes[k]->op == GGML_OP_MUL &&
+                            graph->nodes[k]->src[0] == graph->nodes[j] &&
+                            // src1 must either be weights or already processed
+                            (graph->nodes[k]->src[1]->op == GGML_OP_NONE || used_node_set.find(graph->nodes[k]->src[1]) != used_node_set.end())) {
+                            current_set.push_back(k);
+                            used[k] = true;
+                            break;
+                        }
+                    }
+                }
+                // Look for MUL_MAT + ADD + ADD
+                if (j > 0 &&
+                    graph->nodes[j]->op == GGML_OP_ADD &&
+                    graph->nodes[j-1]->op == GGML_OP_MUL_MAT) {
+                    for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
+                        if (graph->nodes[k]->op == GGML_OP_ADD &&
+                            graph->nodes[k]->src[0] == graph->nodes[j] &&
+                            // src1 must either be weights or already processed
+                            (graph->nodes[k]->src[1]->op == GGML_OP_NONE || used_node_set.find(graph->nodes[k]->src[1]) != used_node_set.end())) {
+                            current_set.push_back(k);
+                            used[k] = true;
+                            break;
+                        }
+                    }
+                }
             }
         }
         // Second pass grabs view nodes.
@@ -13029,6 +13140,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
         // Push the current set into new_order
         for (auto c : current_set) {
             new_order.push_back(graph->nodes[c]);
+            used_node_set.insert(graph->nodes[c]);
             used[c] = true;
         }
         while (first_unused < graph->n_nodes && used[first_unused]) {

+ 47 - 49
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl

@@ -11,29 +11,7 @@
 #define EXPERT_COUNT 8
 #endif
 
-#include "types.glsl"
-
-#ifndef MMQ
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-#else
-layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
-#endif
-
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-#ifdef B_TYPE_VEC2
-layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
-#endif
-#ifdef B_TYPE_VEC4
-layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
-#endif
-
-layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
-
-layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
-
-#ifdef MUL_MAT_ID
-layout (binding = 4) readonly buffer IDS {int data_ids[];};
-#endif
+#include "mul_mat_vec_iface.glsl"
 
 #include "dequant_funcs.glsl"
 
@@ -48,8 +26,7 @@ layout (push_constant) uniform parameter
     uint batch_stride_b;
     uint batch_stride_d;
 
-    uint enable_bias;
-    uint enable_scale;
+    uint fusion_flags;
 
 #ifdef MUL_MAT_ID
     uint nei0;
@@ -123,17 +100,24 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
     if (tid == 0) {
         [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
             [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-                if (p.enable_bias != 0) {
 #ifdef MUL_MAT_ID
-                    temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
-#else
-                    temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
-#endif
+                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
+                    temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
                 }
-#ifdef MUL_MAT_ID
-                if (p.enable_scale != 0) {
+                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
                     const uint expert_idx = gl_GlobalInvocationID.y;
-                    temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
+                    temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
+                }
+                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
+                    const uint expert_idx = gl_GlobalInvocationID.y;
+                    temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
+                }
+#else
+                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
+                    temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
+                }
+                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
+                    temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
                 }
 #endif
                 data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
@@ -171,17 +155,24 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
                 [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
                     temp[j][n] += tmpsh[j][n][s];
                 }
-                if (p.enable_bias != 0) {
 #ifdef MUL_MAT_ID
-                    temp[j][n] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
-#else
-                    temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
-#endif
+                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
+                    temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
                 }
-#ifdef MUL_MAT_ID
-                if (p.enable_scale != 0) {
+                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
                     const uint expert_idx = gl_GlobalInvocationID.y;
-                    temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
+                    temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
+                }
+                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
+                    const uint expert_idx = gl_GlobalInvocationID.y;
+                    temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
+                }
+#else
+                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
+                    temp[j][n] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
+                }
+                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
+                    temp[j][n] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
                 }
 #endif
                 data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
@@ -209,17 +200,24 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
     if (tid == 0) {
         [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
             [[unroll]] for (uint n = 0; n < num_rows; ++n) {
-                if (p.enable_bias != 0) {
 #ifdef MUL_MAT_ID
-                    tmpsh[j][n][0] += FLOAT_TYPE(data_bias[expert_id*p.stride_d + first_row + n]);
-#else
-                    tmpsh[j][n][0] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
-#endif
+                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
+                    tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
                 }
-#ifdef MUL_MAT_ID
-                if (p.enable_scale != 0) {
+                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
+                    const uint expert_idx = gl_GlobalInvocationID.y;
+                    tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]);
+                }
+                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
                     const uint expert_idx = gl_GlobalInvocationID.y;
-                    tmpsh[j][n][0] *= FLOAT_TYPE(data_bias[expert_idx]);
+                    tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]);
+                }
+#else
+                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
+                    tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[j*p.batch_stride_d + d_offset + first_row + n]);
+                }
+                if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
+                    tmpsh[j][n][0] += FLOAT_TYPE(data_fuse1[j*p.batch_stride_d + d_offset + first_row + n]);
                 }
 #endif
                 data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);

+ 33 - 0
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl

@@ -0,0 +1,33 @@
+#include "types.glsl"
+
+#define MAT_VEC_FUSION_FLAGS_BIAS0 0x1
+#define MAT_VEC_FUSION_FLAGS_BIAS1 0x2
+#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
+#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
+
+#ifndef MMQ
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+#if defined(A_TYPE_VEC4)
+layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
+#endif
+#else
+layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
+#endif
+
+layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
+#ifdef B_TYPE_VEC2
+layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
+#endif
+#ifdef B_TYPE_VEC4
+layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
+#endif
+
+layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
+
+layout (binding = 3) readonly buffer Fuse0 {D_TYPE data_fuse0[];};
+layout (binding = 4) readonly buffer Fuse1 {D_TYPE data_fuse1[];};
+
+#ifdef MUL_MAT_ID
+layout (binding = 5) readonly buffer IDS {int data_ids[];};
+#endif
+

+ 8 - 12
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp

@@ -8,14 +8,7 @@
 
 layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
 
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
-
-layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
-layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
-
-layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
+#include "mul_mat_vec_iface.glsl"
 
 layout (push_constant) uniform parameter
 {
@@ -31,7 +24,7 @@ layout (push_constant) uniform parameter
     uint nb03;
     uint nb13;
     uint nb23;
-    uint enable_bias;
+    uint fusion_flags;
 } p;
 
 shared FLOAT_TYPE tmp[BLOCK_SIZE];
@@ -120,9 +113,12 @@ void main() {
     }
 
     if (tid == 0) {
-        if (p.enable_bias != 0) {
-            tmp[0] += FLOAT_TYPE(data_bias[idst]);
+        if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
+            tmp[0] += FLOAT_TYPE(data_fuse0[idst]);
+        }
+        if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
+            tmp[0] += FLOAT_TYPE(data_fuse1[idst]);
         }
-        dst[idst] = tmp[0];
+        data_d[idst] = tmp[0];
     }
 }

+ 8 - 12
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp

@@ -10,14 +10,7 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
-layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
-layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
-
-layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
-layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
-
-layout (binding = 3) readonly buffer Bias {D_TYPE data_bias[];};
+#include "mul_mat_vec_iface.glsl"
 
 layout(constant_id = 0) const int BLOCK_SIZE = 32;
 // gqa_ratio is in the range [1,8]
@@ -31,7 +24,7 @@ layout (push_constant) uniform parameter
     uint nchannels_y;
     uint b_offset;
     uint d_offset;
-    uint enable_bias;
+    uint fusion_flags;
 } p;
 
 #if !USE_SUBGROUP_ADD
@@ -151,10 +144,13 @@ void main() {
         [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
             // dst is not transposed and not permuted
             const uint idst = (channel + c)*nrows_dst + row_dst;
-            if (p.enable_bias != 0) {
-                temp[c] += FLOAT_TYPE(data_bias[idst]);
+            if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
+                temp[c] += FLOAT_TYPE(data_fuse0[idst]);
+            }
+            if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS1) != 0) {
+                temp[c] += FLOAT_TYPE(data_fuse1[idst]);
             }
-            dst[idst] = temp[c];
+            data_d[idst] = temp[c];
         }
     }
 }

+ 19 - 5
tests/test-backend-ops.cpp

@@ -5002,17 +5002,19 @@ struct test_mul_mat_vec_fusion : public test_case {
     const bool b;        // broadcast b matrix (only for use_id)
     const bool with_bias;
     const bool with_gate;
+    std::array<int64_t, 2> batch_dims;
 
     test_mul_mat_vec_fusion(ggml_type type, ggml_glu_op op, int64_t m, int64_t n, int64_t k,
-                        bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true)
-    : type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate) {
+                        bool use_id = false, int n_mats = 1, int n_used = 1, bool b = false, bool with_bias = false, bool with_gate = true,
+                        std::array<int64_t, 2> batch_dims = {4, 2})
+    : type(type), glu_op(op), m(m), n(n), k(k), use_id(use_id), n_mats(n_mats), n_used(n_used), b(b), with_bias(with_bias), with_gate(with_gate), batch_dims(batch_dims) {
         if (use_id) {
             GGML_ASSERT(n_used <= n_mats);
         }
     }
 
     std::string vars() override {
-        return VARS_TO_STR11(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate);
+        return VARS_TO_STR12(type, glu_op, m, n, k, use_id, n_mats, n_used, b, with_bias, with_gate, batch_dims);
     }
 
     std::string op_desc(ggml_tensor * t) override {
@@ -5038,8 +5040,8 @@ struct test_mul_mat_vec_fusion : public test_case {
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         if (!use_id) {
-            const int              channels = 4;
-            const int              samples  = 2;
+            const int              channels = batch_dims[0];
+            const int              samples  = batch_dims[1];
             std::array<int64_t, 4> ne       = { k, m, channels, samples };
             std::array<int64_t, 4> ne0      = { k, n, channels, samples };
 
@@ -5062,6 +5064,11 @@ struct test_mul_mat_vec_fusion : public test_case {
             }
 
             ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
+
+            std::array<int64_t, 4> bias2_ne   = { out->ne[0], 1, channels, samples };
+            ggml_tensor * bias2 = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, bias2_ne.data());
+            out = ggml_add(ctx, out, bias2);
+
             ggml_set_name(out, "out");
             return out;
         } else {
@@ -5089,6 +5096,11 @@ struct test_mul_mat_vec_fusion : public test_case {
             }
 
             ggml_tensor * out = with_gate ? build_gate(ctx, ffn_gate, ffn_up) : ffn_up;
+
+            std::array<int64_t, 4> scale_ne { 1, out->ne[1], out->ne[2], out->ne[3] };
+            ggml_tensor * scale = ggml_new_tensor(ctx, out->type, 4, scale_ne.data());
+            out = ggml_mul(ctx, out, scale);
+
             ggml_set_name(out, "out");
             return out;
         }
@@ -7645,6 +7657,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                             }
                             test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256,
                                 use_id, 16, 8, b, with_bias, with_gate));
+                            test_cases.emplace_back(new test_mul_mat_vec_fusion(type, glu_op, 1, 32, 256,
+                                use_id, 16, 8, b, with_bias, with_gate, {1, 1}));
                         }
                     }
                 }