Просмотр исходного кода

metal : add GELU implementation (#1770)

Co-authored-by: Adam Treat <adam@nomic.ai>
AT 2 лет назад
Родитель
Сommit
92f44ff7f7
2 измененных файлов с 27 добавлено и 0 удалено
  1. 16 0
      ggml-metal.m
  2. 11 0
      ggml-metal.metal

+ 16 - 0
ggml-metal.m

@@ -45,6 +45,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(scale);
     GGML_METAL_DECL_KERNEL(silu);
     GGML_METAL_DECL_KERNEL(relu);
+    GGML_METAL_DECL_KERNEL(gelu);
     GGML_METAL_DECL_KERNEL(soft_max);
     GGML_METAL_DECL_KERNEL(diag_mask_inf);
     GGML_METAL_DECL_KERNEL(get_rows_f16);
@@ -135,6 +136,7 @@ struct ggml_metal_context * ggml_metal_init(void) {
         GGML_METAL_ADD_KERNEL(scale);
         GGML_METAL_ADD_KERNEL(silu);
         GGML_METAL_ADD_KERNEL(relu);
+        GGML_METAL_ADD_KERNEL(gelu);
         GGML_METAL_ADD_KERNEL(soft_max);
         GGML_METAL_ADD_KERNEL(diag_mask_inf);
         GGML_METAL_ADD_KERNEL(get_rows_f16);
@@ -420,6 +422,20 @@ void ggml_metal_graph_compute(
 
                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                 } break;
+            case GGML_OP_GELU:
+            {
+                    if (encoder == nil) {
+                        encoder = [command_buffer computeCommandEncoder];
+                    }
+
+                    [encoder setComputePipelineState:ctx->pipeline_gelu];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+            } break;
             case GGML_OP_SOFT_MAX:
                 {
                     if (encoder == nil) {

+ 11 - 0
ggml-metal.metal

@@ -81,6 +81,17 @@ kernel void kernel_relu(
     dst[tpig] = max(0.0f, src0[tpig]);
 }
 
+constant float GELU_COEF_A    = 0.044715f;
+constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+
+kernel void kernel_gelu(
+    device const float * src0,
+    device       float * dst,
+    uint tpig[[thread_position_in_grid]]) {
+    float x = src0[tpig];
+    dst[tpig] = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
 kernel void kernel_soft_max(
         device const float * src0,
         device       float * dst,