|
|
@@ -2713,8 +2713,148 @@ kernel void kernel_rope_neox(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+template<typename T>
|
|
|
+kernel void kernel_rope_multi(
|
|
|
+ constant ggml_metal_kargs_rope & args,
|
|
|
+ device const char * src0,
|
|
|
+ device const char * src1,
|
|
|
+ device const char * src2,
|
|
|
+ device char * dst,
|
|
|
+ ushort tiitg[[thread_index_in_threadgroup]],
|
|
|
+ ushort3 tptg [[threads_per_threadgroup]],
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
|
|
|
+ const int i3 = tgpig[2];
|
|
|
+ const int i2 = tgpig[1];
|
|
|
+ const int i1 = tgpig[0];
|
|
|
+
|
|
|
+ float corr_dims[2];
|
|
|
+ rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
|
|
+
|
|
|
+ device const int32_t * pos = (device const int32_t *) src1;
|
|
|
+
|
|
|
+ const float inv_ndims = -1.f/args.n_dims;
|
|
|
+
|
|
|
+ float cos_theta;
|
|
|
+ float sin_theta;
|
|
|
+
|
|
|
+ for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
|
|
+ if (i0 < args.n_dims) {
|
|
|
+ const int ic = i0/2;
|
|
|
+
|
|
|
+ // mrope theta calculations
|
|
|
+ // note: the rest is the same as kernel_rope_neox
|
|
|
+ const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
|
|
|
+ const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
|
|
|
+ const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
|
|
|
+ const int sector = ic % sect_dims;
|
|
|
+
|
|
|
+ float theta_base;
|
|
|
+ if (sector < args.sect_0) {
|
|
|
+ theta_base = (float) pos[i2];
|
|
|
+ } else if (sector < sec_w01) {
|
|
|
+ theta_base = (float) pos[i2 + args.ne02];
|
|
|
+ } else if (sector < sec_w012) {
|
|
|
+ theta_base = (float) pos[i2 + args.ne02 * 2];
|
|
|
+ } else {
|
|
|
+ theta_base = (float) pos[i2 + args.ne02 * 3];
|
|
|
+ }
|
|
|
+ // end of mrope
|
|
|
+
|
|
|
+ const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
|
|
|
+
|
|
|
+ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
|
|
+
|
|
|
+ rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
|
|
+
|
|
|
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
|
|
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
|
|
+
|
|
|
+ const float x0 = src[0];
|
|
|
+ const float x1 = src[args.n_dims/2];
|
|
|
+
|
|
|
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
|
+ dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
|
+ } else {
|
|
|
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
|
|
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
|
+
|
|
|
+ dst_data[0] = src[0];
|
|
|
+ dst_data[1] = src[1];
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+template<typename T>
|
|
|
+kernel void kernel_rope_vision(
|
|
|
+ constant ggml_metal_kargs_rope & args,
|
|
|
+ device const char * src0,
|
|
|
+ device const char * src1,
|
|
|
+ device const char * src2,
|
|
|
+ device char * dst,
|
|
|
+ ushort tiitg[[thread_index_in_threadgroup]],
|
|
|
+ ushort3 tptg [[threads_per_threadgroup]],
|
|
|
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
|
|
|
+ const int i3 = tgpig[2];
|
|
|
+ const int i2 = tgpig[1];
|
|
|
+ const int i1 = tgpig[0];
|
|
|
+
|
|
|
+ float corr_dims[2];
|
|
|
+ rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
|
|
+
|
|
|
+ device const int32_t * pos = (device const int32_t *) src1;
|
|
|
+
|
|
|
+ const float inv_ndims = -1.f/args.n_dims;
|
|
|
+
|
|
|
+ float cos_theta;
|
|
|
+ float sin_theta;
|
|
|
+
|
|
|
+ for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
|
|
+ if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
|
|
|
+ const int ic = i0/2;
|
|
|
+
|
|
|
+ // mrope theta calculations (only support 2 dimensions)
|
|
|
+ const int sect_dims = args.sect_0 + args.sect_1;
|
|
|
+ const int sector = ic % sect_dims;
|
|
|
+
|
|
|
+ float p;
|
|
|
+ float theta_base;
|
|
|
+ if (sector < args.sect_1) {
|
|
|
+ p = (float) sector;
|
|
|
+ theta_base = (float) pos[i2];
|
|
|
+ } else {
|
|
|
+ p = (float) sector - args.sect_0;
|
|
|
+ theta_base = (float) pos[i2 + args.ne02];
|
|
|
+ }
|
|
|
+
|
|
|
+ const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
|
|
|
+ // end of mrope
|
|
|
+
|
|
|
+ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
|
|
+
|
|
|
+ rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
|
|
+
|
|
|
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
|
|
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
|
|
+
|
|
|
+ const float x0 = src[0];
|
|
|
+ const float x1 = src[args.n_dims]; // different from kernel_rope_multi
|
|
|
+
|
|
|
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
|
+ dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
|
|
|
+ } else {
|
|
|
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
|
|
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
|
+
|
|
|
+ dst_data[0] = src[0];
|
|
|
+ dst_data[1] = src[1];
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
|
|
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
|
|
+typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
|
|
|
+typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
|
|
|
|
|
|
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
|
|
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
|
|
@@ -2722,6 +2862,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
|
|
|
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
|
|
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
|
|
|
|
|
+template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
|
|
|
+template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
|
|
|
+
|
|
|
+template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
|
|
|
+template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
|
|
|
+
|
|
|
typedef void (im2col_t)(
|
|
|
device const float * x,
|
|
|
device char * dst,
|