Browse Source

ggml : extend ggml_pool_1d + metal (#16429)

* chore: resolve conflicts

* feat: ggml metal impl

* fix: ggml_metal_kargs_pool_1d struct

* fix: require contiguous input

* chore: test pool_1d

* chore: limit pool1d test cases to p0=0 and s0=k0 to conform with asserts

* chore: add p0 and s0 to testing

* fix: allow padding for cpu and metal

* Update ggml/src/ggml-metal/ggml-metal.metal

* fix: correct single-threaded loop

* ggml : cleanup

* tests : add ne[1] != 1 tests

* fix: ne[1] handling in np

* cont : fixes

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Thore Koritzius 1 week ago
parent
commit
388ce82241

+ 63 - 37
ggml/src/ggml-cpu/ops.cpp

@@ -7,10 +7,9 @@
 #include "unary-ops.h"
 #include "vec.h"
 
-#include <cfloat>
 #include <algorithm>
+#include <cfloat>
 #include <cmath>
-#include <functional>
 
 // ggml_compute_forward_dup
 
@@ -7110,12 +7109,13 @@ void ggml_compute_forward_conv_2d_dw(
     }
 }
 
-// ggml_compute_forward_pool_1d_sk_p0
-
-static void ggml_compute_forward_pool_1d_sk_p0(
+// ggml_compute_forward_pool_1d_ksp
+static void ggml_compute_forward_pool_1d_ksp(
         const ggml_compute_params * params,
         const ggml_op_pool op,
         const int k,
+        const int s,
+        const int p,
         ggml_tensor * dst) {
 
     const ggml_tensor * src = dst->src[0];
@@ -7126,39 +7126,56 @@ static void ggml_compute_forward_pool_1d_sk_p0(
         return;
     }
 
-    const char * cdata = (const char *)src->data;
-    const char * const data_end = cdata + ggml_nbytes(src);
-    float * drow = (float *)dst->data;
+    const int64_t IW = src->ne[0];
+    const int64_t OW = dst->ne[0];
 
-    const int64_t rs = dst->ne[0];
+    const int64_t nr = ggml_nrows(src);
 
-    while (cdata < data_end) {
-        const void * srow = (const void *)cdata;
-        int j = 0;
-        for (int64_t i = 0; i < rs; ++i) {
+    for (int64_t ir = 0; ir < nr; ++ir) {
+        const char * srow_bytes =            (const char *) src->data + ir * src->nb[1];
+        float      * drow       = (float *) ((      char *) dst->data + ir * dst->nb[1]);
+
+        for (int64_t ow = 0; ow < OW; ++ow) {
+            float res = 0;
             switch (op) {
-                case GGML_OP_POOL_AVG:   drow[i] = 0;        break;
-                case GGML_OP_POOL_MAX:   drow[i] = -FLT_MAX; break;
+                case GGML_OP_POOL_AVG: res = 0.0f;     break;
+                case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
                 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
             }
+
+            int count = 0;
+            const int base = (int) ow * s - p;
+
             for (int ki = 0; ki < k; ++ki) {
-                const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
+                const int j = base + ki;
+                if (j < 0 || j >= (int) IW) {
+                    continue;
+                }
+
+                float v;
+                if (src->type == GGML_TYPE_F32) {
+                    v = ((const float *) srow_bytes)[j];
+                } else {
+                    v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]);
+                }
+
                 switch (op) {
-                    case GGML_OP_POOL_AVG:                         drow[i] += srow_j; break;
-                    case GGML_OP_POOL_MAX:   if (srow_j > drow[i]) drow[i]  = srow_j; break;
-                    case GGML_OP_POOL_COUNT:                       GGML_ABORT("fatal error");
+                    case GGML_OP_POOL_AVG: res += v;                break;
+                    case GGML_OP_POOL_MAX: res =  std::max(v, res); break;
+                    case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
                 }
-                ++j;
+
+                ++count;
             }
+
             switch (op) {
-                case GGML_OP_POOL_AVG:         drow[i] /= k; break;
-                case GGML_OP_POOL_MAX:                       break;
+                case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break;
+                case GGML_OP_POOL_MAX:                                           break;
                 case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
             }
-        }
 
-        cdata += src->nb[1];
-        drow  += rs;
+            drow[ow] = res;
+        }
     }
 }
 
