|
@@ -4541,69 +4541,179 @@ kernel void kernel_timestep_embedding_f32(
|
|
|
// bitonic sort implementation following the CUDA kernels as reference
|
|
// bitonic sort implementation following the CUDA kernels as reference
|
|
|
typedef void (argsort_t)(
|
|
typedef void (argsort_t)(
|
|
|
constant ggml_metal_kargs_argsort & args,
|
|
constant ggml_metal_kargs_argsort & args,
|
|
|
- device const float * x,
|
|
|
|
|
|
|
+ device const char * src0,
|
|
|
device int32_t * dst,
|
|
device int32_t * dst,
|
|
|
- threadgroup int32_t * shared_values [[threadgroup(0)]],
|
|
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
- uint3 tpitg[[thread_position_in_threadgroup]]);
|
|
|
|
|
|
|
+ threadgroup int32_t * smem_i32 [[threadgroup(0)]],
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
+ ushort3 ntg[[threads_per_threadgroup]]);
|
|
|
|
|
|
|
|
template<ggml_sort_order order>
|
|
template<ggml_sort_order order>
|
|
|
kernel void kernel_argsort_f32_i32(
|
|
kernel void kernel_argsort_f32_i32(
|
|
|
constant ggml_metal_kargs_argsort & args,
|
|
constant ggml_metal_kargs_argsort & args,
|
|
|
- device const float * x,
|
|
|
|
|
|
|
+ device const char * src0,
|
|
|
device int32_t * dst,
|
|
device int32_t * dst,
|
|
|
- threadgroup int32_t * shared_values [[threadgroup(0)]],
|
|
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
- uint3 tpitg[[thread_position_in_threadgroup]]) {
|
|
|
|
|
|
|
+ threadgroup int32_t * smem_i32 [[threadgroup(0)]],
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
+ ushort3 ntg[[threads_per_threadgroup]]) {
|
|
|
// bitonic sort
|
|
// bitonic sort
|
|
|
- int col = tpitg[0];
|
|
|
|
|
- int row = tgpig[1];
|
|
|
|
|
|
|
+ const int col = tpitg[0];
|
|
|
|
|
|
|
|
- if (col >= args.ncols_pad) return;
|
|
|
|
|
|
|
+ const int i00 = (tgpig[0]/args.ne01)*ntg.x;
|
|
|
|
|
+ const int i01 = tgpig[0]%args.ne01;
|
|
|
|
|
+ const int i02 = tgpig[1];
|
|
|
|
|
+ const int i03 = tgpig[2];
|
|
|
|
|
|
|
|
- device const float * x_row = x + row * args.ncols;
|
|
|
|
|
- threadgroup int32_t * dst_row = shared_values;
|
|
|
|
|
|
|
+ device const float * x_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03);
|
|
|
|
|
|
|
|
// initialize indices
|
|
// initialize indices
|
|
|
- dst_row[col] = col;
|
|
|
|
|
|
|
+ smem_i32[col] = i00 + col;
|
|
|
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
|
- for (int k = 2; k <= args.ncols_pad; k *= 2) {
|
|
|
|
|
|
|
+ for (int k = 2; k <= ntg.x; k *= 2) {
|
|
|
for (int j = k / 2; j > 0; j /= 2) {
|
|
for (int j = k / 2; j > 0; j /= 2) {
|
|
|
int ixj = col ^ j;
|
|
int ixj = col ^ j;
|
|
|
if (ixj > col) {
|
|
if (ixj > col) {
|
|
|
if ((col & k) == 0) {
|
|
if ((col & k) == 0) {
|
|
|
- if (dst_row[col] >= args.ncols ||
|
|
|
|
|
- (dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
|
|
|
|
|
- x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
|
|
|
|
- x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
|
|
|
|
|
|
+ if (smem_i32[col] >= args.ne00 ||
|
|
|
|
|
+ (smem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
|
|
|
|
+ x_row[smem_i32[col]] > x_row[smem_i32[ixj]] :
|
|
|
|
|
+ x_row[smem_i32[col]] < x_row[smem_i32[ixj]]))
|
|
|
) {
|
|
) {
|
|
|
- SWAP(dst_row[col], dst_row[ixj]);
|
|
|
|
|
|
|
+ SWAP(smem_i32[col], smem_i32[ixj]);
|
|
|
}
|
|
}
|
|
|
} else {
|
|
} else {
|
|
|
- if (dst_row[ixj] >= args.ncols ||
|
|
|
|
|
- (dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
|
|
|
|
|
- x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
|
|
|
|
- x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
|
|
|
|
|
|
+ if (smem_i32[ixj] >= args.ne00 ||
|
|
|
|
|
+ (smem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ?
|
|
|
|
|
+ x_row[smem_i32[col]] < x_row[smem_i32[ixj]] :
|
|
|
|
|
+ x_row[smem_i32[col]] > x_row[smem_i32[ixj]]))
|
|
|
) {
|
|
) {
|
|
|
- SWAP(dst_row[col], dst_row[ixj]);
|
|
|
|
|
|
|
+ SWAP(smem_i32[col], smem_i32[ixj]);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// copy the result to dst without the padding
|
|
// copy the result to dst without the padding
|
|
|
- if (col < args.ncols) {
|
|
|
|
|
- dst[row * args.ncols + col] = dst_row[col];
|
|
|
|
|
|
|
+ if (i00 + col < args.ne00) {
|
|
|
|
|
+ dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03;
|
|
|
|
|
+
|
|
|
|
|
+ dst[col] = smem_i32[col];
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
|
|
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
|
|
|
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
|
|
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
|
|
|
|
|
|
|
|
|
|
+typedef void (argsort_merge_t)(
|
|
|
|
|
+ constant ggml_metal_kargs_argsort_merge & args,
|
|
|
|
|
+ device const char * src0,
|
|
|
|
|
+ device const int32_t * tmp,
|
|
|
|
|
+ device int32_t * dst,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
+ ushort3 ntg[[threads_per_threadgroup]]);
|
|
|
|
|
+
|
|
|
|
|
+template<ggml_sort_order order>
|
|
|
|
|
+kernel void kernel_argsort_merge_f32_i32(
|
|
|
|
|
+ constant ggml_metal_kargs_argsort_merge & args,
|
|
|
|
|
+ device const char * src0,
|
|
|
|
|
+ device const int32_t * tmp,
|
|
|
|
|
+ device int32_t * dst,
|
|
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
+ ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
+ ushort3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
|
+ int im = tgpig[0] / args.ne01;
|
|
|
|
|
+ int i01 = tgpig[0] % args.ne01;
|
|
|
|
|
+ int i02 = tgpig[1];
|
|
|
|
|
+ int i03 = tgpig[2];
|
|
|
|
|
+
|
|
|
|
|
+ const int start = im * (2*args.len);
|
|
|
|
|
+
|
|
|
|
|
+ const int len0 = MIN(args.len, MAX(0, args.ne00 - (int)(start)));
|
|
|
|
|
+ const int len1 = MIN(args.len, MAX(0, args.ne00 - (int)(start + args.len)));
|
|
|
|
|
+
|
|
|
|
|
+ const int total = len0 + len1;
|
|
|
|
|
+
|
|
|
|
|
+ device const int32_t * tmp0 = tmp + start
|
|
|
|
|
+ + i01*args.ne00
|
|
|
|
|
+ + i02*args.ne00*args.ne01
|
|
|
|
|
+ + i03*args.ne00*args.ne01*args.ne02;
|
|
|
|
|
+
|
|
|
|
|
+ device const int32_t * tmp1 = tmp0 + args.len;
|
|
|
|
|
+
|
|
|
|
|
+ dst += start
|
|
|
|
|
+ + i01*args.ne00
|
|
|
|
|
+ + i02*args.ne00*args.ne01
|
|
|
|
|
+ + i03*args.ne00*args.ne01*args.ne02;
|
|
|
|
|
+
|
|
|
|
|
+ device const float * src0_row = (device const float *)(src0
|
|
|
|
|
+ + args.nb01*i01
|
|
|
|
|
+ + args.nb02*i02
|
|
|
|
|
+ + args.nb03*i03);
|
|
|
|
|
+
|
|
|
|
|
+ for (int k = tpitg.x; k < (int) total; k += ntg.x) {
|
|
|
|
|
+ // find partition (i,j) such that i+j = k
|
|
|
|
|
+ int low = k > len1 ? k - len1 : 0;
|
|
|
|
|
+ int high = MIN(k, len0);
|
|
|
|
|
+
|
|
|
|
|
+ while (low < high) {
|
|
|
|
|
+ const int mid = (low + high) >> 1;
|
|
|
|
|
+
|
|
|
|
|
+ const int32_t idx0 = tmp0[mid];
|
|
|
|
|
+ const int32_t idx1 = tmp1[k - mid - 1];
|
|
|
|
|
+
|
|
|
|
|
+ const float val0 = src0_row[idx0];
|
|
|
|
|
+ const float val1 = src0_row[idx1];
|
|
|
|
|
+
|
|
|
|
|
+ if (order == GGML_SORT_ORDER_ASC) {
|
|
|
|
|
+ if (val0 <= val1) {
|
|
|
|
|
+ low = mid + 1;
|
|
|
|
|
+ } else {
|
|
|
|
|
+ high = mid;
|
|
|
|
|
+ }
|
|
|
|
|
+ } else {
|
|
|
|
|
+ if (val0 >= val1) {
|
|
|
|
|
+ low = mid + 1;
|
|
|
|
|
+ } else {
|
|
|
|
|
+ high = mid;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ const int i = low;
|
|
|
|
|
+ const int j = k - i;
|
|
|
|
|
+
|
|
|
|
|
+ int32_t out_idx;
|
|
|
|
|
+
|
|
|
|
|
+ if (i >= len0) {
|
|
|
|
|
+ out_idx = tmp1[j];
|
|
|
|
|
+ } else if (j >= len1) {
|
|
|
|
|
+ out_idx = tmp0[i];
|
|
|
|
|
+ } else {
|
|
|
|
|
+ const int32_t idx0 = tmp0[i];
|
|
|
|
|
+ const int32_t idx1 = tmp1[j];
|
|
|
|
|
+
|
|
|
|
|
+ const float val0 = src0_row[idx0];
|
|
|
|
|
+ const float val1 = src0_row[idx1];
|
|
|
|
|
+
|
|
|
|
|
+ out_idx = (order == GGML_SORT_ORDER_ASC)
|
|
|
|
|
+ ? (val0 <= val1 ? idx0 : idx1)
|
|
|
|
|
+ : (val0 >= val1 ? idx0 : idx1);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ dst[k] = out_idx;
|
|
|
|
|
+ }
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>;
|
|
|
|
|
+template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>;
|
|
|
|
|
+
|
|
|
kernel void kernel_leaky_relu_f32(
|
|
kernel void kernel_leaky_relu_f32(
|
|
|
constant ggml_metal_kargs_leaky_relu & args,
|
|
constant ggml_metal_kargs_leaky_relu & args,
|
|
|
device const float * src0,
|
|
device const float * src0,
|