|
@@ -149,6 +149,8 @@ enum ggml_metal_kernel_type {
|
|
|
GGML_METAL_KERNEL_TYPE_SIGMOID,
|
|
GGML_METAL_KERNEL_TYPE_SIGMOID,
|
|
|
GGML_METAL_KERNEL_TYPE_GELU,
|
|
GGML_METAL_KERNEL_TYPE_GELU,
|
|
|
GGML_METAL_KERNEL_TYPE_GELU_4,
|
|
GGML_METAL_KERNEL_TYPE_GELU_4,
|
|
|
|
|
+ GGML_METAL_KERNEL_TYPE_GELU_ERF,
|
|
|
|
|
+ GGML_METAL_KERNEL_TYPE_GELU_ERF_4,
|
|
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK,
|
|
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
|
GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,
|
|
|
GGML_METAL_KERNEL_TYPE_SILU,
|
|
GGML_METAL_KERNEL_TYPE_SILU,
|
|
@@ -1103,6 +1105,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID, sigmoid, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_4, gelu_4, true);
|
|
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF, gelu_erf, true);
|
|
|
|
|
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_ERF_4, gelu_erf_4, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4, gelu_quick_4, true);
|
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
|
@@ -1613,6 +1617,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
|
|
case GGML_UNARY_OP_RELU:
|
|
case GGML_UNARY_OP_RELU:
|
|
|
case GGML_UNARY_OP_SIGMOID:
|
|
case GGML_UNARY_OP_SIGMOID:
|
|
|
case GGML_UNARY_OP_GELU:
|
|
case GGML_UNARY_OP_GELU:
|
|
|
|
|
+ case GGML_UNARY_OP_GELU_ERF:
|
|
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
|
case GGML_UNARY_OP_SILU:
|
|
case GGML_UNARY_OP_SILU:
|
|
|
case GGML_UNARY_OP_ELU:
|
|
case GGML_UNARY_OP_ELU:
|
|
@@ -2251,6 +2256,25 @@ static bool ggml_metal_encode_node(
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
} break;
|
|
} break;
|
|
|
|
|
+ case GGML_UNARY_OP_GELU_ERF:
|
|
|
|
|
+ {
|
|
|
|
|
+ int64_t n = ggml_nelements(dst);
|
|
|
|
|
+
|
|
|
|
|
+ id<MTLComputePipelineState> pipeline = nil;
|
|
|
|
|
+
|
|
|
|
|
+ if (n % 4 == 0) {
|
|
|
|
|
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF_4].pipeline;
|
|
|
|
|
+ n /= 4;
|
|
|
|
|
+ } else {
|
|
|
|
|
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_ERF].pipeline;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ [encoder setComputePipelineState:pipeline];
|
|
|
|
|
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
|
+
|
|
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
|
+ } break;
|
|
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
|
{
|
|
{
|
|
|
int64_t n = ggml_nelements(dst);
|
|
int64_t n = ggml_nelements(dst);
|