Просмотр исходного кода

metal: accelerated conv2d (#17175)

* metal: accelerated conv2d

* cont : cleanup

---------

Co-authored-by: bghira <bghira@users.github.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
bagheera 2 месяцев назад
Родитель
Сommit
0cfb19166b

+ 24 - 0
ggml/src/ggml-metal/ggml-metal-device.cpp

@@ -1438,6 +1438,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_met
     return res;
 }
 
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
+    assert(op->op == GGML_OP_CONV_2D);
+
+    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
+    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
+    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
+    GGML_ASSERT(op->type         == GGML_TYPE_F32);
+
+    char base[256];
+    char name[256];
+
+    snprintf(base, 256, "kernel_conv_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
+    snprintf(name, 256, "%s", base);
+
+    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
+    if (res) {
+        return res;
+    }
+
+    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+
+    return res;
+}
+
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_UPSCALE);
 

+ 1 - 0
ggml/src/ggml-metal/ggml-metal-device.h

@@ -133,6 +133,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope              (ggml_me
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col            (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d           (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale           (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad               (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d    (ggml_metal_library_t lib, const struct ggml_tensor * op);

+ 5 - 0
ggml/src/ggml-metal/ggml-metal-device.m

@@ -885,6 +885,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
             return true;
         case GGML_OP_IM2COL:
             return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
+        case GGML_OP_CONV_2D:
+            return ggml_is_contiguous(op->src[0]) &&
+                   op->src[1]->type == GGML_TYPE_F32 &&
+                   op->type == GGML_TYPE_F32 &&
+                   (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
         case GGML_OP_POOL_1D:
             return false;
         case GGML_OP_UPSCALE:

+ 30 - 0
ggml/src/ggml-metal/ggml-metal-impl.h

@@ -528,6 +528,36 @@ typedef struct {
     uint64_t nb2;
 } ggml_metal_kargs_conv_transpose_2d;
 
+typedef struct {
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    int32_t  IW;
+    int32_t  IH;
+    int32_t  KW;
+    int32_t  KH;
+    int32_t  IC;
+    int32_t  OC;
+    int32_t  OW;
+    int32_t  OH;
+    int32_t  N;
+    int32_t  s0;
+    int32_t  s1;
+    int32_t  p0;
+    int32_t  p1;
+    int32_t  d0;
+    int32_t  d1;
+} ggml_metal_kargs_conv_2d;
+
 typedef struct {
     uint64_t  ofs0;
     uint64_t  ofs1;

+ 83 - 5
ggml/src/ggml-metal/ggml-metal-ops.cpp

@@ -10,6 +10,7 @@
 
 #include <cassert>
 #include <algorithm>
+#include <limits>
 
 static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
     if (!t) {
@@ -364,6 +365,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
             {
                 n_fuse = ggml_metal_op_im2col(ctx, idx);
             } break;
+        case GGML_OP_CONV_2D:
+            {
+                n_fuse = ggml_metal_op_conv_2d(ctx, idx);
+            } break;
         case GGML_OP_CONV_TRANSPOSE_1D:
             {
                 n_fuse = ggml_metal_op_conv_transpose_1d(ctx, idx);
@@ -1036,11 +1041,6 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
 
     nth = std::min(nth, nk0);
 
-    if (nth*nrptg > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
-        nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
-        nrptg = 1;
-    }
-
     ggml_metal_kargs_set_rows args = {
         /*.nk0  =*/ nk0,
         /*.ne01 =*/ ne01,
@@ -3082,6 +3082,84 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
     return 1;
 }
 
+int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
+    ggml_tensor * op = ctx->node(idx);
+
+    ggml_metal_library_t lib = ctx->lib;
+    ggml_metal_encoder_t enc = ctx->enc;
+
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+    GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
+    GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
+
+    GGML_ASSERT(ggml_is_contiguous(op->src[0]));
+    GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
+    GGML_ASSERT(op->type == GGML_TYPE_F32);
+    GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32);
+
+    const int32_t s0 = ((const int32_t *) op->op_params)[0];
+    const int32_t s1 = ((const int32_t *) op->op_params)[1];
+    const int32_t p0 = ((const int32_t *) op->op_params)[2];
+    const int32_t p1 = ((const int32_t *) op->op_params)[3];
+    const int32_t d0 = ((const int32_t *) op->op_params)[4];
+    const int32_t d1 = ((const int32_t *) op->op_params)[5];
+
+    ggml_metal_kargs_conv_2d args = {
+        /*.nb00 =*/ nb00,
+        /*.nb01 =*/ nb01,
+        /*.nb02 =*/ nb02,
+        /*.nb03 =*/ nb03,
+        /*.nb10 =*/ nb10,
+        /*.nb11 =*/ nb11,
+        /*.nb12 =*/ nb12,
+        /*.nb13 =*/ nb13,
+        /*.nb0  =*/ nb0,
+        /*.nb1  =*/ nb1,
+        /*.nb2  =*/ nb2,
+        /*.nb3  =*/ nb3,
+        /*.IW   =*/ ne10,
+        /*.IH   =*/ ne11,
+        /*.KW   =*/ ne00,
+        /*.KH   =*/ ne01,
+        /*.IC   =*/ ne02,
+        /*.OC   =*/ ne03,
+        /*.OW   =*/ ne0,
+        /*.OH   =*/ ne1,
+        /*.N    =*/ ne3,
+        /*.s0   =*/ s0,
+        /*.s1   =*/ s1,
+        /*.p0   =*/ p0,
+        /*.p1   =*/ p1,
+        /*.d0   =*/ d0,
+        /*.d1   =*/ d1,
+    };
+
+    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
+
+    int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
+    nth = std::min(nth, 256);
+    nth = std::max(nth, 1);
+
+    const uint64_t n_out = ggml_nelements(op);
+
+    uint64_t tg = (n_out + nth - 1)/nth;
+    tg = std::max<uint64_t>(tg, 1);
+    tg = std::min<uint64_t>(tg, (uint64_t) std::numeric_limits<int>::max());
+
+    ggml_metal_encoder_set_pipeline(enc, pipeline);
+    ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+    ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
+
+    ggml_metal_encoder_dispatch_threadgroups(enc, tg, 1, 1, nth, 1, 1);
+
+    return 1;
+}
+
 int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
     ggml_tensor * op = ctx->node(idx);
 

+ 1 - 0
ggml/src/ggml-metal/ggml-metal-ops.h

@@ -70,6 +70,7 @@ int ggml_metal_op_group_norm        (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_norm              (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_rope              (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_im2col            (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_conv_2d           (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx);
 int ggml_metal_op_upscale           (ggml_metal_op_t ctx, int idx);

+ 114 - 0
ggml/src/ggml-metal/ggml-metal.metal

@@ -4146,6 +4146,120 @@ template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
 //template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
 //template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
 
+template <typename TK>
+kernel void kernel_conv_2d(
+        constant ggml_metal_kargs_conv_2d & args,
+        device const char * weights,
+        device const char * src,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        uint3    tgpg[[threadgroups_per_grid]],
+        uint3   tpitg[[thread_position_in_threadgroup]],
+        uint3     ntg[[threads_per_threadgroup]]) {
+
+    const uint threads_per_tg = ntg.x * ntg.y * ntg.z;
+    const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x;
+    const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x;
+    const uint thread_index = tg_index * threads_per_tg + local_thread;
+    const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z;
+    const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW;
+
+    for (uint64_t index = thread_index; index < total_outputs; index += total_threads) {
+        uint64_t tmp = index;
+
+        const int32_t ow = tmp % args.OW; tmp /= args.OW;
+        const int32_t oh = tmp % args.OH; tmp /= args.OH;
+        const int32_t oc = tmp % args.OC; tmp /= args.OC;
+        const int32_t  n = tmp;
+
+        float acc = 0.0f;
+
+        const int32_t base_x = ow*args.s0 - args.p0;
+        const int32_t base_y = oh*args.s1 - args.p1;
+
+        int32_t ky_start = 0;
+        if (base_y < 0) {
+            ky_start = (-base_y + args.d1 - 1)/args.d1;
+        }
+        int32_t ky_end = args.KH;
+        const int32_t y_max = args.IH - 1 - base_y;
+        if (y_max < 0) {
+            ky_end = ky_start;
+        } else if (base_y + (args.KH - 1)*args.d1 >= args.IH) {
+            ky_end = min(ky_end, y_max/args.d1 + 1);
+        }
+
+        int32_t kx_start = 0;
+        if (base_x < 0) {
+            kx_start = (-base_x + args.d0 - 1)/args.d0;
+        }
+        int32_t kx_end = args.KW;
+        const int32_t x_max = args.IW - 1 - base_x;
+        if (x_max < 0) {
+            kx_end = kx_start;
+        } else if (base_x + (args.KW - 1)*args.d0 >= args.IW) {
+            kx_end = min(kx_end, x_max/args.d0 + 1);
+        }
+
+        if (ky_start < ky_end && kx_start < kx_end) {
+            const uint64_t src_base_n = (uint64_t) n  * args.nb13;
+            const uint64_t w_base_oc  = (uint64_t) oc * args.nb03;
+
+            for (int32_t ic = 0; ic < args.IC; ++ic) {
+                const uint64_t src_base_nc = src_base_n + (uint64_t) ic * args.nb12;
+                const uint64_t w_base_ocic = w_base_oc  + (uint64_t) ic * args.nb02;
+
+                for (int32_t ky = ky_start; ky < ky_end; ++ky) {
+                    const int32_t iy = base_y + ky*args.d1;
+                    const uint64_t src_base_row = src_base_nc + (uint64_t) iy * args.nb11;
+                    const uint64_t w_base_row   = w_base_ocic + (uint64_t) ky * args.nb01;
+
+                    for (int32_t kx = kx_start; kx < kx_end; ++kx) {
+                        const int32_t ix = base_x + kx*args.d0;
+                        const uint64_t src_offs = src_base_row + (uint64_t) ix * args.nb10;
+                        const uint64_t w_offs   = w_base_row   + (uint64_t) kx * args.nb00;
+
+                        const float x = *(device const float *)(src + src_offs);
+                        const float w = (float) (*(device const TK *)(weights + w_offs));
+
+                        acc += x * w;
+                    }
+                }
+            }
+        }
+
+        const uint64_t dst_offs =
+            (uint64_t) n  * args.nb3 +
+            (uint64_t) oc * args.nb2 +
+            (uint64_t) oh * args.nb1 +
+            (uint64_t) ow * args.nb0;
+
+        *(device float *)(dst + dst_offs) = acc;
+    }
+}
+
+template [[host_name("kernel_conv_2d_f32_f32")]]
+kernel void kernel_conv_2d<float>(
+        constant ggml_metal_kargs_conv_2d & args,
+        device const char * weights,
+        device const char * src,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        uint3    tgpg[[threadgroups_per_grid]],
+        uint3   tpitg[[thread_position_in_threadgroup]],
+        uint3     ntg[[threads_per_threadgroup]]);
+
+template [[host_name("kernel_conv_2d_f16_f32")]]
+kernel void kernel_conv_2d<half>(
+        constant ggml_metal_kargs_conv_2d & args,
+        device const char * weights,
+        device const char * src,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        uint3    tgpg[[threadgroups_per_grid]],
+        uint3   tpitg[[thread_position_in_threadgroup]],
+        uint3     ntg[[threads_per_threadgroup]]);
+
 typedef void (conv_transpose_1d_t)(
         constant ggml_metal_kargs_conv_transpose_1d & args,
         device const float * src0,