|
@@ -705,6 +705,7 @@ struct vk_device_struct {
|
|
|
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
|
|
vk_pipeline pipeline_argsort_f32[num_argsort_pipelines];
|
|
|
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
|
|
vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
|
|
|
vk_pipeline pipeline_sum_rows_f32;
|
|
vk_pipeline pipeline_sum_rows_f32;
|
|
|
|
|
+ vk_pipeline pipeline_cumsum_f32;
|
|
|
vk_pipeline pipeline_argmax_f32;
|
|
vk_pipeline pipeline_argmax_f32;
|
|
|
vk_pipeline pipeline_count_equal_i32;
|
|
vk_pipeline pipeline_count_equal_i32;
|
|
|
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
|
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
|
|
@@ -3968,6 +3969,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
|
|
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
|
|
|
|
|
|
|
|
|
+ ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size);
|
|
|
|
|
+
|
|
|
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
|
|
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
|
|
|
|
|
|
|
|
#define IM2COL(bda) \
|
|
#define IM2COL(bda) \
|
|
@@ -8457,6 +8460,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
return ctx->device->pipeline_sum_rows_f32;
|
|
return ctx->device->pipeline_sum_rows_f32;
|
|
|
}
|
|
}
|
|
|
return nullptr;
|
|
return nullptr;
|
|
|
|
|
+ case GGML_OP_CUMSUM:
|
|
|
|
|
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
|
|
|
+ return ctx->device->pipeline_cumsum_f32;
|
|
|
|
|
+ }
|
|
|
|
|
+ return nullptr;
|
|
|
case GGML_OP_ARGMAX:
|
|
case GGML_OP_ARGMAX:
|
|
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
|
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
|
|
|
return ctx->device->pipeline_argmax_f32;
|
|
return ctx->device->pipeline_argmax_f32;
|
|
@@ -8821,6 +8829,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
case GGML_OP_SOFT_MAX:
|
|
case GGML_OP_SOFT_MAX:
|
|
|
case GGML_OP_SOFT_MAX_BACK:
|
|
case GGML_OP_SOFT_MAX_BACK:
|
|
|
case GGML_OP_SUM_ROWS:
|
|
case GGML_OP_SUM_ROWS:
|
|
|
|
|
+ case GGML_OP_CUMSUM:
|
|
|
case GGML_OP_MEAN:
|
|
case GGML_OP_MEAN:
|
|
|
case GGML_OP_ARGMAX:
|
|
case GGML_OP_ARGMAX:
|
|
|
{
|
|
{
|
|
@@ -10150,6 +10159,11 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons
|
|
|
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p);
|
|
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_MEAN, p);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
|
|
|
|
+ vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
|
|
|
|
|
+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p);
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
|
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
|
|
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
|
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f });
|
|
|
}
|
|
}
|
|
@@ -11749,6 +11763,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
|
|
case GGML_OP_SUM_ROWS:
|
|
case GGML_OP_SUM_ROWS:
|
|
|
ggml_vk_sum_rows(ctx, compute_ctx, src0, node);
|
|
ggml_vk_sum_rows(ctx, compute_ctx, src0, node);
|
|
|
|
|
|
|
|
|
|
+ break;
|
|
|
|
|
+ case GGML_OP_CUMSUM:
|
|
|
|
|
+ ggml_vk_cumsum(ctx, compute_ctx, src0, node);
|
|
|
|
|
+
|
|
|
break;
|
|
break;
|
|
|
case GGML_OP_MEAN:
|
|
case GGML_OP_MEAN:
|
|
|
ggml_vk_mean(ctx, compute_ctx, src0, node);
|
|
ggml_vk_mean(ctx, compute_ctx, src0, node);
|
|
@@ -13786,6 +13804,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
case GGML_OP_SUM_ROWS:
|
|
case GGML_OP_SUM_ROWS:
|
|
|
case GGML_OP_MEAN:
|
|
case GGML_OP_MEAN:
|
|
|
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
|
|
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
|
|
|
|
|
+ case GGML_OP_CUMSUM:
|
|
|
|
|
+ {
|
|
|
|
|
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
|
|
|
+ auto device = ggml_vk_get_device(ctx->device);
|
|
|
|
|
+ if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
|
|
|
|
|
+ return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
|
|
|
|
|
+ }
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
case GGML_OP_ARGMAX:
|
|
case GGML_OP_ARGMAX:
|
|
|
case GGML_OP_COUNT_EQUAL:
|
|
case GGML_OP_COUNT_EQUAL:
|
|
|
case GGML_OP_IM2COL:
|
|
case GGML_OP_IM2COL:
|
|
@@ -14436,6 +14463,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
|
|
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
|
|
tensor_clone = ggml_sum(ggml_ctx, src_clone[0]);
|
|
|
} else if (tensor->op == GGML_OP_SUM_ROWS) {
|
|
} else if (tensor->op == GGML_OP_SUM_ROWS) {
|
|
|
tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
|
|
tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]);
|
|
|
|
|
+ } else if (tensor->op == GGML_OP_CUMSUM) {
|
|
|
|
|
+ tensor_clone = ggml_cumsum(ggml_ctx, src_clone[0]);
|
|
|
} else if (tensor->op == GGML_OP_MEAN) {
|
|
} else if (tensor->op == GGML_OP_MEAN) {
|
|
|
tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
|
|
tensor_clone = ggml_mean(ggml_ctx, src_clone[0]);
|
|
|
} else if (tensor->op == GGML_OP_ARGMAX) {
|
|
} else if (tensor->op == GGML_OP_ARGMAX) {
|