@@ -7173,10 +7190,8 @@ void ggml_compute_forward_pool_1d(
     const int k0 = opts[1];
     const int s0 = opts[2];
     const int p0 = opts[3];
-    GGML_ASSERT(p0 == 0); // padding not supported
-    GGML_ASSERT(k0 == s0); // only s = k supported
 
-    ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst);
+    ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst);
 }
 
 // ggml_compute_forward_pool_2d
@@ -7194,6 +7209,7 @@ void ggml_compute_forward_pool_2d(
     }
 
     const int32_t * opts = (const int32_t *)dst->op_params;
+
     ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
     const int k0 = opts[1];
     const int k1 = opts[2];
@@ -7217,11 +7233,13 @@ void ggml_compute_forward_pool_2d(
     while (cdata < data_end) {
         for (int oy = 0; oy < py; ++oy) {
             float * const drow = dplane + oy * px;
+            float * const out  = drow;
+
             for (int ox = 0; ox < px; ++ox) {
-                float * const out =  drow + ox;
+                float res = 0;
                 switch (op) {
-                    case GGML_OP_POOL_AVG:     *out = 0;        break;
-                    case GGML_OP_POOL_MAX:     *out = -FLT_MAX; break;
+                    case GGML_OP_POOL_AVG: res = 0;        break;
+                    case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
                     case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
                 }
 
@@ -7229,24 +7247,32 @@ void ggml_compute_forward_pool_2d(
                 const int iy = offset1 + oy * s1;
 
                 for (int ky = 0; ky < k1; ++ky) {
-                    if (iy + ky < 0 || iy + ky >= src->ne[1]) continue;
+                    if (iy + ky < 0 || iy + ky >= src->ne[1]) {
+                        continue;
+                    }
+
                     const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky));
                     for (int kx = 0; kx < k0; ++kx) {
                         int j = ix + kx;
-                        if (j < 0 || j >= src->ne[0]) continue;
+                        if (j < 0 || j >= src->ne[0]) {
+                            continue;
+                        }
+
                         const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]);
                         switch (op) {
-                            case GGML_OP_POOL_AVG:                     *out += srow_j; break;
-                            case GGML_OP_POOL_MAX: if (srow_j > *out)  *out  = srow_j; break;
+                            case GGML_OP_POOL_AVG: res += srow_j;                break;
+                            case GGML_OP_POOL_MAX: res =  std::max(srow_j, res); break;
                             case GGML_OP_POOL_COUNT:               GGML_ABORT("fatal error");
                         }
                     }
                 }
                 switch (op) {
-                    case GGML_OP_POOL_AVG:           *out /= ka; break;
-                    case GGML_OP_POOL_MAX:                       break;
+                    case GGML_OP_POOL_AVG:           res /= ka; break;
+                    case GGML_OP_POOL_MAX:                      break;
                     case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error");
                 }
+
+                out[ox] = res;
             }
         }
 

+ 25 - 0
ggml/src/ggml-metal/ggml-metal-device.cpp

@@ -94,6 +94,31 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_l
     return res;
 }
 
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
+    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
+    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
+
+    const char * pool_str = "undefined";
+    switch (op_pool) {
+        case GGML_OP_POOL_AVG: pool_str = "avg"; break;
+        case GGML_OP_POOL_MAX: pool_str = "max"; break;
+        default: GGML_ASSERT(false && "not implemented");
+    };
+
+    char base[256];
+    char name[256];
+
+    snprintf(base, sizeof(base), "kernel_pool_1d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
+    snprintf(name, sizeof(name), "%s", base);
+
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+    }
+
+    return res;
+}
+
 ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
     GGML_ASSERT(ggml_is_contiguous(op->src[0]));
     GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);

+ 1 - 0
ggml/src/ggml-metal/ggml-metal-device.h

