Explorar o código

ggml-cpu: templateify ggml_compute_forward_rope_f32 and _f16 (#16805)

* extract rotate_pairs logic from ggml_compute_forward_rope_f32

* templateify ggml_compute_forward_rope_f32 and _f16

* abort when rope type not supported, remove GLM from test-rope

* add imrope branch to switch

* add rope tests for perf

* Update ggml/src/ggml-cpu/ops.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update ggml/src/ggml-cpu/ops.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
duduta hai 2 meses
pai
achega
73460f6278
Modificáronse 3 ficheiros con 76 adicións e 268 borrados
  1. 54 263
      ggml/src/ggml-cpu/ops.cpp
  2. 16 0
      tests/test-backend-ops.cpp
  3. 6 5
      tests/test-rope.cpp

+ 54 - 263
ggml/src/ggml-cpu/ops.cpp

@@ -5503,194 +5503,28 @@ static void ggml_mrope_cache_init(
     }
 }
 
-static void ggml_compute_forward_rope_f32(
-        const ggml_compute_params * params,
-        ggml_tensor * dst,
-        const bool forward) {
-
-    const ggml_tensor * src0 = dst->src[0];
-    const ggml_tensor * src1 = dst->src[1];
-    const ggml_tensor * src2 = dst->src[2];
-
-    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
-    int sections[4];
-
-    //const int n_past     = ((int32_t *) dst->op_params)[0];
-    const int n_dims     = ((int32_t *) dst->op_params)[1];
-    const int mode       = ((int32_t *) dst->op_params)[2];
-    //const int n_ctx      = ((int32_t *) dst->op_params)[3];
-    const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
-
-    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
-    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
-    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
-    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
-    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
-    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
-    memcpy(&sections,    (int32_t *) dst->op_params + 11, sizeof(int)*4);
-
-    GGML_TENSOR_UNARY_OP_LOCALS
-
-    //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
-    //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
-
-    GGML_ASSERT(nb00 == sizeof(float));
-
-    const int ith = params->ith;
-    const int nth = params->nth;
-
-    const int nr = ggml_nrows(dst);
-
-    GGML_ASSERT(n_dims <= ne0);
-    GGML_ASSERT(n_dims % 2 == 0);
-
-    // rows per thread
-    const int dr = (nr + nth - 1)/nth;
 
-    // row range for this thread
-    const int ir0 = dr*ith;
-    const int ir1 = MIN(ir0 + dr, nr);
-
-    // row index used to determine which thread to use
-    int ir = 0;
-
-    const float theta_scale = powf(freq_base, -2.0f/n_dims);
-
-    float corr_dims[2];
-    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
-
-    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
-    const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;  // ggml_rope_multi, multimodal rotary position embedding
-    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
-    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
-
-    if (is_mrope) {
-        GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
-    }
-
-    if (is_vision) {
-        GGML_ASSERT(n_dims == ne0/2);
-    }
+template<typename T>
+static void rotate_pairs(const int64_t n, const int64_t n_offset, const float * cache, const T * src_data, T * dst_data, const int scale = 2) {
+  for (int64_t i0 = 0; i0 < n; i0 += 2) {
+    const int64_t ic = i0/scale; // hack for GGML_ROPE_TYPE_NORMAL, where we need ic = i0; for all other cases, ic = i0/2
 
-    const float * freq_factors = NULL;
-    if (src2 != NULL) {
-        GGML_ASSERT(src2->type == GGML_TYPE_F32);
-        GGML_ASSERT(src2->ne[0] >= n_dims / 2);
-        freq_factors = (const float *) src2->data;
-    }
+    const float cos_theta = cache[i0 + 0];
+    const float sin_theta = cache[i0 + 1];
 
-    // backward process uses inverse rotation by cos and sin.
-    // cos and sin build a rotation matrix, where the inverse is the transpose.
-    // this essentially just switches the sign of sin.
-    const float sin_sign = forward ? 1.0f : -1.0f;
+    const T * const src = src_data + ic;
+    T * dst             = dst_data + ic;
 
-    const int32_t * pos = (const int32_t *) src1->data;
+    const float x0 = type_conversion_table<T>::to_f32(src[0]);
+    const float x1 = type_conversion_table<T>::to_f32(src[n_offset]);
 
-    for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
-        for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
-
-            float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
-            if (!is_mrope) {
-                const int64_t p = pos[i2];
-                ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
-            }
-            else {
-                const int64_t p_t = pos[i2];
-                const int64_t p_h = pos[i2 + ne2];
-                const int64_t p_w = pos[i2 + ne2 * 2];
-                const int64_t p_e = pos[i2 + ne2 * 3];
-                ggml_mrope_cache_init(
-                    p_t, p_h, p_w, p_e, sections, is_imrope, is_vision,
-                    freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
-            }
-
-            for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
-                if (ir++ < ir0) continue;
-                if (ir   > ir1) break;
-
-                if (is_neox || is_mrope) {
-                    if (is_vision){
-                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
-                            const int64_t ic = i0/2;
-
-                            const float cos_theta = cache[i0 + 0];
-                            const float sin_theta = cache[i0 + 1];
-
-                            const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
-                            float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
-
-                            const float x0 = src[0];
-                            const float x1 = src[n_dims];
-
-                            dst_data[0]      = x0*cos_theta - x1*sin_theta;
-                            dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
-                        }
-                    } else {
-                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
-                            const int64_t ic = i0/2;
-
-                            const float cos_theta = cache[i0 + 0];
-                            const float sin_theta = cache[i0 + 1];
-
-                            const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
-                            float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
-
-                            const float x0 = src[0];
-                            const float x1 = src[n_dims/2];
-
-                            dst_data[0]        = x0*cos_theta - x1*sin_theta;
-                            dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
-                        }
-                    }
-                } else {
-                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
-                        const float cos_theta = cache[i0 + 0];
-                        const float sin_theta = cache[i0 + 1];
-
-                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                              float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-                        const float x0 = src[0];
-                        const float x1 = src[1];
-
-                        dst_data[0] = x0*cos_theta - x1*sin_theta;
-                        dst_data[1] = x0*sin_theta + x1*cos_theta;
-                    }
-                }
-
-                if (is_vision) {
-                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
-                        const int64_t ic = i0/2;
-
-                        const float cos_theta = cache[i0 + 0];
-                        const float sin_theta = cache[i0 + 1];
-
-                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
-                        float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
-
-                        const float x0 = src[0];
-                        const float x1 = src[n_dims];
-
-                        dst_data[0]      = x0*cos_theta - x1*sin_theta;
-                        dst_data[n_dims] = x0*sin_theta + x1*cos_theta;
-                    }
-                } else {
-                    // fill the remain channels with data from src tensor
-                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
-                        const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                        float * dst_data  = (float *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-                        dst_data[0] = src[0];
-                        dst_data[1] = src[1];
-                    }
-                }
-            }
-        }
-    }
+    dst[0]        = type_conversion_table<T>::from_f32(x0*cos_theta - x1*sin_theta);
+    dst[n_offset] = type_conversion_table<T>::from_f32(x0*sin_theta + x1*cos_theta);
+  }
 }
 
