Parcourir la source

vulkan: Implement top-k (#17418)

* vulkan: Implement top-k

Each pass launches workgroups that each sort 2^N elements (where N is usually 7-10)
and discards all but the top K. Repeat until only K are left. And there's a fast
path when K==1 to just find the max value rather than sorting.

* fix pipeline selection

* vulkan: Add N-ary search algorithm for topk

* microoptimizations
Jeff Bolz il y a 1 mois
Parent
commit
879d673759

+ 150 - 0
ggml/src/ggml-vulkan/ggml-vulkan.cpp

@@ -409,6 +409,7 @@ enum shader_reduction_mode {
 // argsort pipelines for up to 1<<10 invocations per workgroup
 static constexpr uint32_t num_argsort_pipelines = 11;
 static constexpr uint32_t num_topk_moe_pipelines = 10;
+static constexpr uint32_t num_topk_pipelines = 11;
 
 static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
                                                                              GGML_OP_VIEW,     GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
@@ -515,6 +516,7 @@ struct vk_device_struct {
     bool single_queue;
     bool support_async;
     uint32_t subgroup_size;
+    uint32_t subgroup_size_log2;
     uint32_t shader_core_count;
     bool uma;
     bool prefer_host_memory;
@@ -704,6 +706,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
     vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
     vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
+    vk_pipeline pipeline_topk_f32[num_topk_pipelines];
     vk_pipeline pipeline_sum_rows_f32;
     vk_pipeline pipeline_cumsum_f32;
     vk_pipeline pipeline_argmax_f32;
@@ -1205,6 +1208,15 @@ struct vk_op_argsort_push_constants {
     uint32_t inner_end;
 };
 
+struct vk_op_topk_push_constants {
+    uint32_t orig_ncols;
+    uint32_t ncols_input;
+    uint32_t ncols_output;
+    uint32_t nrows;
+    uint32_t first_pass;
+    uint32_t last_pass;
+};
+
 struct vk_op_im2col_push_constants {
     uint64_t dst_addr;
     uint32_t batch_offset; uint32_t offset_delta;
@@ -3965,6 +3977,23 @@ static void ggml_vk_load_shaders(vk_device& device) {
         ggml_vk_create_pipeline2(device, device->pipeline_argsort_large_f32[i], "argsort_large_f32_"+std::to_string(i), argsort_large_f32_len, argsort_large_f32_data, "main", 3, sizeof(vk_op_argsort_push_constants), {BLOCK_SIZE * WG_UNROLL_FACTOR, 1, 1}, {BLOCK_SIZE, WG_UNROLL_FACTOR}, 1, true);
     }
 
+    for (uint32_t i = 0; i < num_topk_pipelines; ++i) {
+        const uint32_t BLOCK_SIZE = 1u << i;
+        const uint32_t NCOLS_PADDED_LOG2 = i;
+        if (i <= device->max_workgroup_size_log2) {
+            uint32_t nary_shmem = 2 * sizeof(int) * BLOCK_SIZE +
+                                  sizeof(int) * device->subgroup_size +
+                                  2 * sizeof(int) +
+                                  (BLOCK_SIZE / device->subgroup_size) * sizeof(int);
+            if (device->subgroup_arithmetic && device->subgroup_require_full_support && device->subgroup_shuffle && device->subgroup_ballot &&
+                nary_shmem <= device->properties.limits.maxComputeSharedMemorySize) {
+                ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_nary_search_f32_len, topk_nary_search_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, device->subgroup_size, device->subgroup_size_log2}, 1, true, true, device->subgroup_size);
+            } else if (2 * sizeof(int) * BLOCK_SIZE <= device->properties.limits.maxComputeSharedMemorySize) {
+                ggml_vk_create_pipeline2(device, device->pipeline_topk_f32[i], "topk_f32_"+std::to_string(i), topk_argsort_f32_len, topk_argsort_f32_data, "main", 2, sizeof(vk_op_topk_push_constants), {BLOCK_SIZE, 1, 1}, {BLOCK_SIZE, NCOLS_PADDED_LOG2}, 1, true);
+            }
+        }
+    }
+
     ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
@@ -4336,6 +4365,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
         device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
 
         device->subgroup_size = subgroup_props.subgroupSize;
+        device->subgroup_size_log2 = uint32_t(log2f(float(device->subgroup_size)));
         device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
         if (sm_builtins) {
             device->shader_core_count = sm_props.shaderSMCount;
@@ -10143,6 +10173,104 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
     }
 }
 
+static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+    uint32_t ncols = src0->ne[0];
+    uint32_t nrows = ggml_nrows(src0);
+    uint32_t k = dst->ne[0];
+
+    vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 };
+
+    // Reserve space for ivec2 per element, double buffered
+    const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int);
+    const size_t x_sz = dbl_buf_size * 2;
+    uint32_t dbl_buf_index = 0;
+
+    if (ctx->prealloc_size_x < x_sz) {
+        ctx->prealloc_size_x = x_sz;
+        ggml_vk_preallocate_buffers(ctx, subctx);
+    }
+    if (ctx->prealloc_x_need_sync) {
+        ggml_vk_sync_buffers(ctx, subctx);
+    }
+
+    std::array<uint32_t, 3> elements;
+    elements[1] = std::min(nrows, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
+    elements[2] = 1;
+
+    uint32_t num_elements = ncols;
+
+    // Each iteration reduces a workgroup's worth of elements down to the K
+    // largest elements. Repeat until we have the top K elements.
+    // Need to do at least one iteration to write out the results.
+    bool done_one_iter = false;
+    while (num_elements > k || !done_one_iter) {
+        done_one_iter = true;
+
+        // Prefer going as small as num_topk_pipelines - 3 for perf reasons.
+        // But if K is larger, then we need a larger workgroup
+        uint32_t max_pipeline = num_topk_pipelines - 3;
+        uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
+        // require full subgroup
+        min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);
+
+        uint32_t pipeline_idx = (uint32_t)ceilf(log2f(float(num_elements)));
+        pipeline_idx = std::min(pipeline_idx, max_pipeline);
+        pipeline_idx = std::max(pipeline_idx, min_pipeline);
+
+        if (num_elements > (1u << pipeline_idx)) {
+            // If we could finish on this loop iteration (i.e. a single workgroup)
+            // then do so. It's better than the overhead of another pass.
+            for (uint32_t i = pipeline_idx; i < num_topk_pipelines; ++i) {
+                if (num_elements <= (1u << i)) {
+                    pipeline_idx = i;
+                    break;
+                }
+            }
+        }
+
+        vk_pipeline pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
+        // If the device doesn't support a pipeline this large, use smaller
+        while (!pipeline) {
+            pipeline_idx--;
+            GGML_ASSERT(pipeline_idx >= min_pipeline);
+            pipeline = ctx->device->pipeline_topk_f32[pipeline_idx];
+        }
+
+        vk_op_topk_push_constants pc2 = pc;
+        pc2.ncols_input = num_elements;
+
+        // Number of elements remaining after this pass
+        uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);
+
+        vk_subbuffer src_buf;
+        vk_subbuffer dst_buf;
+
+        if (num_elements == ncols) {
+            pc2.first_pass = 1;
+            src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
+        } else {
+            src_buf = { ctx->prealloc_x, dbl_buf_index * dbl_buf_size, dbl_buf_size };
+        }
+        if (num_dst_elements == k) {
+            pc2.last_pass = 1;
+            dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
+        } else {
+            dst_buf = { ctx->prealloc_x, (dbl_buf_index ^ 1) * dbl_buf_size, dbl_buf_size };
+        }
+
+        elements[0] = num_elements;
+
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc2, elements);
+        num_elements = num_dst_elements;
+        dbl_buf_index ^= 1;
+        if (num_elements > k) {
+            ggml_vk_sync_buffers(ctx, subctx);
+        }
+    }
+    ctx->prealloc_x_need_sync = true;
+}
+
 static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0));
     ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_SUM, p);
