Browse Source

vulkan: Use mul_mat_vec_id for small values of n (#18918)

Change ggml_vk_mul_mat_vec_id_q_f16 to loop over the batch dimension and
update the indexing calculations in get_offsets.

Mat-vec is faster than mat-mat for small values of n. We don't get the same
reuse of the weights as in the non-ID path, but with this the cost is linear
in n rather than n>1 being far slower than n==1.
Jeff Bolz 1 week ago
parent
commit
50b7f076a5

+ 25 - 22
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -991,6 +991,8 @@ struct vk_mat_vec_id_push_constants {
     uint32_t fusion_flags;
     uint32_t nei0;
     uint32_t ne11;
+    uint32_t expert_i1;
+    uint32_t nbi1;
 };
 
 struct vk_flash_attn_push_constants {
@@ -8083,8 +8085,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
 
     const uint64_t nei0 = ids->ne[0];
     const uint64_t nei1 = ids->ne[1];
-
-    GGML_ASSERT(nei1 == 1);
+    const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int));
 
     const uint64_t ne20 = dst->ne[0];
     const uint64_t ne21 = dst->ne[1];
@@ -8168,7 +8169,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
         if (quantize_y) {
             ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
         }
-        ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
+        ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1);
     }
 
     vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
@@ -8226,7 +8227,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
     uint32_t stride_batch_y = ne10*ne11;
 
     if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
-        stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
+        stride_batch_y = src1->nb[2] / ggml_type_size(src1->type);
     }
 
     const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
@@ -8262,23 +8263,25 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
         fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;
     }
 
-    // compute
-    const vk_mat_vec_id_push_constants pc = {
-        (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
-        (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
-        fusion_flags,
-        (uint32_t)nei0, (uint32_t)ne11,
-    };
-    ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
-        {
-            d_X,
-            d_Y,
-            d_D,
-            d_F0,
-            d_F1,
-            d_ids,
-        },
-        pc, { groups_x, (uint32_t)nei0, groups_z });
+    // Loop over the batch dimension
+    for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) {
+        const vk_mat_vec_id_push_constants pc = {
+            (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
+            (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
+            fusion_flags,
+            (uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1
+        };
+        ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
+            {
+                d_X,
+                d_Y,
+                d_D,
+                d_F0,
+                d_F1,
+                d_ids,
+            },
+            pc, { groups_x, (uint32_t)nei0, groups_z });
+    }
 
     if (x_non_contig) {
         ctx->prealloc_x_need_sync = true;
@@ -8292,7 +8295,7 @@ static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int no
     ggml_tensor * dst = cgraph->nodes[node_idx];
     ggml_tensor * src0 = dst->src[0];
     ggml_tensor * src2 = dst->src[2];
-    return src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
+    return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
 }
 
 static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {

+ 18 - 16
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl

@@ -29,6 +29,8 @@ layout (push_constant) uniform parameter
 #ifdef MUL_MAT_ID
     uint nei0;
     uint ne11;
+    uint expert_i1;
+    uint nbi1;
 #else
     uint ne02;
     uint ne12;
@@ -43,7 +45,7 @@ uint expert_id;
 
 void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
 #ifdef MUL_MAT_ID
-    const uint expert_idx = gl_GlobalInvocationID.y;
+    const uint expert_i0 = gl_GlobalInvocationID.y;
 #else
     const uint batch_idx = gl_GlobalInvocationID.y;
 #endif
@@ -60,7 +62,7 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
         batch_idx_a = i03 * p.ne02 + i02;
     }
 #else
-    expert_id = data_ids[expert_idx];
+    expert_id = data_ids[expert_i0 + p.expert_i1 * p.nbi1];
 #endif
 
     a_offset =
@@ -71,13 +73,13 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
 #endif
     b_offset =
 #ifdef MUL_MAT_ID
-            (expert_idx % p.ne11) * p.stride_b;
+            (expert_i0 % p.ne11) * p.stride_b + p.expert_i1 * p.batch_stride_b;
 #else
             batch_idx * p.batch_stride_b;
 #endif
     d_offset =
 #ifdef MUL_MAT_ID
-            expert_idx * p.stride_d;
+            expert_i0 * p.stride_d + p.expert_i1 * p.batch_stride_d;
 #else
             batch_idx * p.batch_stride_d;
 #endif
@@ -103,12 +105,12 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
                     temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
                 }
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
-                    const uint expert_idx = gl_GlobalInvocationID.y;
-                    temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
+                    const uint expert_i0 = gl_GlobalInvocationID.y;
+                    temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
                 }
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
-                    const uint expert_idx = gl_GlobalInvocationID.y;
-                    temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
+                    const uint expert_i0 = gl_GlobalInvocationID.y;
+                    temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
                 }
 #else
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
@@ -158,12 +160,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
                     temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
                 }
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
-                    const uint expert_idx = gl_GlobalInvocationID.y;
-                    temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]);
+                    const uint expert_i0 = gl_GlobalInvocationID.y;
+                    temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]);
                 }
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
-                    const uint expert_idx = gl_GlobalInvocationID.y;
-                    temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]);
+                    const uint expert_i0 = gl_GlobalInvocationID.y;
+                    temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]);
                 }
 #else
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {
@@ -203,12 +205,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
                     tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]);
                 }
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) {
-                    const uint expert_idx = gl_GlobalInvocationID.y;
-                    tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]);
+                    const uint expert_i0 = gl_GlobalInvocationID.y;
+                    tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_i0]);
                 }
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) {
-                    const uint expert_idx = gl_GlobalInvocationID.y;
-                    tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]);
+                    const uint expert_i0 = gl_GlobalInvocationID.y;
+                    tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_i0]);
                 }
 #else
                 if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) {