|
@@ -1070,20 +1070,20 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
|
|
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
|
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
|
|
|
static void rope_yarn(
|
|
static void rope_yarn(
|
|
|
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
|
float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
|
|
|
- float * cos_theta, float * sin_theta
|
|
|
|
|
|
|
+ thread float * cos_theta, thread float * sin_theta
|
|
|
) {
|
|
) {
|
|
|
// Get n-d rotational scaling corrected for extrapolation
|
|
// Get n-d rotational scaling corrected for extrapolation
|
|
|
float theta_interp = freq_scale * theta_extrap;
|
|
float theta_interp = freq_scale * theta_extrap;
|
|
|
float theta = theta_interp;
|
|
float theta = theta_interp;
|
|
|
if (ext_factor != 0.0f) {
|
|
if (ext_factor != 0.0f) {
|
|
|
- ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
|
|
|
|
|
|
+ float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
|
|
|
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
|
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
|
|
|
|
|
|
|
|
// Get n-d magnitude scaling corrected for interpolation
|
|
// Get n-d magnitude scaling corrected for interpolation
|
|
|
- mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
|
|
|
|
|
|
+ mscale *= 1.0f + 0.1f * log(1.0f / freq_scale);
|
|
|
}
|
|
}
|
|
|
- *cos_theta = cosf(theta) * mscale;
|
|
|
|
|
- *sin_theta = sinf(theta) * mscale;
|
|
|
|
|
|
|
+ *cos_theta = cos(theta) * mscale;
|
|
|
|
|
+ *sin_theta = sin(theta) * mscale;
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
|
// Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
|
|
@@ -1123,8 +1123,13 @@ typedef void (rope_t)(
|
|
|
constant int & n_past,
|
|
constant int & n_past,
|
|
|
constant int & n_dims,
|
|
constant int & n_dims,
|
|
|
constant int & mode,
|
|
constant int & mode,
|
|
|
|
|
+ constant int & n_orig_ctx,
|
|
|
constant float & freq_base,
|
|
constant float & freq_base,
|
|
|
constant float & freq_scale,
|
|
constant float & freq_scale,
|
|
|
|
|
+ constant float & ext_factor,
|
|
|
|
|
+ constant float & attn_factor,
|
|
|
|
|
+ constant float & beta_fast,
|
|
|
|
|
+ constant float & beta_slow,
|
|
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
|
uint3 tptg[[threads_per_threadgroup]],
|
|
uint3 tptg[[threads_per_threadgroup]],
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]]);
|
|
uint3 tgpig[[threadgroup_position_in_grid]]);
|
|
@@ -1153,6 +1158,7 @@ kernel void kernel_rope(
|
|
|
constant int & n_past,
|
|
constant int & n_past,
|
|
|
constant int & n_dims,
|
|
constant int & n_dims,
|
|
|
constant int & mode,
|
|
constant int & mode,
|
|
|
|
|
+ constant int & n_orig_ctx,
|
|
|
constant float & freq_base,
|
|
constant float & freq_base,
|
|
|
constant float & freq_scale,
|
|
constant float & freq_scale,
|
|
|
constant float & ext_factor,
|
|
constant float & ext_factor,
|