@@ -11755,6 +11883,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
             ggml_vk_argsort(ctx, compute_ctx, src0, node);
         }
 
+        break;
+    case GGML_OP_TOP_K:
+        ggml_vk_topk(ctx, compute_ctx, src0, node);
+
         break;
     case GGML_OP_SUM:
         ggml_vk_sum(ctx, compute_ctx, src0, node);
@@ -13787,6 +13919,22 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                     return op->ne[0] <= (1 << device->max_workgroup_size_log2);
                 }
             }
+        case GGML_OP_TOP_K:
+            {
+                if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
+                    return false;
+                }
+                ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+                auto device = ggml_vk_get_device(ctx->device);
+                // We could potentially support larger, using argsort to sort the
+                // whole thing. Not clear if this is needed.
+                uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1;
+                if (min_pipeline >= num_topk_pipelines ||
+                    !device->pipeline_topk_f32[min_pipeline]) {
+                    return false;
+                }
+            }
+            return true;
         case GGML_OP_UPSCALE:
         case GGML_OP_ACC:
         case GGML_OP_CONCAT:
@@ -14459,6 +14607,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
             tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]);
         } else if (tensor->op == GGML_OP_ARGSORT) {
             tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params);
+        } else if (tensor->op == GGML_OP_TOP_K) {
+            tensor_clone = ggml_top_k(ggml_ctx, src_clone[0], tensor->ne[0]);
         } else if (tensor->op == GGML_OP_SUM) {
             tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
         } else if (tensor->op == GGML_OP_SUM_ROWS) {

+ 113 - 0
ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp

@@ -0,0 +1,113 @@
+#version 450
+#extension GL_EXT_control_flow_attributes : enable
+
+#include "types.glsl"
+
+layout(constant_id = 0) const int BLOCK_SIZE = 1024;
+layout(constant_id = 1) const int NCOLS_PADDED_LOG2 = 10;
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+// Input can either be the source (A) or intermediate values (S).
+// Similarly, output can be either destination (D) or intermediate values (S).
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 0) readonly buffer S {ivec2 data_s[];};
+layout (binding = 1) writeonly buffer D {int data_d[];};
+layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
+
+layout (push_constant) uniform parameter {
+    uint orig_ncols;
+    uint ncols_input;
+    uint ncols_output;
+    uint nrows;
+    uint first_pass;
+    uint last_pass;
+} p;
+
+// pairs of (gid, value)
+shared ivec2 dst_row[BLOCK_SIZE];
+
+void topk(bool needs_bounds_check, const uint row) {
+    const int col = int(gl_LocalInvocationID.x);
+
+    // initialize indices
+    if (gl_GlobalInvocationID.x < p.ncols_input) {
+        if (p.first_pass != 0) {
+            const uint row_offset = row * p.ncols_input;
+            dst_row[col] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
+        } else {
+            const uint row_offset = row * p.orig_ncols;
+            dst_row[col] = data_s[row_offset + gl_GlobalInvocationID.x];
+        }
+    } else {
+        dst_row[col] = ivec2(p.orig_ncols, 0);
+    }
+    barrier();
+
+    if (p.ncols_output == 1) {
+        // Fast path for single output - just do a max reduction
+        [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
+            if (col < s) {
+                ivec2 a = dst_row[col];
+                ivec2 b = dst_row[col + s];
+                if (a.x >= p.orig_ncols ||
+                    b.x < p.orig_ncols && b.y > a.y) {
+                    dst_row[col] = b;
+                }
+            }
+            barrier();
+        }
+    } else {
+        // bitonic sort on this group of elements
+        uint num_outer_loop_iters = NCOLS_PADDED_LOG2;
+        for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) {
+            uint num_inner_loop_iters = outer_idx + 1;
+            for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) {
+                const int ixj = int(col ^ j);
+
+                int idx_0 = (col & k) == 0 ? col : ixj;
+                int idx_1 = (col & k) == 0 ? ixj : col;
+
+                ivec2 sh_idx_0 = dst_row[idx_0];
+                ivec2 sh_idx_1 = dst_row[idx_1];
+                bool idx_0_oob = needs_bounds_check ? sh_idx_0.x >= p.orig_ncols : false;
+                bool idx_1_oob = needs_bounds_check ? sh_idx_1.x >= p.orig_ncols : false;
+
+                if ((idx_0_oob ||
+                    (!idx_1_oob && intBitsToFloat(sh_idx_0.y) < intBitsToFloat(sh_idx_1.y))) && (ixj > col)) {
+                    dst_row[idx_0] = sh_idx_1;
+                    dst_row[idx_1] = sh_idx_0;
+                }
+
+                barrier();
+            }
+        }
+    }
+
+    if (col < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
+        if (p.last_pass != 0) {
+            const uint row_offset = row * p.ncols_output;
+            data_d[row_offset + col] = dst_row[col].x;
+        } else {
+            const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
+            data_t[row_offset + col] = dst_row[col];
+        }
+    }
+}
+
+void main() {
+    // Fast path for fully occupied workgroups
+    if ((p.ncols_input % BLOCK_SIZE) == 0) {
+        uint row = gl_WorkGroupID.y;
+        while (row < p.nrows) {
+            topk(false, row);
+            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+        }
+    } else {
+        uint row = gl_WorkGroupID.y;
+        while (row < p.nrows) {
+            topk(true, row);
+            row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+        }
+    }
+}

