|
|
@@ -1227,6 +1227,7 @@ struct vk_op_topk_push_constants {
|
|
|
uint32_t orig_ncols;
|
|
|
uint32_t ncols_input;
|
|
|
uint32_t ncols_output;
|
|
|
+ uint32_t k;
|
|
|
uint32_t nrows;
|
|
|
uint32_t first_pass;
|
|
|
uint32_t last_pass;
|
|
|
@@ -1673,6 +1674,14 @@ class vk_perf_logger {
|
|
|
timings[name.str()].push_back(time);
|
|
|
return;
|
|
|
}
|
|
|
+ if (node->op == GGML_OP_TOP_K) {
|
|
|
+ std::stringstream name;
|
|
|
+ name << ggml_op_name(node->op) <<
|
|
|
+ " K=" << node->ne[0] <<
|
|
|
+ " (" << node->src[0]->ne[0] << "," << node->src[0]->ne[1] << "," << node->src[0]->ne[2] << "," << node->src[0]->ne[3] << ")";
|
|
|
+ timings[name.str()].push_back(time);
|
|
|
+ return;
|
|
|
+ }
|
|
|
timings[ggml_op_name(node->op)].push_back(time);
|
|
|
}
|
|
|
private:
|
|
|
@@ -10345,17 +10354,8 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|
|
uint32_t nrows = ggml_nrows(src0);
|
|
|
uint32_t k = dst->ne[0];
|
|
|
|
|
|
- vk_op_topk_push_constants pc { ncols, ncols, k, nrows, 0, 0 };
|
|
|
+ vk_op_topk_push_constants pc { ncols, ncols, ncols, k, nrows, 0, 0 };
|
|
|
|
|
|
- // Reserve space for ivec2 per element, double buffered
|
|
|
- const size_t dbl_buf_size = size_t{ncols} * nrows * 2 * sizeof(int);
|
|
|
- const size_t x_sz = dbl_buf_size * 2;
|
|
|
- uint32_t dbl_buf_index = 0;
|
|
|
-
|
|
|
- if (ctx->prealloc_size_x < x_sz) {
|
|
|
- ctx->prealloc_size_x = x_sz;
|
|
|
- ggml_vk_preallocate_buffers(ctx, subctx);
|
|
|
- }
|
|
|
if (ctx->prealloc_x_need_sync) {
|
|
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
|
}
|
|
|
@@ -10370,8 +10370,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|
|
// largest elements. Repeat until we have the top K elements.
|
|
|
// Need to do at least one iteration to write out the results.
|
|
|
bool done_one_iter = false;
|
|
|
+ uint32_t dbl_buf_index = 0;
|
|
|
+ size_t dbl_buf_size;
|
|
|
while (num_elements > k || !done_one_iter) {
|
|
|
- done_one_iter = true;
|
|
|
|
|
|
// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
|
|
|
// But if K is larger, then we need a larger workgroup
|
|
|
@@ -10411,6 +10412,21 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|
|
// Number of elements remaining after this pass
|
|
|
uint32_t num_dst_elements = (num_elements / pipeline->wg_denoms[0]) * k + std::min(k, num_elements % pipeline->wg_denoms[0]);
|
|
|
|
|
|
+ pc2.ncols_output = num_dst_elements;
|
|
|
+
|
|
|
+ if (!done_one_iter) {
|
|
|
+ // Reserve space for ivec2 per element, double buffered
|
|
|
+ // K per workgroup per row
|
|
|
+ dbl_buf_size = num_dst_elements * nrows * 2 * sizeof(int);
|
|
|
+ dbl_buf_size = ROUNDUP_POW2(dbl_buf_size, ctx->device->properties.limits.minStorageBufferOffsetAlignment);
|
|
|
+ const size_t x_sz = dbl_buf_size * 2;
|
|
|
+
|
|
|
+ if (ctx->prealloc_size_x < x_sz) {
|
|
|
+ ctx->prealloc_size_x = x_sz;
|
|
|
+ ggml_vk_preallocate_buffers(ctx, subctx);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
vk_subbuffer src_buf;
|
|
|
vk_subbuffer dst_buf;
|
|
|
|
|
|
@@ -10436,6 +10452,7 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|
|
if (num_elements > k) {
|
|
|
ggml_vk_sync_buffers(ctx, subctx);
|
|
|
}
|
|
|
+ done_one_iter = true;
|
|
|
}
|
|
|
ctx->prealloc_x_need_sync = true;
|
|
|
}
|