|
@@ -498,6 +498,7 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_COS,
|
|
GGML_METAL_KERNEL_TYPE_COS,
|
|
|
GGML_METAL_KERNEL_TYPE_NEG,
|
|
GGML_METAL_KERNEL_TYPE_NEG,
|
|
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
|
|
|
|
+ GGML_METAL_KERNEL_TYPE_MEAN,
|
|
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
|
GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
|
|
|
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
|
GGML_METAL_KERNEL_TYPE_ARGMAX,
|
|
@@ -1454,6 +1455,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
|
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32, pool_2d_avg_f32, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
|
@@ -1653,6 +1655,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
case GGML_OP_LOG:
|
|
case GGML_OP_LOG:
|
|
|
return false; // TODO: implement
|
|
return false; // TODO: implement
|
|
|
case GGML_OP_SUM_ROWS:
|
|
case GGML_OP_SUM_ROWS:
|
|
|
|
|
+ case GGML_OP_MEAN:
|
|
|
case GGML_OP_SOFT_MAX:
|
|
case GGML_OP_SOFT_MAX:
|
|
|
case GGML_OP_GROUP_NORM:
|
|
case GGML_OP_GROUP_NORM:
|
|
|
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
|
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
|
@@ -2400,11 +2403,30 @@ static bool ggml_metal_encode_node(
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
} break;
|
|
} break;
|
|
|
case GGML_OP_SUM_ROWS:
|
|
case GGML_OP_SUM_ROWS:
|
|
|
|
|
+ case GGML_OP_MEAN:
|
|
|
{
|
|
{
|
|
|
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
|
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
|
|
|
|
|
|
|
|
- id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
|
|
|
|
|
|
+ id<MTLComputePipelineState> pipeline = nil;
|
|
|
|
|
+
|
|
|
|
|
+ switch (dst->op) {
|
|
|
|
|
+ case GGML_OP_SUM_ROWS:
|
|
|
|
|
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
|
|
|
|
|
+ break;
|
|
|
|
|
+ case GGML_OP_MEAN:
|
|
|
|
|
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MEAN].pipeline;
|
|
|
|
|
+ break;
|
|
|
|
|
+ default:
|
|
|
|
|
+ GGML_ABORT("fatal error");
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ int nth = 32; // SIMD width
|
|
|
|
|
+
|
|
|
|
|
+ while (nth < ne00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
|
|
|
|
+ nth *= 2;
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
|
|
+ nth = MIN(nth, ne00);
|
|
|
|
|
|
|
|
ggml_metal_kargs_sum_rows args = {
|
|
ggml_metal_kargs_sum_rows args = {
|
|
|
/*.ne00 =*/ ne00,
|
|
/*.ne00 =*/ ne00,
|
|
@@ -2434,11 +2456,12 @@ static bool ggml_metal_encode_node(
|
|
|
};
|
|
};
|
|
|
|
|
|
|
|
[encoder setComputePipelineState:pipeline];
|
|
[encoder setComputePipelineState:pipeline];
|
|
|
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
|
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
|
- [encoder setBytes:&args length:sizeof(args) atIndex:2];
|
|
|
|
|
|
|
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
|
|
|
|
|
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
|
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
|
|
|
|
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
|
|
|
|
|
|
|
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
} break;
|
|
} break;
|
|
|
case GGML_OP_SOFT_MAX:
|
|
case GGML_OP_SOFT_MAX:
|
|
|
{
|
|
{
|