|
|
@@ -6956,9 +6956,9 @@ struct ggml_tensor * ggml_rope_impl(
|
|
|
int n_past,
|
|
|
int n_dims,
|
|
|
int mode,
|
|
|
+ int n_ctx,
|
|
|
float freq_base,
|
|
|
float freq_scale,
|
|
|
- int n_ctx,
|
|
|
bool inplace) {
|
|
|
GGML_ASSERT(n_past >= 0);
|
|
|
bool is_node = false;
|
|
|
@@ -6997,7 +6997,7 @@ struct ggml_tensor * ggml_rope(
|
|
|
int n_dims,
|
|
|
int mode,
|
|
|
int n_ctx) {
|
|
|
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, false);
|
|
|
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, false);
|
|
|
}
|
|
|
|
|
|
struct ggml_tensor * ggml_rope_inplace(
|
|
|
@@ -7007,7 +7007,7 @@ struct ggml_tensor * ggml_rope_inplace(
|
|
|
int n_dims,
|
|
|
int mode,
|
|
|
int n_ctx) {
|
|
|
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true);
|
|
|
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true);
|
|
|
}
|
|
|
|
|
|
struct ggml_tensor * ggml_rope_custom_inplace(
|
|
|
@@ -7016,10 +7016,10 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
|
|
int n_past,
|
|
|
int n_dims,
|
|
|
int mode,
|
|
|
+ int n_ctx,
|
|
|
float freq_base,
|
|
|
- float freq_scale,
|
|
|
- int n_ctx) {
|
|
|
- return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true);
|
|
|
+ float freq_scale) {
|
|
|
+ return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, true);
|
|
|
}
|
|
|
|
|
|
// ggml_rope_back
|
|
|
@@ -7029,7 +7029,8 @@ struct ggml_tensor * ggml_rope_back(
|
|
|
struct ggml_tensor * a,
|
|
|
int n_past,
|
|
|
int n_dims,
|
|
|
- int mode) {
|
|
|
+ int mode,
|
|
|
+ int n_ctx) {
|
|
|
GGML_ASSERT(n_past >= 0);
|
|
|
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
|
|
|
|
|
|
@@ -7043,12 +7044,13 @@ struct ggml_tensor * ggml_rope_back(
|
|
|
|
|
|
ggml_scratch_save(ctx);
|
|
|
|
|
|
- struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
|
|
|
+ struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
|
|
|
ggml_set_name(b, "n_past, n_dims, mode");
|
|
|
|
|
|
((int32_t *) b->data)[0] = n_past;
|
|
|
((int32_t *) b->data)[1] = n_dims;
|
|
|
((int32_t *) b->data)[2] = mode;
|
|
|
+ ((int32_t *) b->data)[3] = n_ctx;
|
|
|
|
|
|
ggml_scratch_load(ctx);
|
|
|
|
|
|
@@ -15740,13 +15742,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
|
const int n_past = ((int32_t *) src1->data)[0];
|
|
|
const int n_dims = ((int32_t *) src1->data)[1];
|
|
|
const int mode = ((int32_t *) src1->data)[2];
|
|
|
+ const int n_ctx = ((int32_t *) src1->data)[3];
|
|
|
src0->grad = ggml_add_impl(ctx,
|
|
|
src0->grad,
|
|
|
ggml_rope_back(ctx,
|
|
|
tensor->grad,
|
|
|
n_past,
|
|
|
n_dims,
|
|
|
- mode),
|
|
|
+ mode,
|
|
|
+ n_ctx),
|
|
|
inplace);
|
|
|
}
|
|
|
if (src1->grad) {
|
|
|
@@ -15757,7 +15761,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|
|
{
|
|
|
if (src0->grad) {
|
|
|
assert(src1->type == GGML_TYPE_I32);
|
|
|
- assert(ggml_nelements(src1) == 3);
|
|
|
+ assert(ggml_nelements(src1) == 4);
|
|
|
const int n_past = ((int32_t *) src1->data)[0];
|
|
|
const int n_dims = ((int32_t *) src1->data)[1];
|
|
|
const int mode = ((int32_t *) src1->data)[2];
|