@@ -104,6 +104,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_
 
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base              (ggml_metal_library_t lib, enum ggml_op op);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy               (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d           (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d           (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows          (ggml_metal_library_t lib, enum ggml_type tsrc);
 struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows          (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);

+ 2 - 2
ggml/src/ggml-metal/ggml-metal-device.m

@@ -1044,10 +1044,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                    op->src[1]->type == GGML_TYPE_F32 &&
                    op->type == GGML_TYPE_F32 &&
                    (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
-        case GGML_OP_POOL_1D:
-            return false;
         case GGML_OP_UPSCALE:
             return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS);
+        case GGML_OP_POOL_1D:
+            return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_POOL_2D:
             return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_PAD:

+ 9 - 0
ggml/src/ggml-metal/ggml-metal-impl.h

@@ -928,6 +928,15 @@ typedef struct {
     int64_t  np;
 } ggml_metal_kargs_pool_2d;
 
+typedef struct {
+    int32_t  k0;
+    int32_t  s0;
+    int32_t  p0;
+    int64_t  IW;
+    int64_t  OW;
+    int64_t  np;
+} ggml_metal_kargs_pool_1d;
+
 typedef struct {
      int64_t ne00;
     uint64_t nb01;

+ 52 - 0
ggml/src/ggml-metal/ggml-metal-ops.cpp

@@ -432,6 +432,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_cpy(ctx, idx);
             } break;
+        case GGML_OP_POOL_1D:
+            {
+                n_fuse = ggml_metal_op_pool_1d(ctx, idx);
+            } break;
         case GGML_OP_POOL_2D:
             {
                 n_fuse = ggml_metal_op_pool_2d(ctx, idx);
@@ -1622,6 +1626,54 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
+int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+
+    const int32_t * opts = op->op_params;
+    ggml_op_pool op_pool = (ggml_op_pool) opts[0];
+
+    const int32_t k0 = opts[1];
+    const int32_t s0 = opts[2];
+    const int32_t p0 = opts[3];
+
+    const int64_t IW = op->src[0]->ne[0];
+    const int64_t OW = op->ne[0];
+
+    const int64_t np = ggml_nelements(op);
+
+    ggml_metal_kargs_pool_1d args_pool_1d = {
+        /* .k0 = */  k0,
+        /* .s0 = */  s0,
+        /* .p0 = */  p0,
+        /* .IW = */  IW,
+        /* .OW = */  OW,
+        /* .np = */  np
+    };
+
+    auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool);
+
+    const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
+    const int ntg = (np + nth - 1) / nth;
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args_pool_1d, sizeof(args_pool_1d),  0);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1);
+
+    return 1;
+}
+
+
 int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 

+ 1 - 0
ggml/src/ggml-metal/ggml-metal-ops.h

