|
@@ -4146,6 +4146,120 @@ template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
|
|
|
//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
|
|
//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
|
|
|
//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
|
|
//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
|
|
|
|
|
|
|
|
|
|
+template <typename TK>
|
|
|
|
|
+kernel void kernel_conv_2d(
|
|
|
|
|
+ constant ggml_metal_kargs_conv_2d & args,
|
|
|
|
|
+ device const char * weights,
|
|
|
|
|
+ device const char * src,
|
|
|
|
|
+ device char * dst,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint3 tgpg[[threadgroups_per_grid]],
|
|
|
|
|
+ uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
+ uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
|
+
|
|
|
|
|
+ const uint threads_per_tg = ntg.x * ntg.y * ntg.z;
|
|
|
|
|
+ const uint tg_index = (tgpig.z * tgpg.y + tgpig.y) * tgpg.x + tgpig.x;
|
|
|
|
|
+ const uint local_thread = tpitg.z * (ntg.x * ntg.y) + tpitg.y * ntg.x + tpitg.x;
|
|
|
|
|
+ const uint thread_index = tg_index * threads_per_tg + local_thread;
|
|
|
|
|
+ const uint64_t total_threads = (uint64_t) threads_per_tg * tgpg.x * tgpg.y * tgpg.z;
|
|
|
|
|
+ const uint64_t total_outputs = (uint64_t) args.N * args.OC * args.OH * args.OW;
|
|
|
|
|
+
|
|
|
|
|
+ for (uint64_t index = thread_index; index < total_outputs; index += total_threads) {
|
|
|
|
|
+ uint64_t tmp = index;
|
|
|
|
|
+
|
|
|
|
|
+ const int32_t ow = tmp % args.OW; tmp /= args.OW;
|
|
|
|
|
+ const int32_t oh = tmp % args.OH; tmp /= args.OH;
|
|
|
|
|
+ const int32_t oc = tmp % args.OC; tmp /= args.OC;
|
|
|
|
|
+ const int32_t n = tmp;
|
|
|
|
|
+
|
|
|
|
|
+ float acc = 0.0f;
|
|
|
|
|
+
|
|
|
|
|
+ const int32_t base_x = ow*args.s0 - args.p0;
|
|
|
|
|
+ const int32_t base_y = oh*args.s1 - args.p1;
|
|
|
|
|
+
|
|
|
|
|
+ int32_t ky_start = 0;
|
|
|
|
|
+ if (base_y < 0) {
|
|
|
|
|
+ ky_start = (-base_y + args.d1 - 1)/args.d1;
|
|
|
|
|
+ }
|
|
|
|
|
+ int32_t ky_end = args.KH;
|
|
|
|
|
+ const int32_t y_max = args.IH - 1 - base_y;
|
|
|
|
|
+ if (y_max < 0) {
|
|
|
|
|
+ ky_end = ky_start;
|
|
|
|
|
+ } else if (base_y + (args.KH - 1)*args.d1 >= args.IH) {
|
|
|
|
|
+ ky_end = min(ky_end, y_max/args.d1 + 1);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ int32_t kx_start = 0;
|
|
|
|
|
+ if (base_x < 0) {
|
|
|
|
|
+ kx_start = (-base_x + args.d0 - 1)/args.d0;
|
|
|
|
|
+ }
|
|
|
|
|
+ int32_t kx_end = args.KW;
|
|
|
|
|
+ const int32_t x_max = args.IW - 1 - base_x;
|
|
|
|
|
+ if (x_max < 0) {
|
|
|
|
|
+ kx_end = kx_start;
|
|
|
|
|
+ } else if (base_x + (args.KW - 1)*args.d0 >= args.IW) {
|
|
|
|
|
+ kx_end = min(kx_end, x_max/args.d0 + 1);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if (ky_start < ky_end && kx_start < kx_end) {
|
|
|
|
|
+ const uint64_t src_base_n = (uint64_t) n * args.nb13;
|
|
|
|
|
+ const uint64_t w_base_oc = (uint64_t) oc * args.nb03;
|
|
|
|
|
+
|
|
|
|
|
+ for (int32_t ic = 0; ic < args.IC; ++ic) {
|
|
|
|
|
+ const uint64_t src_base_nc = src_base_n + (uint64_t) ic * args.nb12;
|
|
|
|
|
+ const uint64_t w_base_ocic = w_base_oc + (uint64_t) ic * args.nb02;
|
|
|
|
|
+
|
|
|
|
|
+ for (int32_t ky = ky_start; ky < ky_end; ++ky) {
|
|
|
|
|
+ const int32_t iy = base_y + ky*args.d1;
|
|
|
|
|
+ const uint64_t src_base_row = src_base_nc + (uint64_t) iy * args.nb11;
|
|
|
|
|
+ const uint64_t w_base_row = w_base_ocic + (uint64_t) ky * args.nb01;
|
|
|
|
|
+
|
|
|
|
|
+ for (int32_t kx = kx_start; kx < kx_end; ++kx) {
|
|
|
|
|
+ const int32_t ix = base_x + kx*args.d0;
|
|
|
|
|
+ const uint64_t src_offs = src_base_row + (uint64_t) ix * args.nb10;
|
|
|
|
|
+ const uint64_t w_offs = w_base_row + (uint64_t) kx * args.nb00;
|
|
|
|
|
+
|
|
|
|
|
+ const float x = *(device const float *)(src + src_offs);
|
|
|
|
|
+ const float w = (float) (*(device const TK *)(weights + w_offs));
|
|
|
|
|
+
|
|
|
|
|
+ acc += x * w;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ const uint64_t dst_offs =
|
|
|
|
|
+ (uint64_t) n * args.nb3 +
|
|
|
|
|
+ (uint64_t) oc * args.nb2 +
|
|
|
|
|
+ (uint64_t) oh * args.nb1 +
|
|
|
|
|
+ (uint64_t) ow * args.nb0;
|
|
|
|
|
+
|
|
|
|
|
+ *(device float *)(dst + dst_offs) = acc;
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+template [[host_name("kernel_conv_2d_f32_f32")]]
|
|
|
|
|
+kernel void kernel_conv_2d<float>(
|
|
|
|
|
+ constant ggml_metal_kargs_conv_2d & args,
|
|
|
|
|
+ device const char * weights,
|
|
|
|
|
+ device const char * src,
|
|
|
|
|
+ device char * dst,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint3 tgpg[[threadgroups_per_grid]],
|
|
|
|
|
+ uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
+ uint3 ntg[[threads_per_threadgroup]]);
|
|
|
|
|
+
|
|
|
|
|
+template [[host_name("kernel_conv_2d_f16_f32")]]
|
|
|
|
|
+kernel void kernel_conv_2d<half>(
|
|
|
|
|
+ constant ggml_metal_kargs_conv_2d & args,
|
|
|
|
|
+ device const char * weights,
|
|
|
|
|
+ device const char * src,
|
|
|
|
|
+ device char * dst,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint3 tgpg[[threadgroups_per_grid]],
|
|
|
|
|
+ uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
+ uint3 ntg[[threads_per_threadgroup]]);
|
|
|
|
|
+
|
|
|
typedef void (conv_transpose_1d_t)(
|
|
typedef void (conv_transpose_1d_t)(
|
|
|
constant ggml_metal_kargs_conv_transpose_1d & args,
|
|
constant ggml_metal_kargs_conv_transpose_1d & args,
|
|
|
device const float * src0,
|
|
device const float * src0,
|