|
@@ -5571,38 +5571,6 @@ kernel void kernel_flash_attn_ext_vec_reduce(
|
|
|
#undef DV
|
|
#undef DV
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-template<typename T>
|
|
|
|
|
-kernel void kernel_set(
|
|
|
|
|
- constant ggml_metal_kargs_set & args,
|
|
|
|
|
- device const char * src0,
|
|
|
|
|
- device const char * src1,
|
|
|
|
|
- device char * dst,
|
|
|
|
|
- uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
- ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
- ushort3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
|
- const int i13 = tgpig[2];
|
|
|
|
|
- const int i12 = tgpig[1];
|
|
|
|
|
- const int i11 = tgpig[0];
|
|
|
|
|
-
|
|
|
|
|
- const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10;
|
|
|
|
|
-
|
|
|
|
|
- const int64_t i3 = n / (args.ne12*args.ne11*args.ne10);
|
|
|
|
|
- const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10);
|
|
|
|
|
- const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10;
|
|
|
|
|
-
|
|
|
|
|
- device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs);
|
|
|
|
|
-
|
|
|
|
|
- for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) {
|
|
|
|
|
- device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10);
|
|
|
|
|
- dst_data[i10] = (T) src[0];
|
|
|
|
|
- }
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-typedef decltype(kernel_set<float>) kernel_set_t;
|
|
|
|
|
-
|
|
|
|
|
-template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set<float>;
|
|
|
|
|
-template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set<int32_t>;
|
|
|
|
|
-
|
|
|
|
|
template<typename T0, typename T1>
|
|
template<typename T0, typename T1>
|
|
|
kernel void kernel_cpy(
|
|
kernel void kernel_cpy(
|
|
|
constant ggml_metal_kargs_cpy & args,
|
|
constant ggml_metal_kargs_cpy & args,
|