-// TODO: deduplicate f16/f32 code
-static void ggml_compute_forward_rope_f16(
+template<typename T> //float or ggml_fp16_t
+static void ggml_compute_forward_rope_flt(
         const ggml_compute_params * params,
         ggml_tensor * dst,
         const bool forward) {
@@ -5699,6 +5533,9 @@ static void ggml_compute_forward_rope_f16(
     const ggml_tensor * src1 = dst->src[1];
     const ggml_tensor * src2 = dst->src[2];
 
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+
     float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
     int sections[4];
 
@@ -5707,6 +5544,7 @@ static void ggml_compute_forward_rope_f16(
     const int mode       = ((int32_t *) dst->op_params)[2];
     //const int n_ctx      = ((int32_t *) dst->op_params)[3];
     const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
+
     memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
     memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
     memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
@@ -5715,13 +5553,13 @@ static void ggml_compute_forward_rope_f16(
     memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
     memcpy(&sections,    (int32_t *) dst->op_params + 11, sizeof(int)*4);
 
-
     GGML_TENSOR_UNARY_OP_LOCALS
 
     //printf("ne0: %d, ne1: %d, ne2: %d, ne3: %d\n", ne0, ne1, ne2, ne3);
     //printf("n_past = %d, ne2 = %d\n", n_past, ne2);
 
-    GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb0 == nb00);
+    GGML_ASSERT(nb0 == sizeof(T));
 
     const int ith = params->ith;
     const int nth = params->nth;
@@ -5746,12 +5584,11 @@ static void ggml_compute_forward_rope_f16(
     float corr_dims[2];
     ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
 
-    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
-    const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
-    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
+    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
+    const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;  // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
     const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
 
-    if (is_mrope) {
+    if (mrope_used) {
         GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
     }
 
@@ -5773,11 +5610,11 @@ static void ggml_compute_forward_rope_f16(
 
     const int32_t * pos = (const int32_t *) src1->data;
 
-    for (int64_t i3 = 0; i3 < ne3; i3++) {
-        for (int64_t i2 = 0; i2 < ne2; i2++) {
+    for (int64_t i3 = 0; i3 < ne3; i3++) { // batch
+        for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len
 
             float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith;
-            if (!is_mrope) {
+            if (!mrope_used) {
                 const int64_t p = pos[i2];
                 ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
             }
@@ -5791,86 +5628,40 @@ static void ggml_compute_forward_rope_f16(
                     freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale);
             }
 
-            for (int64_t i1 = 0; i1 < ne1; i1++) {
+            for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads
                 if (ir++ < ir0) continue;
                 if (ir   > ir1) break;
 
-                if (is_neox || is_mrope) {
-                    if (is_vision) {
-                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
-                            const int64_t ic = i0/2;
-
-                            const float cos_theta = cache[i0 + 0];
-                            const float sin_theta = cache[i0 + 1];
-
-                            const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
-                            ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
-
-                            const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
-                            const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
-
-                            dst_data[0]      = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
-                            dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
-                        }
-                    } else {
-                        for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
-                            const int64_t ic = i0/2;
-
-                            const float cos_theta = cache[i0 + 0];
-                            const float sin_theta = cache[i0 + 1];
-
-                            const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
-                            ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
-
-                            const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
-                            const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims/2]);
-
-                            dst_data[0]        = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
-                            dst_data[n_dims/2] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
-                        }
-                    }
-                } else {
-                    for (int64_t i0 = 0; i0 < n_dims; i0 += 2) {
-                        const float cos_theta = cache[i0 + 0];
-                        const float sin_theta = cache[i0 + 1];
-
-                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                              ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
-
-                        const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
-                        const float x1 = GGML_CPU_FP16_TO_FP32(src[1]);
-
-                        dst_data[0] = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
-                        dst_data[1] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
-                    }
+                T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
+                T * dst_data  = (T *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1);
+
+                switch (mode) {
+                    case GGML_ROPE_TYPE_NORMAL:
+                        rotate_pairs<T>(n_dims, 1, cache, src, dst_data, 1);
+                        break;
+                    case GGML_ROPE_TYPE_NEOX:
+                    case GGML_ROPE_TYPE_MROPE:
+                    case GGML_ROPE_TYPE_IMROPE:
+                        rotate_pairs<T>(n_dims, n_dims/2, cache, src, dst_data);
+                        break;
+                    case GGML_ROPE_TYPE_VISION:
+                        rotate_pairs<T>(ne0, n_dims, cache, src, dst_data);
+                        break;
+                    default:
+                        GGML_ABORT("rope type not supported");
                 }
 
-                if (is_vision) {
-                    for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
-                        const int64_t ic = i0/2;
-
-                        const float cos_theta = cache[i0 + 0];
-                        const float sin_theta = cache[i0 + 1];
-
-                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
-                        ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
-
-                        const float x0 = GGML_CPU_FP16_TO_FP32(src[0]);
-                        const float x1 = GGML_CPU_FP16_TO_FP32(src[n_dims]);
-
-                        dst_data[0]      = GGML_CPU_FP32_TO_FP16(x0*cos_theta - x1*sin_theta);
-                        dst_data[n_dims] = GGML_CPU_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
-                    }
-                } else {
+                if (!is_vision) {
+                    // fill the remain channels with data from src tensor
                     for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) {
-                        const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-                        ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+                        const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+                        T * dst_data  = (T *)((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
 
                         dst_data[0] = src[0];
                         dst_data[1] = src[1];
                     }
                 }
-            }
+            } //attn-heads
         }
     }
 }
@@ -5884,11 +5675,11 @@ void ggml_compute_forward_rope(
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_rope_f16(params, dst, true);
+                ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, true);
             } break;
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_rope_f32(params, dst, true);
+                ggml_compute_forward_rope_flt<float>(params, dst, true);
             } break;
         default:
             {
@@ -5908,11 +5699,11 @@ void ggml_compute_forward_rope_back(
     switch (src0->type) {
         case GGML_TYPE_F16:
             {
-                ggml_compute_forward_rope_f16(params, dst, false);
+                ggml_compute_forward_rope_flt<ggml_fp16_t>(params, dst, false);
             } break;
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_rope_f32(params, dst, false);
+                ggml_compute_forward_rope_flt<float>(params, dst, false);
             } break;
         default:
             {

+ 16 - 0
tests/test-backend-ops.cpp

@@ -7603,6 +7603,22 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
         test_cases.emplace_back(new test_add_id(GGML_TYPE_F32, GGML_TYPE_F32, 2880, 32, 4, n_token));
     }
 
+    for (bool fw : {true, false}) { // fw == forward
+        for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+            for (bool ff : {false, true}) { // freq_factors
+                for (float v : { 0, 1 }) {
+                    test_cases.emplace_back(new test_rope(type, {128,  32, 512, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // llama 7B
+                    test_cases.emplace_back(new test_rope(type, {128,  64, 512, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // llama 65B
+                    test_cases.emplace_back(new test_rope(type, { 80,  32, 512, 1},  20, GGML_ROPE_TYPE_NEOX, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // neox (stablelm)
+                    test_cases.emplace_back(new test_rope(type, { 64,   8, 512, 1},  64, GGML_ROPE_TYPE_NEOX, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // neox (falcon 40B)
+                    test_cases.emplace_back(new test_rope(type, {128,  12, 512, 1}, 128, GGML_ROPE_TYPE_MROPE,  512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B)
+                    test_cases.emplace_back(new test_rope(type, {128,  12, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE,  512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,imrope (qwen3vl 2B)
+                    test_cases.emplace_back(new test_rope(type, { 80,  16, 2, 1},  80, GGML_ROPE_TYPE_VISION, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT)
+                }
+            }
+        }
+    }
+
     std::vector<std::array<int64_t, 4>> reduce_rows_cases = {
         { 8192, 1,    1, 1 },
         { 8192, 8192, 1, 1 },

+ 6 - 5
tests/test-rope.cpp

@@ -138,7 +138,7 @@ int main(int /*argc*/, const char ** /*argv*/) {
     struct ggml_tensor * x;
 
     // rope f32
-    for (int m = 0; m < 6; ++m) {
+    for (int m = 0; m < 5; ++m) {
         const int ndims = 4;
 
         const int64_t n_rot = 128;
@@ -153,7 +153,7 @@ int main(int /*argc*/, const char ** /*argv*/) {
         x = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f);
         int mode = -1;
 
-        if (m < 3) {
+        if (m < 2) {
             struct ggml_tensor * p0 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
             struct ggml_tensor * p1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
             struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2]);
@@ -163,8 +163,8 @@ int main(int /*argc*/, const char ** /*argv*/) {
                 ((int32_t *) p1->data)[i] = n_past_2 - n_past_0;
                 ((int32_t *) p2->data)[i] = n_past_2 + i;
             }
-            // test mode 0, 2, 4 (standard, GPT-NeoX, GLM)
-            mode = m == 0 ? 0 : m == 1 ? 2 : 4;
+            // test mode 0, 2  (standard, GPT-NeoX)
+            mode = m == 0 ? GGML_ROPE_TYPE_NORMAL : GGML_ROPE_TYPE_NEOX;
 
             // 100, 101, 102, ..., 172
             r0 = ggml_rope(ctx0, x,  p0, n_rot, mode);
@@ -180,7 +180,8 @@ int main(int /*argc*/, const char ** /*argv*/) {
             struct ggml_tensor * p2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne[2] * 4);
 
             int sections[4] = {16, 24, 24, 0};
-            mode = (m == 3) ? GGML_ROPE_TYPE_MROPE : (m == 4) ? GGML_ROPE_TYPE_VISION : GGML_ROPE_TYPE_IMROPE;
+
+            mode = (m == 2) ? GGML_ROPE_TYPE_MROPE : (m == 3) ? GGML_ROPE_TYPE_VISION : GGML_ROPE_TYPE_IMROPE;
 
             for (int i = 0; i < ne[2]; ++i) {
                 for (int j = 0; j < 4; ++j) {