Quellcode durchsuchen

metal: add neg operator (#13029)

Jeffrey Morgan vor 9 Monaten
Ursprung
Commit
4ba9d711ba
2 geänderte Dateien mit 22 neuen und 0 gelöschten Zeilen
  1. 15 0
      ggml/src/ggml-metal/ggml-metal.m
  2. 7 0
      ggml/src/ggml-metal/ggml-metal.metal

+ 15 - 0
ggml/src/ggml-metal/ggml-metal.m

@@ -481,6 +481,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_SQRT,
     GGML_METAL_KERNEL_TYPE_SIN,
     GGML_METAL_KERNEL_TYPE_COS,
+    GGML_METAL_KERNEL_TYPE_NEG,
     GGML_METAL_KERNEL_TYPE_SUM_ROWS,
     GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
     GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
@@ -1159,6 +1160,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                            sqrt,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                             sin,                             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                             cos,                             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG,                             neg,                             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                        sum_rows,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                          argmax,                          true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,                 pool_2d_avg_f32,                 true);
@@ -1320,6 +1322,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_ELU:
+                case GGML_UNARY_OP_NEG:
                     return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
                 default:
                     return false;
@@ -2010,6 +2013,18 @@ static void ggml_metal_encode_node(
 
                     [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                 } break;
+                case GGML_UNARY_OP_NEG:
+                {
+                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NEG].pipeline;
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                    const int64_t n = ggml_nelements(dst);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                } break;
                 default:
                 {
                     GGML_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));

+ 7 - 0
ggml/src/ggml-metal/ggml-metal.metal

@@ -949,6 +949,13 @@ kernel void kernel_cos(
     dst[tpig] = cos(src0[tpig]);
 }
 
+kernel void kernel_neg(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = -src0[tpig];
+}
+
 kernel void kernel_sum_rows(
         device const float * src0,
         device       float * dst,