|
@@ -4179,6 +4179,97 @@ kernel void kernel_conv_transpose_1d<half>(
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
uint3 tgpg[[threadgroups_per_grid]]);
|
|
uint3 tgpg[[threadgroups_per_grid]]);
|
|
|
|
|
|
|
|
|
|
+
|
|
|
|
|
+typedef void (conv_transpose_2d_t)(
|
|
|
|
|
+ constant ggml_metal_kargs_conv_transpose_2d & args,
|
|
|
|
|
+ device const float * src0,
|
|
|
|
|
+ device const float * src1,
|
|
|
|
|
+ device char * dst,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint3 tgpg[[threadgroups_per_grid]]);
|
|
|
|
|
+
|
|
|
|
|
+template <typename T>
|
|
|
|
|
+kernel void kernel_conv_transpose_2d(
|
|
|
|
|
+ constant ggml_metal_kargs_conv_transpose_2d & args,
|
|
|
|
|
+ device const T * src0,
|
|
|
|
|
+ device const float * src1,
|
|
|
|
|
+ device char * dst,
|
|
|
|
|
+ threadgroup float * shared_sum [[threadgroup(0)]],
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
+ uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t out_x = tgpig[0];
|
|
|
|
|
+ const int64_t out_y = tgpig[1];
|
|
|
|
|
+ const int64_t out_c = tgpig[2];
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t kw = tpitg[0];
|
|
|
|
|
+ const int64_t kh = tpitg[1];
|
|
|
|
|
+
|
|
|
|
|
+ float v = 0.0f;
|
|
|
|
|
+
|
|
|
|
|
+ for (int64_t in_c = 0; in_c < args.IC; in_c++) {
|
|
|
|
|
+ int64_t in_y = out_y - kh;
|
|
|
|
|
+
|
|
|
|
|
+ if (in_y < 0 || in_y % args.s0) continue;
|
|
|
|
|
+
|
|
|
|
|
+ in_y /= args.s0;
|
|
|
|
|
+
|
|
|
|
|
+ if (in_y >= args.IH) continue;
|
|
|
|
|
+
|
|
|
|
|
+ int64_t in_x = out_x - kw;
|
|
|
|
|
+
|
|
|
|
|
+ if (in_x < 0 || in_x % args.s0) continue;
|
|
|
|
|
+
|
|
|
|
|
+ in_x /= args.s0;
|
|
|
|
|
+
|
|
|
|
|
+ if (in_x >= args.IW) continue;
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
|
|
|
|
|
+ const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
|
|
|
|
|
+
|
|
|
|
|
+ v += (float)src0[kernel_idx] * src1[input_idx];
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ const uint tid = tpitg.y * ntg.x + tpitg.x;
|
|
|
|
|
+ shared_sum[tid] = v;
|
|
|
|
|
+
|
|
|
|
|
+ threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
+
|
|
|
|
|
+ if (tid == 0) {
|
|
|
|
|
+ float total = 0.0f;
|
|
|
|
|
+ const uint num_threads = ntg.x * ntg.y;
|
|
|
|
|
+ for (uint i = 0; i < num_threads; i++) {
|
|
|
|
|
+ total += shared_sum[i];
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
|
|
|
|
|
+ dst_ptr[0] = total;
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
|
|
|
|
|
+kernel void kernel_conv_transpose_2d<float>(
|
|
|
|
|
+ constant ggml_metal_kargs_conv_transpose_2d & args,
|
|
|
|
|
+ device const float * src0,
|
|
|
|
|
+ device const float * src1,
|
|
|
|
|
+ device char * dst,
|
|
|
|
|
+ threadgroup float * shared_sum [[threadgroup(0)]],
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
+ uint3 ntg[[threads_per_threadgroup]]);
|
|
|
|
|
+
|
|
|
|
|
+template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
|
|
|
|
|
+kernel void kernel_conv_transpose_2d<half>(
|
|
|
|
|
+ constant ggml_metal_kargs_conv_transpose_2d & args,
|
|
|
|
|
+ device const half * src0,
|
|
|
|
|
+ device const float * src1,
|
|
|
|
|
+ device char * dst,
|
|
|
|
|
+ threadgroup float * shared_sum [[threadgroup(0)]],
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
+ uint3 ntg[[threads_per_threadgroup]]);
|
|
|
|
|
+
|
|
|
kernel void kernel_upscale_f32(
|
|
kernel void kernel_upscale_f32(
|
|
|
constant ggml_metal_kargs_upscale & args,
|
|
constant ggml_metal_kargs_upscale & args,
|
|
|
device const char * src0,
|
|
device const char * src0,
|