Просмотр исходного кода

vulkan: Fix mismatch in TOPK_MOE unit test (#17541)

* Fix shader to support 2D workgroup mapping to a single subgroup

* Set required_subgroup_size

topk_moe shader requires static WARP_SIZE and actual subgroup size to match
Masato Nakasaka 1 месяц назад
Родитель
Сommit
d8c0a7b085

+ 3 - 3
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -4174,9 +4174,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
 
     for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
-        ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX],      "topk_moe_f32_early_softmax_"+std::to_string(i),       topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true);
-        ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i),   topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true);
-        ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX],       "topk_moe_f32_late_softmax"+std::to_string(i),         topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true);
+        ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX],      "topk_moe_f32_early_softmax_"+std::to_string(i),       topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true, device->subgroup_size);
+        ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i),   topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true, device->subgroup_size);
+        ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX],       "topk_moe_f32_late_softmax"+std::to_string(i),         topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true, device->subgroup_size);
     }
 
     for (auto &c : compiles) {

+ 10 - 9
ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp

@@ -75,7 +75,7 @@ void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit
 }
 
 void main() {
-    const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
+    const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID;
     if (row >= n_rows) {
         return;
     }
@@ -83,17 +83,18 @@ void main() {
     const uint logits_offset = n_experts * row;
     const uint weights_offset = n_expert_used * row;
     const uint ids_offset = n_experts * row;
+    const uint lane = gl_SubgroupInvocationID;
 
     float wt[experts_per_thread];
 
     [[unroll]]
     for (uint i = 0; i < n_experts; i += WARP_SIZE) {
-        const uint expert = i + gl_LocalInvocationID.x;
+        const uint expert = i + lane;
         wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
     }
 
     if (!late_softmax) {
-        softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
+        softmax_warp_inplace(wt, n_experts, lane, false);
     }
 
     // at this point, each thread holds a portion of softmax,
@@ -111,11 +112,11 @@ void main() {
 
     for (int k = 0; k < n_expert_used; k++) {
         float max_val    = wt[0];
-        uint   max_expert = gl_LocalInvocationID.x;
+        uint   max_expert = lane;
 
         [[unroll]]
         for (int i = 1; i < experts_per_thread; i++) {
-            const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE;
+            const uint expert = lane + i * WARP_SIZE;
             if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
                 max_val    = wt[i];
                 max_expert = expert;
@@ -132,11 +133,11 @@ void main() {
             }
         }
 
-        if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
+        if ((k & (WARP_SIZE - 1)) == lane) {
             output_weights[k / WARP_SIZE] = max_val;
         }
 
-        if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
+        if ((max_expert & (WARP_SIZE - 1)) == lane) {
             wt[max_expert / WARP_SIZE] = -INFINITY;
 
             ids[ids_offset + k] = max_expert;
@@ -158,12 +159,12 @@ void main() {
     }
 
     if (late_softmax) {
-        softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
+        softmax_warp_inplace(output_weights, n_expert_used, lane, true);
     }
 
     [[unroll]]
     for (uint i = 0; i < experts_per_thread; ++i) {
-        uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
+        uint idx = i * WARP_SIZE + lane;
         if (idx < n_expert_used) {
             weights[weights_offset + idx] = output_weights[i];
         }