|
|
@@ -991,6 +991,8 @@ struct vk_mat_vec_id_push_constants {
|
|
|
uint32_t fusion_flags;
|
|
|
uint32_t nei0;
|
|
|
uint32_t ne11;
|
|
|
+ uint32_t expert_i1;
|
|
|
+ uint32_t nbi1;
|
|
|
};
|
|
|
|
|
|
struct vk_flash_attn_push_constants {
|
|
|
@@ -8083,8 +8085,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|
|
|
|
|
const uint64_t nei0 = ids->ne[0];
|
|
|
const uint64_t nei1 = ids->ne[1];
|
|
|
-
|
|
|
- GGML_ASSERT(nei1 == 1);
|
|
|
+ const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int));
|
|
|
|
|
|
const uint64_t ne20 = dst->ne[0];
|
|
|
const uint64_t ne21 = dst->ne[1];
|
|
|
@@ -8168,7 +8169,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|
|
if (quantize_y) {
|
|
|
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
|
|
|
}
|
|
|
- ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
|
|
|
+ ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1);
|
|
|
}
|
|
|
|
|
|
vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
|
|
|
@@ -8226,7 +8227,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|
|
uint32_t stride_batch_y = ne10*ne11;
|
|
|
|
|
|
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
|
|
|
- stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
|
|
|
+ stride_batch_y = src1->nb[2] / ggml_type_size(src1->type);
|
|
|
}
|
|
|
|
|
|
const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
|
|
|
@@ -8262,23 +8263,25 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
|
|
|
fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;
|
|
|
}
|
|
|
|
|
|
- // compute
|
|
|
- const vk_mat_vec_id_push_constants pc = {
|
|
|
- (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
|
|
- (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
|
|
|
- fusion_flags,
|
|
|
- (uint32_t)nei0, (uint32_t)ne11,
|
|
|
- };
|
|
|
- ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
|
|
- {
|
|
|
- d_X,
|
|
|
- d_Y,
|
|
|
- d_D,
|
|
|
- d_F0,
|
|
|
- d_F1,
|
|
|
- d_ids,
|
|
|
- },
|
|
|
- pc, { groups_x, (uint32_t)nei0, groups_z });
|
|
|
+ // Loop over the batch dimension
|
|
|
+ for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) {
|
|
|
+ const vk_mat_vec_id_push_constants pc = {
|
|
|
+ (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
|
|
+ (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
|
|
|
+ fusion_flags,
|
|
|
+ (uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1
|
|
|
+ };
|
|
|
+ ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
|
|
|
+ {
|
|
|
+ d_X,
|
|
|
+ d_Y,
|
|
|
+ d_D,
|
|
|
+ d_F0,
|
|
|
+ d_F1,
|
|
|
+ d_ids,
|
|
|
+ },
|
|
|
+ pc, { groups_x, (uint32_t)nei0, groups_z });
|
|
|
+ }
|
|
|
|
|
|
if (x_non_contig) {
|
|
|
ctx->prealloc_x_need_sync = true;
|
|
|
@@ -8292,7 +8295,7 @@ static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int no
|
|
|
ggml_tensor * dst = cgraph->nodes[node_idx];
|
|
|
ggml_tensor * src0 = dst->src[0];
|
|
|
ggml_tensor * src2 = dst->src[2];
|
|
|
- return src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
|
|
|
+ return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
|
|
|
}
|
|
|
|
|
|
static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
|