|
|
@@ -1082,6 +1082,7 @@ struct vk_op_soft_max_push_constants {
|
|
|
|
|
|
struct vk_op_argsort_push_constants {
|
|
|
uint32_t ncols;
|
|
|
+ uint32_t nrows;
|
|
|
int32_t order;
|
|
|
};
|
|
|
|
|
|
@@ -8708,6 +8709,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
break;
|
|
|
case GGML_OP_ARGSORT:
|
|
|
elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 };
|
|
|
+ elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
|
|
|
break;
|
|
|
case GGML_OP_IM2COL:
|
|
|
{
|
|
|
@@ -9954,9 +9956,11 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|
|
int32_t * op_params = (int32_t *)dst->op_params;
|
|
|
|
|
|
uint32_t ncols = src0->ne[0];
|
|
|
+ uint32_t nrows = ggml_nrows(src0);
|
|
|
|
|
|
ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGSORT, {
|
|
|
ncols,
|
|
|
+ nrows,
|
|
|
op_params[0],
|
|
|
}, dryrun);
|
|
|
}
|