@@ -61,6 +61,7 @@ int ggml_metal_op_ssm_conv          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_ssm_scan          (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_rwkv              (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_cpy               (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_pool_1d           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_pool_2d           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_mul_mat           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_mul_mat_id        (ggml_metal_op_t ctx, int idx);

+ 68 - 0
ggml/src/ggml-metal/ggml-metal.metal

@@ -9869,6 +9869,74 @@ kernel void kernel_pool_2d_avg_f32(
     o_ptr[cur_oh * args.OW + cur_ow] = res;
 }
 
+
+kernel void kernel_pool_1d_max_f32(
+        constant        ggml_metal_kargs_pool_1d & args,
+        device  const   float * src,
+        device          float * dst,
+        uint            gid [[thread_position_in_grid]]
+) {
+
+    if (gid >= args.np) {
+        return;
+    }
+
+    const int ow  = (int)gid % args.OW;
+    const int row = (int)gid / args.OW;
+
+    const int base = ow * args.s0 - args.p0;
+
+    float acc = -INFINITY;
+
+    const int src_off = row * args.IW;
+    const int dst_off = row * args.OW;
+
+    for (int ki = 0; ki < args.k0; ++ki) {
+        int j = base + ki;
+        if (j < 0 || j >= args.IW){
+            continue;
+        }
+        float v = src[src_off + j];
+        acc = max(acc, v);
+    }
+
+    dst[dst_off + ow] = acc;
+}
+
+kernel void kernel_pool_1d_avg_f32(
+        constant        ggml_metal_kargs_pool_1d & args,
+        device  const   float * src,
+        device          float * dst,
+        uint            gid [[thread_position_in_grid]]
+) {
+
+    if (gid >= args.np) {
+        return;
+    }
+
+    const int ow  = (int)gid % args.OW;
+    const int row = (int)gid / args.OW;
+
+    const int base = ow * args.s0 - args.p0;
+
+    float acc = 0.0f;
+    int   cnt = 0;
+
+    const int src_off = row * args.IW;
+    const int dst_off = row * args.OW;
+
+    for (int ki = 0; ki < args.k0; ++ki) {
+        const int j = base + ki;
+        if (j < 0 || j >= args.IW) {
+            continue;
+        }
+        acc += src[src_off + j];
+        cnt += 1;
+    }
+
+    dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f;
+}
+
 kernel void kernel_opt_step_adamw_f32(
         constant    ggml_metal_kargs_opt_step_adamw & args,
         device       float * x,

+ 5 - 0
ggml/src/ggml.c

@@ -4838,6 +4838,8 @@ struct ggml_tensor * ggml_pool_1d(
         a->ne[2],
         a->ne[3],
     };
+    GGML_ASSERT(ne[0] > 0);
+
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
     int32_t params[] = { op, k0, s0, p0 };
@@ -4868,6 +4870,9 @@ struct ggml_tensor * ggml_pool_2d(
         a->ne[2],
         a->ne[3],
     };
+    GGML_ASSERT(ne[0] > 0);
+    GGML_ASSERT(ne[1] > 0);
+
     result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
     int32_t params[] = { op, k0, k1, s0, s1, p0, p1 };

+ 45 - 0
tests/test-backend-ops.cpp

@@ -4679,6 +4679,37 @@ struct test_pool2d : public test_case {
     }
 };
 
+// GGML_OP_POOL1D
+struct test_pool1d : public test_case {
+    enum ggml_op_pool pool_type;
+    const ggml_type type_input;
+    const std::array<int64_t, 4> ne_input;
+    const int k0;
+    const int s0;
+    const int p0;
+
+    std::string vars() override {
+        return VARS_TO_STR6(pool_type, type_input, ne_input, k0, s0, p0);
+    }
+
+    test_pool1d(ggml_op_pool pool_type = GGML_OP_POOL_AVG,
+                ggml_type type_input = GGML_TYPE_F32,
+                std::array<int64_t,4> ne_input = {10, 1, 1, 1},
+                int k0 = 3, int s0 = 3, int p0 = 0)
+        : pool_type(pool_type), type_input(type_input), ne_input(ne_input), k0(k0), s0(s0), p0(p0) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * input = ggml_new_tensor(ctx, type_input, 4, ne_input.data());
+        ggml_set_param(input);
+        ggml_set_name(input, "input");
+
+        ggml_tensor * out = ggml_pool_1d(ctx, input, pool_type, k0, s0, p0);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+};
+
 // GGML_OP_CONV_TRANSPOSE_1D
 struct test_conv_transpose_1d : public test_case {
     const std::array<int64_t, 4> ne_input;
@@ -7058,6 +7089,20 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         }
     }
 
+    for (ggml_type type_input : {GGML_TYPE_F32}) {
+        for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
+            for (int k0 : {1, 3}) {
+                for (int s0 : {1, 2}) {
+                    for (int p0 : {0, 1}) {
+                        test_cases.emplace_back(new test_pool1d(pool_type, type_input, { 10,  3, 2, 1 }, k0, s0, p0));
+                        test_cases.emplace_back(new test_pool1d(pool_type, type_input, { 11,  1, 3, 2 }, k0, s0, p0));
+                        test_cases.emplace_back(new test_pool1d(pool_type, type_input, { 128, 2, 1, 3 }, k0, s0, p0));
+                    }
+                }
+            }
+        }
+    }
+
 #if 0
     // >4GB im2col destination. Too slow to run by default.
     // Test cases taken from Wan2.1 T2V 1.3B.