|
|
@@ -4970,8 +4970,13 @@ struct ggml_tensor * ggml_rope_back(
|
|
|
int n_dims,
|
|
|
int mode,
|
|
|
int n_ctx,
|
|
|
+ int n_orig_ctx,
|
|
|
float freq_base,
|
|
|
float freq_scale,
|
|
|
+ float ext_factor,
|
|
|
+ float attn_factor,
|
|
|
+ float beta_fast,
|
|
|
+ float beta_slow,
|
|
|
float xpos_base,
|
|
|
bool xpos_down) {
|
|
|
GGML_ASSERT(ggml_is_vector(b));
|
|
|
@@ -4988,11 +4993,15 @@ struct ggml_tensor * ggml_rope_back(
|
|
|
|
|
|
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
|
|
|
|
|
|
- int32_t params[8] = { /*n_past*/ 0, n_dims, mode, n_ctx };
|
|
|
- memcpy(params + 4, &freq_base, sizeof(float));
|
|
|
- memcpy(params + 5, &freq_scale, sizeof(float));
|
|
|
- memcpy(params + 6, &xpos_base, sizeof(float));
|
|
|
- memcpy(params + 7, &xpos_down, sizeof(bool));
|
|
|
+ int32_t params[13] = { /*n_past*/ 0, n_dims, mode, n_ctx, n_orig_ctx };
|
|
|
+ memcpy(params + 5, &freq_base, sizeof(float));
|
|
|
+ memcpy(params + 6, &freq_scale, sizeof(float));
|
|
|
+ memcpy(params + 7, &ext_factor, sizeof(float));
|
|
|
+ memcpy(params + 8, &attn_factor, sizeof(float));
|
|
|
+ memcpy(params + 9, &beta_fast, sizeof(float));
|
|
|
+ memcpy(params + 10, &beta_slow, sizeof(float));
|
|
|
+ memcpy(params + 11, &xpos_base, sizeof(float));
|
|
|
+ memcpy(params + 12, &xpos_down, sizeof(bool));
|
|
|
ggml_set_op_params(result, params, sizeof(params));
|
|
|
|
|
|
result->op = GGML_OP_ROPE_BACK;
|
|
|
@@ -10974,7 +10983,8 @@ static void ggml_compute_forward_rope_f32(
|
|
|
const struct ggml_compute_params * params,
|
|
|
const struct ggml_tensor * src0,
|
|
|
const struct ggml_tensor * src1,
|
|
|
- struct ggml_tensor * dst) {
|
|
|
+ struct ggml_tensor * dst,
|
|
|
+ const bool forward) {
|
|
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
|
|
return;
|
|
|
}
|
|
|
@@ -11033,6 +11043,11 @@ static void ggml_compute_forward_rope_f32(
|
|
|
const bool is_neox = mode & 2;
|
|
|
const bool is_glm = mode & 4;
|
|
|
|
|
|
+ // backward process uses inverse rotation by cos and sin.
|
|
|
+ // cos and sin build a rotation matrix, where the inverse is the transpose.
|
|
|
+ // this essentially just switches the sign of sin.
|
|
|
+ const float sin_sign = forward ? 1.0f : -1.0f;
|
|
|
+
|
|
|
const int32_t * pos = (const int32_t *) src1->data;
|
|
|
|
|
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
|
@@ -11049,9 +11064,9 @@ static void ggml_compute_forward_rope_f32(
|
|
|
float block_theta = MAX(p - (n_ctx - 2), 0);
|
|
|
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
|
|
|
const float cos_theta = cosf(theta_base);
|
|
|
- const float sin_theta = sinf(theta_base);
|
|
|
+ const float sin_theta = sinf(theta_base) * sin_sign;
|
|
|
const float cos_block_theta = cosf(block_theta);
|
|
|
- const float sin_block_theta = sinf(block_theta);
|
|
|
+ const float sin_block_theta = sinf(block_theta) * sin_sign;
|
|
|
|
|
|
theta_base *= theta_scale;
|
|
|
block_theta *= theta_scale;
|
|
|
@@ -11075,6 +11090,7 @@ static void ggml_compute_forward_rope_f32(
|
|
|
rope_yarn(
|
|
|
theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
|
|
|
);
|
|
|
+ sin_theta *= sin_sign;
|
|
|
|
|
|
// zeta scaling for xPos only:
|
|
|
float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
|
|
|
@@ -11105,6 +11121,7 @@ static void ggml_compute_forward_rope_f32(
|
|
|
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
|
|
&cos_theta, &sin_theta
|
|
|
);
|
|
|
+ sin_theta *= sin_sign;
|
|
|
|
|
|
theta_base *= theta_scale;
|
|
|
|
|
|
@@ -11130,7 +11147,8 @@ static void ggml_compute_forward_rope_f16(
|
|
|
const struct ggml_compute_params * params,
|
|
|
const struct ggml_tensor * src0,
|
|
|
const struct ggml_tensor * src1,
|
|
|
- struct ggml_tensor * dst) {
|
|
|
+ struct ggml_tensor * dst,
|
|
|
+ const bool forward) {
|
|
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
|
|
return;
|
|
|
}
|
|
|
@@ -11182,6 +11200,11 @@ static void ggml_compute_forward_rope_f16(
|
|
|
const bool is_neox = mode & 2;
|
|
|
const bool is_glm = mode & 4;
|
|
|
|
|
|
+ // backward process uses inverse rotation by cos and sin.
|
|
|
+ // cos and sin build a rotation matrix, where the inverse is the transpose.
|
|
|
+ // this essentially just switches the sign of sin.
|
|
|
+ const float sin_sign = forward ? 1.0f : -1.0f;
|
|
|
+
|
|
|
const int32_t * pos = (const int32_t *) src1->data;
|
|
|
|
|
|
for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
|
@@ -11198,9 +11221,9 @@ static void ggml_compute_forward_rope_f16(
|
|
|
float block_theta = MAX(p - (n_ctx - 2), 0);
|
|
|
for (int64_t i0 = 0; i0 < ne0 / 4; i0++) {
|
|
|
const float cos_theta = cosf(theta_base);
|
|
|
- const float sin_theta = sinf(theta_base);
|
|
|
+ const float sin_theta = sinf(theta_base) * sin_sign;
|
|
|
const float cos_block_theta = cosf(block_theta);
|
|
|
- const float sin_block_theta = sinf(block_theta);
|
|
|
+ const float sin_block_theta = sinf(block_theta) * sin_sign;
|
|
|
|
|
|
theta_base *= theta_scale;
|
|
|
block_theta *= theta_scale;
|
|
|
@@ -11224,6 +11247,7 @@ static void ggml_compute_forward_rope_f16(
|
|
|
rope_yarn(
|
|
|
theta_base, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta
|
|
|
);
|
|
|
+ sin_theta *= sin_sign;
|
|
|
|
|
|
theta_base *= theta_scale;
|
|
|
|
|
|
@@ -11250,6 +11274,7 @@ static void ggml_compute_forward_rope_f16(
|
|
|
theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
|
|
|
&cos_theta, &sin_theta
|
|
|
);
|
|
|
+ sin_theta *= sin_sign;
|
|
|
|
|
|
theta_base *= theta_scale;
|
|
|
|
|
|
@@ -11279,11 +11304,11 @@ static void ggml_compute_forward_rope(
|
|
|
switch (src0->type) {
|
|
|
case GGML_TYPE_F16:
|
|
|
{
|
|
|
- ggml_compute_forward_rope_f16(params, src0, src1, dst);
|
|
|
+ ggml_compute_forward_rope_f16(params, src0, src1, dst, true);
|
|
|
} break;
|
|
|
case GGML_TYPE_F32:
|
|
|
{
|
|
|
- ggml_compute_forward_rope_f32(params, src0, src1, dst);
|
|
|
+ ggml_compute_forward_rope_f32(params, src0, src1, dst, true);
|
|
|
} break;
|
|
|
default:
|
|
|
{
|
|
|
@@ -11294,216 +11319,6 @@ static void ggml_compute_forward_rope(
|
|
|
|
|
|
// ggml_compute_forward_rope_back
|
|
|
|
|
|
-static void ggml_compute_forward_rope_back_f32(
|
|
|
- const struct ggml_compute_params * params,
|
|
|
- const struct ggml_tensor * src0,
|
|
|
- const struct ggml_tensor * src1,
|
|
|
- struct ggml_tensor * dst) {
|
|
|
-
|
|
|
- if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
- // y = rope(x, src1)
|
|
|
- // dx = rope_back(dy, src1)
|
|
|
- // src0 is dy, src1 contains options
|
|
|
-
|
|
|
- float freq_base;
|
|
|
- float freq_scale;
|
|
|
-
|
|
|
- // these two only relevant for xPos RoPE:
|
|
|
- float xpos_base;
|
|
|
- bool xpos_down;
|
|
|
-
|
|
|
- //const int n_past = ((int32_t *) dst->op_params)[0];
|
|
|
- const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
|
- const int mode = ((int32_t *) dst->op_params)[2];
|
|
|
- const int n_ctx = ((int32_t *) dst->op_params)[3]; UNUSED(n_ctx);
|
|
|
- memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
|
|
|
- memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
|
- memcpy(&xpos_base, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
|
- memcpy(&xpos_down, (int32_t *) dst->op_params + 7, sizeof(bool));
|
|
|
-
|
|
|
- GGML_TENSOR_UNARY_OP_LOCALS
|
|
|
-
|
|
|
- //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
|
|
- //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
|
|
-
|
|
|
- assert(nb0 == sizeof(float));
|
|
|
-
|
|
|
- const int ith = params->ith;
|
|
|
- const int nth = params->nth;
|
|
|
-
|
|
|
- const int nr = ggml_nrows(dst);
|
|
|
-
|
|
|
- // rows per thread
|
|
|
- const int dr = (nr + nth - 1)/nth;
|
|
|
-
|
|
|
- // row range for this thread
|
|
|
- const int ir0 = dr*ith;
|
|
|
- const int ir1 = MIN(ir0 + dr, nr);
|
|
|
-
|
|
|
- // row index used to determine which thread to use
|
|
|
- int ir = 0;
|
|
|
-
|
|
|
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
|
-
|
|
|
- const bool is_neox = mode & 2;
|
|
|
-
|
|
|
- const int32_t * pos = (const int32_t *) src1->data;
|
|
|
-
|
|
|
- for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
|
- for (int64_t i2 = 0; i2 < ne2; i2++) {
|
|
|
- const int64_t p = pos[i2];
|
|
|
- for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
|
- if (ir++ < ir0) continue;
|
|
|
- if (ir > ir1) break;
|
|
|
-
|
|
|
- float theta_base = freq_scale * (float)p;
|
|
|
-
|
|
|
- if (!is_neox) {
|
|
|
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
|
|
- const float cos_theta = cosf(theta_base);
|
|
|
- const float sin_theta = sinf(theta_base);
|
|
|
-
|
|
|
- // zeta scaling for xPos only:
|
|
|
- float zeta = xpos_base != 0.0f ? powf((i0 + 0.4f * ne0) / (1.4f * ne0), p / xpos_base) : 1.0f;
|
|
|
- if (xpos_down) zeta = 1.0f / zeta;
|
|
|
-
|
|
|
- theta_base *= theta_scale;
|
|
|
-
|
|
|
- const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
|
- float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
-
|
|
|
- const float dy0 = dy[0];
|
|
|
- const float dy1 = dy[1];
|
|
|
-
|
|
|
- dx[0] = dy0*cos_theta*zeta + dy1*sin_theta*zeta;
|
|
|
- dx[1] = - dy0*sin_theta*zeta + dy1*cos_theta*zeta;
|
|
|
- }
|
|
|
- } else {
|
|
|
- for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
|
|
- for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
|
|
- const float cos_theta = cosf(theta_base);
|
|
|
- const float sin_theta = sinf(theta_base);
|
|
|
-
|
|
|
- theta_base *= theta_scale;
|
|
|
-
|
|
|
- const int64_t i0 = ib*n_dims + ic/2;
|
|
|
-
|
|
|
- const float * const dy = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
|
- float * dx = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
-
|
|
|
- const float dy0 = dy[0];
|
|
|
- const float dy1 = dy[n_dims/2];
|
|
|
-
|
|
|
- dx[0] = dy0*cos_theta + dy1*sin_theta;
|
|
|
- dx[n_dims/2] = - dy0*sin_theta + dy1*cos_theta;
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-static void ggml_compute_forward_rope_back_f16(
|
|
|
- const struct ggml_compute_params * params,
|
|
|
- const struct ggml_tensor * src0,
|
|
|
- const struct ggml_tensor * src1,
|
|
|
- struct ggml_tensor * dst) {
|
|
|
-
|
|
|
- if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
|
|
- return;
|
|
|
- }
|
|
|
-
|
|
|
- // y = rope(x, src1)
|
|
|
- // dx = rope_back(dy, src1)
|
|
|
- // src0 is dy, src1 contains options
|
|
|
-
|
|
|
- //const int n_past = ((int32_t *) dst->op_params)[0];
|
|
|
- const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
|
- const int mode = ((int32_t *) dst->op_params)[2];
|
|
|
-
|
|
|
- GGML_TENSOR_UNARY_OP_LOCALS
|
|
|
-
|
|
|
- //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
|
|
- //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
|
|
-
|
|
|
- assert(nb0 == sizeof(ggml_fp16_t));
|
|
|
-
|
|
|
- const int ith = params->ith;
|
|
|
- const int nth = params->nth;
|
|
|
-
|
|
|
- const int nr = ggml_nrows(dst);
|
|
|
-
|
|
|
- // rows per thread
|
|
|
- const int dr = (nr + nth - 1)/nth;
|
|
|
-
|
|
|
- // row range for this thread
|
|
|
- const int ir0 = dr*ith;
|
|
|
- const int ir1 = MIN(ir0 + dr, nr);
|
|
|
-
|
|
|
- // row index used to determine which thread to use
|
|
|
- int ir = 0;
|
|
|
-
|
|
|
- const float theta_scale = powf(10000.0, -2.0f/n_dims);
|
|
|
-
|
|
|
- const bool is_neox = mode & 2;
|
|
|
-
|
|
|
- const int32_t * pos = (const int32_t *) src1->data;
|
|
|
-
|
|
|
- for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
|
- for (int64_t i2 = 0; i2 < ne2; i2++) {
|
|
|
- const int64_t p = pos[i2];
|
|
|
- for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
|
- if (ir++ < ir0) continue;
|
|
|
- if (ir > ir1) break;
|
|
|
-
|
|
|
- float theta_base = (float)p;
|
|
|
-
|
|
|
- if (!is_neox) {
|
|
|
- for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
|
|
|
- const float cos_theta = cosf(theta_base);
|
|
|
- const float sin_theta = sinf(theta_base);
|
|
|
-
|
|
|
- theta_base *= theta_scale;
|
|
|
-
|
|
|
- const ggml_fp16_t * const dy = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
|
- ggml_fp16_t * dx = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
-
|
|
|
- const float dy0 = GGML_FP16_TO_FP32(dy[0]);
|
|
|
- const float dy1 = GGML_FP16_TO_FP32(dy[1]);
|
|
|
-
|
|
|
- dx[0] = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
|
|
|
- dx[1] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
|
|
|
- }
|
|
|
- } else {
|
|
|
- for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
|
|
|
- for (int64_t ic = 0; ic < n_dims; ic += 2) {
|
|
|
- const float cos_theta = cosf(theta_base);
|
|
|
- const float sin_theta = sinf(theta_base);
|
|
|
-
|
|
|
- theta_base *= theta_scale;
|
|
|
-
|
|
|
- const int64_t i0 = ib*n_dims + ic/2;
|
|
|
-
|
|
|
- const ggml_fp16_t * const dy = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
|
- ggml_fp16_t * dx = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
-
|
|
|
- const float dy0 = GGML_FP16_TO_FP32(dy[0]);
|
|
|
- const float dy1 = GGML_FP16_TO_FP32(dy[n_dims/2]);
|
|
|
-
|
|
|
- dx[0] = GGML_FP32_TO_FP16( dy0*cos_theta + dy1*sin_theta);
|
|
|
- dx[n_dims/2] = GGML_FP32_TO_FP16(-dy0*sin_theta + dy1*cos_theta);
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
static void ggml_compute_forward_rope_back(
|
|
|
const struct ggml_compute_params * params,
|
|
|
const struct ggml_tensor * src0,
|
|
|
@@ -11512,11 +11327,11 @@ static void ggml_compute_forward_rope_back(
|
|
|
switch (src0->type) {
|
|
|
case GGML_TYPE_F16:
|
|
|
{
|
|
|
- ggml_compute_forward_rope_back_f16(params, src0, src1, dst);
|
|
|
+ ggml_compute_forward_rope_f16(params, src0, src1, dst, false);
|
|
|
} break;
|
|
|
case GGML_TYPE_F32:
|
|
|
{
|
|
|
- ggml_compute_forward_rope_back_f32(params, src0, src1, dst);
|
|
|
+ ggml_compute_forward_rope_f32(params, src0, src1, dst, false);
|
|
|
} break;
|
|
|
default:
|
|
|
{
|
|
|
@@ -15559,17 +15374,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
|
// necessary for llama
|
|
|
if (src0->grad) {
|
|
|
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
|
|
- const int n_dims = ((int32_t *) tensor->op_params)[1];
|
|
|
- const int mode = ((int32_t *) tensor->op_params)[2];
|
|
|
- const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
|
|
- float freq_base;
|
|
|
- float freq_scale;
|
|
|
- float xpos_base;
|
|
|
- bool xpos_down;
|
|
|
- memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float));
|
|
|
- memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
|
|
|
- memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
|
|
|
- memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
|
|
|
+ const int n_dims = ((int32_t *) tensor->op_params)[1];
|
|
|
+ const int mode = ((int32_t *) tensor->op_params)[2];
|
|
|
+ const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
|
|
+ const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
|
|
|
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
|
|
|
+
|
|
|
+ memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
|
|
|
+ memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
|
|
|
+ memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
|
|
|
+ memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
|
|
|
+ memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
|
|
|
+ memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
|
|
|
+ memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
|
|
|
+ memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
|
|
|
|
|
|
src0->grad = ggml_add_or_set(ctx,
|
|
|
src0->grad,
|
|
|
@@ -15579,8 +15397,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
|
n_dims,
|
|
|
mode,
|
|
|
n_ctx,
|
|
|
+ n_orig_ctx,
|
|
|
freq_base,
|
|
|
freq_scale,
|
|
|
+ ext_factor,
|
|
|
+ attn_factor,
|
|
|
+ beta_fast,
|
|
|
+ beta_slow,
|
|
|
xpos_base,
|
|
|
xpos_down),
|
|
|
zero_table);
|
|
|
@@ -15590,17 +15413,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
|
{
|
|
|
if (src0->grad) {
|
|
|
//const int n_past = ((int32_t *) tensor->op_params)[0];
|
|
|
- const int n_dims = ((int32_t *) tensor->op_params)[1];
|
|
|
- const int mode = ((int32_t *) tensor->op_params)[2];
|
|
|
- const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
|
|
- float freq_base;
|
|
|
- float freq_scale;
|
|
|
- float xpos_base;
|
|
|
- bool xpos_down;
|
|
|
- memcpy(&freq_base, (int32_t *) tensor->op_params + 4, sizeof(float));
|
|
|
- memcpy(&freq_scale, (int32_t *) tensor->op_params + 5, sizeof(float));
|
|
|
- memcpy(&xpos_base, (int32_t *) tensor->op_params + 6, sizeof(float));
|
|
|
- memcpy(&xpos_down, (int32_t *) tensor->op_params + 7, sizeof(bool));
|
|
|
+ const int n_dims = ((int32_t *) tensor->op_params)[1];
|
|
|
+ const int mode = ((int32_t *) tensor->op_params)[2];
|
|
|
+ const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
|
|
+ const int n_orig_ctx = ((int32_t *) tensor->op_params)[4];
|
|
|
+ float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow, xpos_base, xpos_down;
|
|
|
+
|
|
|
+ memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float));
|
|
|
+ memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float));
|
|
|
+ memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float));
|
|
|
+ memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float));
|
|
|
+ memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float));
|
|
|
+ memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float));
|
|
|
+ memcpy(&xpos_base, (int32_t *) tensor->op_params + 11, sizeof(float));
|
|
|
+ memcpy(&xpos_down, (int32_t *) tensor->op_params + 12, sizeof(bool));
|
|
|
|
|
|
src0->grad = ggml_add_or_set(ctx,
|
|
|
src0->grad,
|
|
|
@@ -15609,14 +15435,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
|
src1,
|
|
|
n_dims,
|
|
|
mode,
|
|
|
- 0,
|
|
|
n_ctx,
|
|
|
+ n_orig_ctx,
|
|
|
freq_base,
|
|
|
freq_scale,
|
|
|
- 0.0f,
|
|
|
- 1.0f,
|
|
|
- 0.0f,
|
|
|
- 0.0f,
|
|
|
+ ext_factor,
|
|
|
+ attn_factor,
|
|
|
+ beta_fast,
|
|
|
+ beta_slow,
|
|
|
xpos_base,
|
|
|
xpos_down,
|
|
|
false),
|