|
@@ -3987,8 +3987,72 @@ template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kerne
|
|
|
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
|
|
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
|
|
|
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
|
|
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
|
|
|
|
|
|
|
|
|
|
+typedef void (im2col_t)(
|
|
|
|
|
+ constant ggml_metal_kargs_im2col & args,
|
|
|
|
|
+ device const float * x,
|
|
|
|
|
+ device char * dst,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint3 tgpg[[threadgroups_per_grid]],
|
|
|
|
|
+ uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
+ uint3 ntg[[threads_per_threadgroup]]);
|
|
|
|
|
+
|
|
|
|
|
+template <typename T>
|
|
|
|
|
+kernel void kernel_im2col(
|
|
|
|
|
+ constant ggml_metal_kargs_im2col & args,
|
|
|
|
|
+ device const float * x,
|
|
|
|
|
+ device char * dst,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ uint3 tgpg[[threadgroups_per_grid]],
|
|
|
|
|
+ uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
+ uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
|
+// const int64_t IC = tgpg[0];
|
|
|
|
|
+ const int64_t OH = tgpg[1];
|
|
|
|
|
+ const int64_t OW = tgpg[2];
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t KH = ntg[1];
|
|
|
|
|
+ const int64_t KW = ntg[2];
|
|
|
|
|
+
|
|
|
|
|
+ int64_t in = tpitg[0];
|
|
|
|
|
+ const int64_t ikh = tpitg[1];
|
|
|
|
|
+ const int64_t ikw = tpitg[2];
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t iic = tgpig[0];
|
|
|
|
|
+ const int64_t ioh = tgpig[1];
|
|
|
|
|
+ const int64_t iow = tgpig[2];
|
|
|
|
|
+
|
|
|
|
|
+ const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
|
|
|
|
|
+ const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
|
|
|
|
|
+
|
|
|
|
|
+ int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
|
|
|
|
|
+
|
|
|
|
|
+ device T * pdst = (device T *) (dst);
|
|
|
|
|
+
|
|
|
|
|
+ if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
|
|
|
|
|
+ while (in < args.N) {
|
|
|
|
|
+ pdst[offset_dst] = 0.0f;
|
|
|
|
|
+ offset_dst += ntg[0]*args.CHW*OH*OW;
|
|
|
|
|
+
|
|
|
|
|
+ in += ntg[0];
|
|
|
|
|
+ }
|
|
|
|
|
+ } else {
|
|
|
|
|
+ int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
|
|
|
|
|
+
|
|
|
|
|
+ while (in < args.N) {
|
|
|
|
|
+ pdst[offset_dst] = x[offset_src];
|
|
|
|
|
+
|
|
|
|
|
+ offset_dst += ntg[0]*args.CHW*OH*OW;
|
|
|
|
|
+ offset_src += ntg[0]*args.ofs0;
|
|
|
|
|
+
|
|
|
|
|
+ in += ntg[0];
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
|
|
|
|
|
+template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
|
|
|
|
|
+
|
|
|
// TODO: obolete -- remove
|
|
// TODO: obolete -- remove
|
|
|
-//typedef void (im2col_t)(
|
|
|
|
|
|
|
+//typedef void (im2col_ext_t)(
|
|
|
// constant ggml_metal_kargs_im2col & args,
|
|
// constant ggml_metal_kargs_im2col & args,
|
|
|
// device const float * x,
|
|
// device const float * x,
|
|
|
// device char * dst,
|
|
// device char * dst,
|
|
@@ -3998,100 +4062,48 @@ template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t ker
|
|
|
// uint3 ntg[[threads_per_threadgroup]]);
|
|
// uint3 ntg[[threads_per_threadgroup]]);
|
|
|
//
|
|
//
|
|
|
//template <typename T>
|
|
//template <typename T>
|
|
|
-//kernel void kernel_im2col(
|
|
|
|
|
|
|
+//kernel void kernel_im2col_ext(
|
|
|
// constant ggml_metal_kargs_im2col & args,
|
|
// constant ggml_metal_kargs_im2col & args,
|
|
|
// device const float * x,
|
|
// device const float * x,
|
|
|
// device char * dst,
|
|
// device char * dst,
|
|
|
// uint3 tgpig[[threadgroup_position_in_grid]],
|
|
// uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
-// uint3 tgpg[[threadgroups_per_grid]],
|
|
|
|
|
|
|
+// uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
|
|
|
// uint3 tpitg[[thread_position_in_threadgroup]],
|
|
// uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
-// uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
|
-//// const int64_t IC = tgpg[0];
|
|
|
|
|
-// const int64_t OH = tgpg[1];
|
|
|
|
|
-// const int64_t OW = tgpg[2];
|
|
|
|
|
|
|
+// uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
|
|
|
|
|
+// const int64_t KHW = (int64_t)args.KHW;
|
|
|
//
|
|
//
|
|
|
-//// const int64_t N = ntg[0];
|
|
|
|
|
-// const int64_t KH = ntg[1];
|
|
|
|
|
-// const int64_t KW = ntg[2];
|
|
|
|
|
|
|
+// const int64_t d = tgpig[0] / args.CHW;
|
|
|
|
|
+// const int64_t chw = tgpig[0] % args.CHW;
|
|
|
|
|
+// const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
|
|
|
|
|
+// const int64_t HW = tgpig[0] % KHW;
|
|
|
//
|
|
//
|
|
|
-// const int64_t in = tpitg[0];
|
|
|
|
|
-// const int64_t ikh = tpitg[1];
|
|
|
|
|
-// const int64_t ikw = tpitg[2];
|
|
|
|
|
|
|
+// const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
|
|
|
|
|
+// if (tpitg_0 >= args.N) {
|
|
|
|
|
+// return;
|
|
|
|
|
+// }
|
|
|
//
|
|
//
|
|
|
-// const int64_t iic = tgpig[0];
|
|
|
|
|
-// const int64_t ioh = tgpig[1];
|
|
|
|
|
-// const int64_t iow = tgpig[2];
|
|
|
|
|
|
|
+// const int64_t tpitg_1 = HW / args.KW;
|
|
|
|
|
+// const int64_t tpitg_2 = HW % args.KW;
|
|
|
//
|
|
//
|
|
|
-// const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
|
|
|
|
|
-// const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
|
|
|
|
|
|
|
+// const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
|
|
|
|
|
+// const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
|
|
|
//
|
|
//
|
|
|
-// const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
|
|
|
|
|
|
|
+// const int64_t offset_dst =
|
|
|
|
|
+// (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
|
|
|
|
|
+// (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
|
|
|
//
|
|
//
|
|
|
// device T * pdst = (device T *) (dst);
|
|
// device T * pdst = (device T *) (dst);
|
|
|
//
|
|
//
|
|
|
// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
|
|
// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
|
|
|
// pdst[offset_dst] = 0.0f;
|
|
// pdst[offset_dst] = 0.0f;
|
|
|
// } else {
|
|
// } else {
|
|
|
-// const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
|
|
|
|
|
-// pdst[offset_dst] = x[offset_src];
|
|
|
|
|
|
|
+// const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
|
|
|
|
|
+// pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
|
|
|
// }
|
|
// }
|
|
|
//}
|
|
//}
|
|
|
//
|
|
//
|
|
|
-//template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
|
|
|
|
|
-//template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
|
|
|
|
|
-
|
|
|
|
|
-typedef void (im2col_ext_t)(
|
|
|
|
|
- constant ggml_metal_kargs_im2col & args,
|
|
|
|
|
- device const float * x,
|
|
|
|
|
- device char * dst,
|
|
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
- uint3 tgpg[[threadgroups_per_grid]],
|
|
|
|
|
- uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
- uint3 ntg[[threads_per_threadgroup]]);
|
|
|
|
|
-
|
|
|
|
|
-template <typename T>
|
|
|
|
|
-kernel void kernel_im2col_ext(
|
|
|
|
|
- constant ggml_metal_kargs_im2col & args,
|
|
|
|
|
- device const float * x,
|
|
|
|
|
- device char * dst,
|
|
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
- uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
|
|
|
|
|
- uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
- uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1]
|
|
|
|
|
- const int64_t KHW = (int64_t)args.KHW;
|
|
|
|
|
-
|
|
|
|
|
- const int64_t d = tgpig[0] / args.CHW;
|
|
|
|
|
- const int64_t chw = tgpig[0] % args.CHW;
|
|
|
|
|
- const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1)
|
|
|
|
|
- const int64_t HW = tgpig[0] % KHW;
|
|
|
|
|
-
|
|
|
|
|
- const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
|
|
|
|
|
- if (tpitg_0 >= args.N) {
|
|
|
|
|
- return;
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- const int64_t tpitg_1 = HW / args.KW;
|
|
|
|
|
- const int64_t tpitg_2 = HW % args.KW;
|
|
|
|
|
-
|
|
|
|
|
- const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
|
|
|
|
|
- const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
|
|
|
|
|
-
|
|
|
|
|
- const int64_t offset_dst =
|
|
|
|
|
- (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
|
|
|
|
|
- (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
|
|
|
|
|
-
|
|
|
|
|
- device T * pdst = (device T *) (dst);
|
|
|
|
|
-
|
|
|
|
|
- if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
|
|
|
|
|
- pdst[offset_dst] = 0.0f;
|
|
|
|
|
- } else {
|
|
|
|
|
- const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
|
|
|
|
|
- pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
|
|
|
|
|
- }
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
|
|
|
|
|
-template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
|
|
|
|
|
|
|
+//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>;
|
|
|
|
|
+//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>;
|
|
|
|
|
|
|
|
typedef void (conv_transpose_1d_t)(
|
|
typedef void (conv_transpose_1d_t)(
|
|
|
constant ggml_metal_kargs_conv_transpose_1d & args,
|
|
constant ggml_metal_kargs_conv_transpose_1d & args,
|