|
@@ -45,6 +45,7 @@ struct ggml_metal_context {
|
|
|
GGML_METAL_DECL_KERNEL(scale);
|
|
GGML_METAL_DECL_KERNEL(scale);
|
|
|
GGML_METAL_DECL_KERNEL(silu);
|
|
GGML_METAL_DECL_KERNEL(silu);
|
|
|
GGML_METAL_DECL_KERNEL(relu);
|
|
GGML_METAL_DECL_KERNEL(relu);
|
|
|
|
|
+ GGML_METAL_DECL_KERNEL(gelu);
|
|
|
GGML_METAL_DECL_KERNEL(soft_max);
|
|
GGML_METAL_DECL_KERNEL(soft_max);
|
|
|
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
|
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
|
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
|
@@ -135,6 +136,7 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|
|
GGML_METAL_ADD_KERNEL(scale);
|
|
GGML_METAL_ADD_KERNEL(scale);
|
|
|
GGML_METAL_ADD_KERNEL(silu);
|
|
GGML_METAL_ADD_KERNEL(silu);
|
|
|
GGML_METAL_ADD_KERNEL(relu);
|
|
GGML_METAL_ADD_KERNEL(relu);
|
|
|
|
|
+ GGML_METAL_ADD_KERNEL(gelu);
|
|
|
GGML_METAL_ADD_KERNEL(soft_max);
|
|
GGML_METAL_ADD_KERNEL(soft_max);
|
|
|
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
|
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
|
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
|
@@ -420,6 +422,20 @@ void ggml_metal_graph_compute(
|
|
|
|
|
|
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
} break;
|
|
} break;
|
|
|
|
|
+ case GGML_OP_GELU:
|
|
|
|
|
+ {
|
|
|
|
|
+ if (encoder == nil) {
|
|
|
|
|
+ encoder = [command_buffer computeCommandEncoder];
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ [encoder setComputePipelineState:ctx->pipeline_gelu];
|
|
|
|
|
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
|
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t n = ggml_nelements(dst);
|
|
|
|
|
+
|
|
|
|
|
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
|
|
|
|
+ } break;
|
|
|
case GGML_OP_SOFT_MAX:
|
|
case GGML_OP_SOFT_MAX:
|
|
|
{
|
|
{
|
|
|
if (encoder == nil) {
|
|
if (encoder == nil) {
|