|
@@ -5503,194 +5503,28 @@ static void ggml_mrope_cache_init(
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-static void ggml_compute_forward_rope_f32(
|
|
|
|
|
- const ggml_compute_params * params,
|
|
|
|
|
- ggml_tensor * dst,
|
|
|
|
|
- const bool forward) {
|
|
|
|
|
-
|
|
|
|
|
- const ggml_tensor * src0 = dst->src[0];
|
|
|
|
|
- const ggml_tensor * src1 = dst->src[1];
|
|
|
|
|
- const ggml_tensor * src2 = dst->src[2];
|
|
|
|
|
-
|
|
|
|
|
- float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
|
|
|
- int sections[4];
|
|
|
|
|
-
|
|
|
|
|
- //const int n_past = ((int32_t *) dst->op_params)[0];
|
|
|
|
|
- const int n_dims = ((int32_t *) dst->op_params)[1];
|
|
|
|
|
- const int mode = ((int32_t *) dst->op_params)[2];
|
|
|
|
|
- //const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
|
|
|
- const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
|
|
|
-
|
|
|
|
|
- memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
|
|
|
- memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
|
|
|
- memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
|
|
|
- memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
|
|
|
|
- memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
|
|
|
- memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
|
|
|
- memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
|
|
|
-
|
|
|
|
|
- GGML_TENSOR_UNARY_OP_LOCALS
|
|
|
|
|
-
|
|
|
|
|
- //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
|
|
|
|
- //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
|
|
|
|
-
|
|
|
|
|
- GGML_ASSERT(nb00 == sizeof(float));
|
|
|
|
|
-
|
|
|
|
|
- const int ith = params->ith;
|
|
|
|
|
- const int nth = params->nth;
|
|
|
|
|
-
|
|
|
|
|
- const int nr = ggml_nrows(dst);
|
|
|
|
|
-
|
|
|
|
|
- GGML_ASSERT(n_dims <= ne0);
|
|
|
|
|
- GGML_ASSERT(n_dims % 2 == 0);
|
|
|
|
|
-
|
|
|
|
|
- // rows per thread
|
|
|
|
|
- const int dr = (nr + nth - 1)/nth;
|
|
|
|
|
|
|
|
|
|
- // row range for this thread
|
|
|
|
|
- const int ir0 = dr*ith;
|
|
|
|
|
- const int ir1 = MIN(ir0 + dr, nr);
|
|
|
|
|
-
|
|
|
|
|
- // row index used to determine which thread to use
|
|
|
|
|
- int ir = 0;
|
|
|
|
|
-
|
|
|
|
|
- const float theta_scale = powf(freq_base, -2.0f/n_dims);
|
|
|
|
|
-
|
|
|
|
|
- float corr_dims[2];
|
|
|
|
|
- ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
|
|
|
-
|
|
|
|
|
- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
|
|
|
|
- const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, multimodal rotary position embedding
|
|
|
|
|
- const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
|
|
|
|
|
- const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
|
|
|
|
-
|
|
|
|
|
- if (is_mrope) {
|
|
|
|
|
- GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if (is_vision) {
|
|
|
|
|
- GGML_ASSERT(n_dims == ne0/2);
|
|
|
|
|
- }
|
|
|
|
|
|
|
+template<typename T>
|
|
|
|
|
+static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
|
|
|
|
|
+ for (int64_t i0 = 0; i0 < n; i0 += 2) {
|
|
|
|
|
+ const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
|
|
|
|
|
|
|
|
- const float * freq_factors = NULL;
|
|
|
|
|
- if (src2 != NULL) {
|
|
|
|
|
- GGML_ASSERT(src2->type == GGML_TYPE_F32);
|
|
|
|
|
- GGML_ASSERT(src2->ne[0] >= n_dims / 2);
|
|
|
|
|
- freq_factors = (const float *) src2->data;
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ const float cos_theta = cache[i0 + 0];
|
|
|
|
|
+ const float sin_theta = cache[i0 + 1];
|
|
|
|
|
|
|
|
- // backward process uses inverse rotation by cos and sin.
|
|
|
|
|
- // cos and sin build a rotation matrix, where the inverse is the transpose.
|
|
|
|
|
- // this essentially just switches the sign of sin.
|
|
|
|
|
- const float sin_sign = forward ? 1.0f : -1.0f;
|
|
|
|
|
|
|
+ const T * const src = src_data + ic;
|
|
|
|
|
+ T * dst = dst_data + ic;
|
|
|
|
|
|
|
|
- const int32_t * pos = (const int32_t *) src1->data;
|
|
|
|
|
|
|
+ const float x0 = type_conversion_table<T>::to_f32(src[0]);
|
|
|
|
|
+ const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
|
|
|
|
|
|
|
|
- for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
|
|
|
|
|
- for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
|
|
|
|
-
|
|
|
|
|
- float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
|
|
|
- if (!is_mrope) {
|
|
|
|
|
- const int64_t p = pos[i2];
|
|
|
|
|
- ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
|
|
|
- }
|
|
|
|
|
- else {
|
|
|
|
|
- const int64_t p_t = pos[i2];
|
|
|
|
|
- const int64_t p_h = pos[i2 + ne2];
|
|
|
|
|
- const int64_t p_w = pos[i2 + ne2 * 2];
|
|
|
|
|
- const int64_t p_e = pos[i2 + ne2 * 3];
|
|
|
|
|
- ggml_mrope_cache_init(
|
|
|
|
|
- p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
|
|
|
|
|
- freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
|
|
|
|
- if (ir++ < ir0) continue;
|
|
|
|
|
- if (ir > ir1) break;
|
|
|
|
|
-
|
|
|
|
|
- if (is_neox || is_mrope) {
|
|
|
|
|
- if (is_vision){
|
|
|
|
|
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
|
|
|
- const int64_t ic = i0/2;
|
|
|
|
|
-
|
|
|
|
|
- const float cos_theta = cache[i0 + 0];
|
|
|
|
|
- const float sin_theta = cache[i0 + 1];
|
|
|
|
|
-
|
|
|
|
|
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
|
|
|
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
|
|
|
-
|
|
|
|
|
- const float x0 = src[0];
|
|
|
|
|
- const float x1 = src[n_dims];
|
|
|
|
|
-
|
|
|
|
|
- dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
|
|
|
- dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
|
|
|
- const int64_t ic = i0/2;
|
|
|
|
|
-
|
|
|
|
|
- const float cos_theta = cache[i0 + 0];
|
|
|
|
|
- const float sin_theta = cache[i0 + 1];
|
|
|
|
|
-
|
|
|
|
|
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
|
|
|
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
|
|
|
-
|
|
|
|
|
- const float x0 = src[0];
|
|
|
|
|
- const float x1 = src[n_dims/2];
|
|
|
|
|
-
|
|
|
|
|
- dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
|
|
|
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
|
|
|
- const float cos_theta = cache[i0 + 0];
|
|
|
|
|
- const float sin_theta = cache[i0 + 1];
|
|
|
|
|
-
|
|
|
|
|
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
|
|
|
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
|
|
-
|
|
|
|
|
- const float x0 = src[0];
|
|
|
|
|
- const float x1 = src[1];
|
|
|
|
|
-
|
|
|
|
|
- dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
|
|
|
- dst_data[1] = x0*sin_theta + x1*cos_theta;
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if (is_vision) {
|
|
|
|
|
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
|
|
|
- const int64_t ic = i0/2;
|
|
|
|
|
-
|
|
|
|
|
- const float cos_theta = cache[i0 + 0];
|
|
|
|
|
- const float sin_theta = cache[i0 + 1];
|
|
|
|
|
-
|
|
|
|
|
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
|
|
|
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
|
|
|
-
|
|
|
|
|
- const float x0 = src[0];
|
|
|
|
|
- const float x1 = src[n_dims];
|
|
|
|
|
-
|
|
|
|
|
- dst_data[0] = x0*cos_theta - x1*sin_theta;
|
|
|
|
|
- dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- // fill the remain channels with data from src tensor
|
|
|
|
|
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
|
|
|
- const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
|
|
|
- float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
|
|
-
|
|
|
|
|
- dst_data[0] = src[0];
|
|
|
|
|
- dst_data[1] = src[1];
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ dst[0] = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
|
|
|
|
|
+ dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
|
|
|
|
|
+ }
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// TODO: deduplicate f16/f32 code
|
|
|
|
|
-static void ggml_compute_forward_rope_f16(
|
|
|
|
|
|
|
+template<typename T> //float or ggml_fp16_t
|
|
|
|
|
+static void ggml_compute_forward_rope_flt(
|
|
|
const ggml_compute_params * params,
|
|
const ggml_compute_params * params,
|
|
|
ggml_tensor * dst,
|
|
ggml_tensor * dst,
|
|
|
const bool forward) {
|
|
const bool forward) {
|
|
@@ -5699,6 +5533,9 @@ static void ggml_compute_forward_rope_f16(
|
|
|
const ggml_tensor * src1 = dst->src[1];
|
|
const ggml_tensor * src1 = dst->src[1];
|
|
|
const ggml_tensor * src2 = dst->src[2];
|
|
const ggml_tensor * src2 = dst->src[2];
|
|
|
|
|
|
|
|
|
|
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
|
|
|
|
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
|
|
|
|
+
|
|
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
|
int sections[4];
|
|
int sections[4];
|
|
|
|
|
|
|
@@ -5707,6 +5544,7 @@ static void ggml_compute_forward_rope_f16(
|
|
|
const int mode = ((int32_t *) dst->op_params)[2];
|
|
const int mode = ((int32_t *) dst->op_params)[2];
|
|
|
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
//const int n_ctx = ((int32_t *) dst->op_params)[3];
|
|
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
|
|
|
+
|
|
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
|
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
|
@@ -5715,13 +5553,13 @@ static void ggml_compute_forward_rope_f16(
|
|
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
|
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int)*4);
|
|
|
|
|
|
|
|
-
|
|
|
|
|
GGML_TENSOR_UNARY_OP_LOCALS
|
|
GGML_TENSOR_UNARY_OP_LOCALS
|
|
|
|
|
|
|
|
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
|
//printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
|
|
|
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
|
//printf("n_past = %d, ne2 = %d\n", n_past, ne2);
|
|
|
|
|
|
|
|
- GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
|
|
|
|
|
|
|
+ GGML_ASSERT(nb0 == nb00);
|
|
|
|
|
+ GGML_ASSERT(nb0 == sizeof(T));
|
|
|
|
|
|
|
|
const int ith = params->ith;
|
|
const int ith = params->ith;
|
|
|
const int nth = params->nth;
|
|
const int nth = params->nth;
|
|
@@ -5746,12 +5584,11 @@ static void ggml_compute_forward_rope_f16(
|
|
|
float corr_dims[2];
|
|
float corr_dims[2];
|
|
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
|
|
|
|
|
|
|
|
- const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
|
|
|
|
- const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
|
|
|
|
- const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
|
|
|
|
|
|
|
+ const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
|
|
|
|
|
+ const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
|
|
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
|
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
|
|
|
|
|
|
|
- if (is_mrope) {
|
|
|
|
|
|
|
+ if (mrope_used) {
|
|
|
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
|
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
@@ -5773,11 +5610,11 @@ static void ggml_compute_forward_rope_f16(
|
|
|
|
|
|
|
|
const int32_t * pos = (const int32_t *) src1->data;
|
|
const int32_t * pos = (const int32_t *) src1->data;
|
|
|
|
|
|
|
|
- for (int64_t i3 = 0; i3 < ne3; i3++) {
|
|
|
|
|
- for (int64_t i2 = 0; i2 < ne2; i2++) {
|
|
|
|
|
|
|
+ for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
|
|
|
|
|
+ for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
|
|
|
|
|
|
|
|
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
|
|
|
- if (!is_mrope) {
|
|
|
|
|
|
|
+ if (!mrope_used) {
|
|
|
const int64_t p = pos[i2];
|
|
const int64_t p = pos[i2];
|
|
|
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
|
}
|
|
}
|
|
@@ -5791,86 +5628,40 @@ static void ggml_compute_forward_rope_f16(
|
|
|
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- for (int64_t i1 = 0; i1 < ne1; i1++) {
|
|
|
|
|
|
|
+ for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
|
|
|
if (ir++ < ir0) continue;
|
|
if (ir++ < ir0) continue;
|
|
|
if (ir > ir1) break;
|
|
if (ir > ir1) break;
|
|
|
|
|
|
|
|
- if (is_neox || is_mrope) {
|
|
|
|
|
- if (is_vision) {
|
|
|
|
|
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
|
|
|
- const int64_t ic = i0/2;
|
|
|
|
|
-
|
|
|
|
|
- const float cos_theta = cache[i0 + 0];
|
|
|
|
|
- const float sin_theta = cache[i0 + 1];
|
|
|
|
|
-
|
|
|
|
|
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
|
|
|
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
|
|
|
-
|
|
|
|
|
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
|
|
|
- const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
|
|
|
|
|
-
|
|
|
|
|
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
|
|
|
- dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
|
|
|
- const int64_t ic = i0/2;
|
|
|
|
|
-
|
|
|
|
|
- const float cos_theta = cache[i0 + 0];
|
|
|
|
|
- const float sin_theta = cache[i0 + 1];
|
|
|
|
|
-
|
|
|
|
|
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
|
|
|
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
|
|
|
-
|
|
|
|
|
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
|
|
|
- const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
|
|
|
|
|
-
|
|
|
|
|
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
|
|
|
- dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
- for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
|
|
|
|
|
- const float cos_theta = cache[i0 + 0];
|
|
|
|
|
- const float sin_theta = cache[i0 + 1];
|
|
|
|
|
-
|
|
|
|
|
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
|
|
|
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
|
|
-
|
|
|
|
|
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
|
|
|
- const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
|
|
|
|
|
-
|
|
|
|
|
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
|
|
|
- dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
|
|
|
- }
|
|
|
|
|
|
|
+ T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
|
|
|
|
|
+ T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1);
|
|
|
|
|
+
|
|
|
|
|
+ switch (mode) {
|
|
|
|
|
+ case GGML_ROPE_TYPE_NORMAL:
|
|
|
|
|
+ rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
|
|
|
|
|
+ break;
|
|
|
|
|
+ case GGML_ROPE_TYPE_NEOX:
|
|
|
|
|
+ case GGML_ROPE_TYPE_MROPE:
|
|
|
|
|
+ case GGML_ROPE_TYPE_IMROPE:
|
|
|
|
|
+ rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
|
|
|
|
|
+ break;
|
|
|
|
|
+ case GGML_ROPE_TYPE_VISION:
|
|
|
|
|
+ rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
|
|
|
|
|
+ break;
|
|
|
|
|
+ default:
|
|
|
|
|
+ GGML_ABORT("rope type not supported");
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- if (is_vision) {
|
|
|
|
|
- for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
|
|
|
- const int64_t ic = i0/2;
|
|
|
|
|
-
|
|
|
|
|
- const float cos_theta = cache[i0 + 0];
|
|
|
|
|
- const float sin_theta = cache[i0 + 1];
|
|
|
|
|
-
|
|
|
|
|
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
|
|
|
|
|
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
|
|
|
|
|
-
|
|
|
|
|
- const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
|
|
|
|
|
- const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
|
|
|
|
|
-
|
|
|
|
|
- dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
|
|
|
|
|
- dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
|
|
|
|
|
- }
|
|
|
|
|
- } else {
|
|
|
|
|
|
|
+ if (!is_vision) {
|
|
|
|
|
+ // fill the remain channels with data from src tensor
|
|
|
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
|
|
|
- const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
|
|
|
- ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
|
|
|
|
+ const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
|
|
|
|
|
+ T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
|
|
|
|
|
|
dst_data[0] = src[0];
|
|
dst_data[0] = src[0];
|
|
|
dst_data[1] = src[1];
|
|
dst_data[1] = src[1];
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
- }
|
|
|
|
|
|
|
+ } //attn-heads
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
@@ -5884,11 +5675,11 @@ void ggml_compute_forward_rope(
|
|
|
switch (src0->type) {
|
|
switch (src0->type) {
|
|
|
case GGML_TYPE_F16:
|
|
case GGML_TYPE_F16:
|
|
|
{
|
|
{
|
|
|
- ggml_compute_forward_rope_f16(params, dst, true);
|
|
|
|
|
|
|
+ ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
|
|
|
} break;
|
|
} break;
|
|
|
case GGML_TYPE_F32:
|
|
case GGML_TYPE_F32:
|
|
|
{
|
|
{
|
|
|
- ggml_compute_forward_rope_f32(params, dst, true);
|
|
|
|
|
|
|
+ ggml_compute_forward_rope_flt<float>(params, dst, true);
|
|
|
} break;
|
|
} break;
|
|
|
default:
|
|
default:
|
|
|
{
|
|
{
|
|
@@ -5908,11 +5699,11 @@ void ggml_compute_forward_rope_back(
|
|
|
switch (src0->type) {
|
|
switch (src0->type) {
|
|
|
case GGML_TYPE_F16:
|
|
case GGML_TYPE_F16:
|
|
|
{
|
|
{
|
|
|
- ggml_compute_forward_rope_f16(params, dst, false);
|
|
|
|
|
|
|
+ ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
|
|
|
} break;
|
|
} break;
|
|
|
case GGML_TYPE_F32:
|
|
case GGML_TYPE_F32:
|
|
|
{
|
|
{
|
|
|
- ggml_compute_forward_rope_f32(params, dst, false);
|
|
|
|
|
|
|
+ ggml_compute_forward_rope_flt<float>(params, dst, false);
|
|
|
} break;
|
|
} break;
|
|
|
default:
|
|
default:
|
|
|
{
|
|
{
|