|
|
@@ -173,6 +173,12 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_SILU,
|
|
|
GGML_METAL_KERNEL_TYPE_SILU_4,
|
|
|
GGML_METAL_KERNEL_TYPE_ELU,
|
|
|
+ GGML_METAL_KERNEL_TYPE_ABS,
|
|
|
+ GGML_METAL_KERNEL_TYPE_SGN,
|
|
|
+ GGML_METAL_KERNEL_TYPE_STEP,
|
|
|
+ GGML_METAL_KERNEL_TYPE_HARDSWISH,
|
|
|
+ GGML_METAL_KERNEL_TYPE_HARDSIGMOID,
|
|
|
+ GGML_METAL_KERNEL_TYPE_EXP,
|
|
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,
|
|
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,
|
|
|
GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,
|
|
|
@@ -1155,6 +1161,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4, silu_4, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ELU, elu, true);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ABS, abs, true);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SGN, sgn, true);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_STEP, step, true);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSWISH, hardswish, true);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_HARDSIGMOID, hardsigmoid, true);
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_EXP, exp, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16, soft_max_f16, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4, soft_max_f16_4, has_simdgroup_reduction);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32, soft_max_f32, has_simdgroup_reduction);
|
|
|
@@ -1688,6 +1700,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
case GGML_UNARY_OP_SILU:
|
|
|
case GGML_UNARY_OP_ELU:
|
|
|
case GGML_UNARY_OP_NEG:
|
|
|
+ case GGML_UNARY_OP_ABS:
|
|
|
+ case GGML_UNARY_OP_SGN:
|
|
|
+ case GGML_UNARY_OP_STEP:
|
|
|
+ case GGML_UNARY_OP_HARDSWISH:
|
|
|
+ case GGML_UNARY_OP_HARDSIGMOID:
|
|
|
+ case GGML_UNARY_OP_EXP:
|
|
|
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
|
|
default:
|
|
|
return false;
|
|
|
@@ -2439,6 +2457,78 @@ static bool ggml_metal_encode_node(
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
} break;
|
|
|
+ case GGML_UNARY_OP_ABS:
|
|
|
+ {
|
|
|
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ABS].pipeline;
|
|
|
+
|
|
|
+ [encoder setComputePipelineState:pipeline];
|
|
|
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
+
|
|
|
+ const int64_t n = ggml_nelements(dst);
|
|
|
+
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
+ } break;
|
|
|
+ case GGML_UNARY_OP_SGN:
|
|
|
+ {
|
|
|
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SGN].pipeline;
|
|
|
+
|
|
|
+ [encoder setComputePipelineState:pipeline];
|
|
|
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
+
|
|
|
+ const int64_t n = ggml_nelements(dst);
|
|
|
+
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
+ } break;
|
|
|
+ case GGML_UNARY_OP_STEP:
|
|
|
+ {
|
|
|
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_STEP].pipeline;
|
|
|
+
|
|
|
+ [encoder setComputePipelineState:pipeline];
|
|
|
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
+
|
|
|
+ const int64_t n = ggml_nelements(dst);
|
|
|
+
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
+ } break;
|
|
|
+ case GGML_UNARY_OP_HARDSWISH:
|
|
|
+ {
|
|
|
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSWISH].pipeline;
|
|
|
+
|
|
|
+ [encoder setComputePipelineState:pipeline];
|
|
|
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
+
|
|
|
+ const int64_t n = ggml_nelements(dst);
|
|
|
+
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
+ } break;
|
|
|
+ case GGML_UNARY_OP_HARDSIGMOID:
|
|
|
+ {
|
|
|
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_HARDSIGMOID].pipeline;
|
|
|
+
|
|
|
+ [encoder setComputePipelineState:pipeline];
|
|
|
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
+
|
|
|
+ const int64_t n = ggml_nelements(dst);
|
|
|
+
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
+ } break;
|
|
|
+ case GGML_UNARY_OP_EXP:
|
|
|
+ {
|
|
|
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_EXP].pipeline;
|
|
|
+
|
|
|
+ [encoder setComputePipelineState:pipeline];
|
|
|
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
+
|
|
|
+ const int64_t n = ggml_nelements(dst);
|
|
|
+
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
+ } break;
|
|
|
default:
|
|
|
{
|
|
|
GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|