+ 199 - 0
ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp

@@ -0,0 +1,199 @@
+#version 450
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_debug_printf : enable
+#extension GL_KHR_shader_subgroup_basic : enable
+#extension GL_KHR_shader_subgroup_ballot : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#extension GL_KHR_shader_subgroup_shuffle : enable
+
+#include "types.glsl"
+
+layout(constant_id = 0) const int BLOCK_SIZE = 1024;
+layout(constant_id = 1) const int SUBGROUP_SIZE = 32;
+layout(constant_id = 2) const int SUBGROUP_SIZE_LOG2 = 5;
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+// Input can either be the source (A) or intermediate values (S).
+// Similarly, output can be either destination (D) or intermediate values (S).
+layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
+layout (binding = 0) readonly buffer S {ivec2 data_s[];};
+layout (binding = 1) writeonly buffer D {int data_d[];};
+layout (binding = 1) writeonly buffer T {ivec2 data_t[];};
+
+layout (push_constant) uniform parameter {
+    uint orig_ncols;
+    uint ncols_input;
+    uint ncols_output;
+    uint nrows;
+    uint first_pass;
+    uint last_pass;
+} p;
+
+// pairs of (gid, value)
+shared ivec2 dst_row[BLOCK_SIZE];
+
+shared int counts[SUBGROUP_SIZE];
+shared int sh_min_idx;
+shared uint sh_total;
+shared uint offset_partials[BLOCK_SIZE / SUBGROUP_SIZE];
+
+// Map float values to uint such that comparisons still work.
+// Positive values set the high bit, negative values are inverted.
+// +0.0 -> 0x80000000, -0.0 -> 0x7FFFFFFF are in the correct places.
+uint f2ui(float x) {
+    uint y = floatBitsToUint(x);
+    if ((y & 0x80000000) != 0) {
+        y ^= ~0;
+    } else {
+        y |= 0x80000000;
+    }
+    return y;
+}
+
+void topk(const uint row) {
+    const int tid = int(gl_LocalInvocationID.x);
+
+    // initialize indices
+    if (gl_GlobalInvocationID.x < p.ncols_input) {
+        if (p.first_pass != 0) {
+            const uint row_offset = row * p.ncols_input;
+            dst_row[tid] = ivec2(gl_GlobalInvocationID.x, floatBitsToInt(data_a[row_offset + gl_GlobalInvocationID.x]));
+        } else {
+            const uint row_offset = row * p.orig_ncols;
+            dst_row[tid] = data_s[row_offset + gl_GlobalInvocationID.x];
+        }
+    } else {
+        dst_row[tid] = ivec2(p.orig_ncols, 0xFF800000); // -inf
+    }
+    barrier();
+
+    if (p.ncols_output == 1) {
+        // Fast path for single output - just do a max reduction
+        [[unroll]] for (int s = BLOCK_SIZE / 2; s >= 1; s /= 2) {
+            if (tid < s) {
+                ivec2 a = dst_row[tid];
+                ivec2 b = dst_row[tid + s];
+                if (a.x >= p.orig_ncols ||
+                    b.x < p.orig_ncols && b.y > a.y) {
+                    dst_row[tid] = b;
+                }
+            }
+            barrier();
+        }
+    } else {
+        // Do an N-ary search to find the K-th largest value.
+        // We remap the float values to be comparable as unsigned integers,
+        // and split the range into 2^N smaller ranges where N is the
+        // subgroup size. Count how many values are in each range, if the K-th
+        // largest value is in the middle of one of thee ranges then repeat
+        // and split again.
+
+        // Mask is the current set of bits we're searching. Shift is the LSB index.
+        int shift = 32 - SUBGROUP_SIZE_LOG2;
+        uint mask = ((1 << SUBGROUP_SIZE_LOG2) - 1) << shift;
+
+        // The current range.
+        uint range_min = 0;
+        uint range_max = 0xFF800000;
+        // How many are above the current range, and how many we need to find.
+        uint total = 0;
+        uint limit = min(p.ncols_output, p.ncols_input - gl_WorkGroupID.x * BLOCK_SIZE);
+
+        while (mask != 0) {
+            barrier();
+            // Initialize bucket counts to zero.
+            if (tid < SUBGROUP_SIZE) {
+                counts[tid] = 0;
+            }
+            barrier();
+            // Count how many values are in each bucket.
+            if (tid < p.ncols_input) {
+                float y = intBitsToFloat(dst_row[tid].y);
+                uint fy = f2ui(y);
+                if (fy >= range_min && fy < range_max) {
+                    uint bucket = (fy & mask) >> shift;
+                    atomicAdd(counts[bucket], 1);
+                }
+            }
+            barrier();
+
+            // On the first subgroup, do a scan to count (from the top down) how
+            // many elements are in the top N buckets. Find the index of the first
+            // that is over the limit. Copy it to the other invocations through
+            // shared memory.
+            if (tid < SUBGROUP_SIZE) {
+                uint partial_sum = counts[SUBGROUP_SIZE - 1 - tid];
+                partial_sum = subgroupInclusiveAdd(partial_sum) + total;
+                uint t = subgroupBallotFindLSB(subgroupBallot(partial_sum >= limit));
+                if (tid == t) {
+                    sh_min_idx = int(SUBGROUP_SIZE - 1 - t);
+                    sh_total = partial_sum;
+                }
+            }
+            barrier();
+            int min_idx = sh_min_idx;
+            total = sh_total;
+
+            // Update the range, and break if we've found the K-th largest.
+            range_max = range_min + ((min_idx + 1) << shift);
+            range_min = range_min + (min_idx << shift);
+
+            if (total == p.ncols_output) {
+                break;
+            }
+            total -= counts[min_idx];
+            mask >>= SUBGROUP_SIZE_LOG2;
+            shift -= SUBGROUP_SIZE_LOG2;
+            if (shift < 0) {
+                shift = 0;
+            }
+        }
+
+        ivec2 v = dst_row[tid];
+
+        // We need to compact these values to the start of the dst_row array.
+        // Have each subgroup count how many items it'll store, so other
+        // subgroups can compute their base offset.
+        bool top = f2ui(intBitsToFloat(v.y)) >= range_min;
+        uvec4 b = subgroupBallot(top);
+        uint bit_count = subgroupBallotBitCount(b);
+        if ((tid % SUBGROUP_SIZE) == 0) {
+            offset_partials[tid / SUBGROUP_SIZE] = bit_count;
+        }
+        barrier();
+
+        uint out_idx = 0;
+        [[unroll]] for (int i = 0; i < BLOCK_SIZE / SUBGROUP_SIZE; ++i) {
+            if (i < tid / SUBGROUP_SIZE) {
+                out_idx += offset_partials[i];
+            }
+        }
+
+        uint bit_count_ex = subgroupBallotExclusiveBitCount(b);
+        if (top) {
+            // TODO: Copy directly to the output?
+            dst_row[out_idx + bit_count_ex] = v;
+        }
+
+        barrier();
+    }
+
+    if (tid < p.ncols_output && gl_GlobalInvocationID.x < p.orig_ncols) {
+        if (p.last_pass != 0) {
+            const uint row_offset = row * p.ncols_output;
+            data_d[row_offset + tid] = dst_row[tid].x;
+        } else {
+            const uint row_offset = row * p.orig_ncols + gl_WorkGroupID.x * p.ncols_output;
+            data_t[row_offset + tid] = dst_row[tid];
+        }
+    }
+}
+
+void main() {
+    uint row = gl_WorkGroupID.y;
+    while (row < p.nrows) {
+        topk(row);
+        row += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
+    }
+}

+ 3 - 0
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

@@ -913,6 +913,9 @@ void process_shaders() {
     string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
     string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}});
 
+    string_to_spv("topk_argsort_f32", "topk_argsort.comp", {{"A_TYPE", "float"}});
+    string_to_spv("topk_nary_search_f32", "topk_nary_search.comp", {{"A_TYPE", "float"}});
+
     string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
     string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
     string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));

+ 15 - 1
tests/test-backend-ops.cpp

@@ -7635,6 +7635,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection)
     }
 
+    for (int i = 0; i < 20; ++i) {
+        for (int k : {1, 2, 3, 7, 15, 100, 500, 1023, 9999}) {
+            if (k <= 1<<i) {
+                test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i), 1, 1, 1}, k));
+                test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {(1<<i) + 11, 1, 2, 1}, k));
+            }
+        }
+    }
     for (int k : {1, 2, 3, 7, 15}) {
         test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {16, 10, 10, 10}, k));
         test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {60, 10, 10, 10}, k));
@@ -8032,7 +8040,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     }
 
     test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
-    test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {65000, 16, 1, 1}, 40));
+    for (auto k : {1, 10, 40}) {
+        for (auto nrows : {1, 16}) {
+            for (auto cols : {k, 1000, 65000, 200000}) {
+                test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k));
+            }
+        }
+    }
 
     return test_cases;
 }