Преглед изворни кода

CANN: Add MROPE and IMROPE support (#17401)

* CANN: ROPE supports both MROPE and IMROPE.

1. Optimize the caching logic of rope_cache_init.
2. Add support for mRoPE and i-mRoPE.

Note that on Ascend 910B devices, it is necessary to disable FA
in CLIP and disable NZ-format conversion. These two issues are
still under investigation.

* Resolve review comments
hipudding пре 1 месец
родитељ
комит
eeb5605de2
3 измењених фајлова са 411 додато и 170 уклоњено
  1. 335 149
      ggml/src/ggml-cann/aclnn_ops.cpp
  2. 76 14
      ggml/src/ggml-cann/common.h
  3. 0 7
      ggml/src/ggml-cann/ggml-cann.cpp

+ 335 - 149
ggml/src/ggml-cann/aclnn_ops.cpp

@@ -2207,78 +2207,120 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context & ctx,
 }
 
 /**
- * @brief Initializes and caches sine/cosine positional encoding values
- *        (used in RoPE, Rotary Position Embedding) for attention layers.
- *
- * This function computes and caches the sin/cos values of
- * θ = position * theta_scale for RoPE encoding. The cache is shared
- * across attention layers, and only the first attention layer will
- * trigger initialization. The cache includes repeated sin/cos values
- * with different repeat methods depending on the @param is_neox flag.
- *
- * Steps performed by this function:
- *   1. Identify whether the target tensor belongs to Q/K in attention
- *      and restrict computation to the first layer only.
- *   2. Initialize the theta scale array (arange → power → freq scaling).
- *   3. Allocate sin/cos caches if the max prompt length increases.
- *   4. Compute θ = position * theta_scale.
- *   5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
- *   6. Expand sin/cos values by repeat or repeat_interleave depending
- *      on whether @param is_neox is enabled.
- *
- * @param ctx                The CANN backend context, holding memory pool,
- *                           stream, and persistent buffers for rope init/cache.
- * @param dst                The destination ggml_tensor whose computation
- *                           depends on the RoPE values (usually Qcur/Kcur).
- * @param theta_scale        Scalar exponent base for computing theta scale values.
- * @param freq_scale         Frequency scaling factor, applied to theta scale.
- * @param attn_factor        Attention scaling factor, applied to sin/cos.
- * @param is_neox            Whether to use Neox-style repeat strategy
- *                           (dim expansion vs repeat_interleave).
+ * @brief Initializes and caches all intermediate tensors required for RoPE
+ *        (Rotary Position Embedding), including support for Yarn, mRoPE,
+ *        i-mRoPE, Neox repeat strategy, independent sectors, frequency factors,
+ *        and multi-section rotary groups.
+ *
+ * This function computes and caches the per-dimension θ coefficients used for
+ * Q/K rotary embedding. The cache is shared across layers, and recomputed only
+ * when any dependent parameter changes.
+ *
+ * The function now supports:
+ *   - Yarn RoPE extrapolation (via @param corr_dims and @param ext_factor)
+ *   - Per-dimension independent sector exponent rules (indep_sects + sections[])
+ *   - Multi-section RoPE (mRoPE) index mapping (mrope_used + is_imrope)
+ *   - Frequency factor division (src2)
+ *   - Neox / normal repeat expansion modes
+ *
+ * @param ctx                CANN backend context, containing memory pool,
+ *                           cached buffers, and runtime stream.
+ * @param dst                Destination ggml_tensor whose computation
+ *                           depends on RoPE (typically Qcur or Kcur).
+ * @param corr_dims          [low, high] Yarn correction range.
+ * @param ext_factor         Yarn extrapolation strength. 0 = disabled.
+ * @param theta_scale        Base multiplier for per-dimension θ exponent.
+ * @param freq_scale         Global frequency scaling factor.
+ * @param attn_factor        Optional scaling applied to sin/cos (if needed).
+ * @param is_neox            Whether to use Neox-style dimension interleave.
+ * @param sections           4-way sector sizes for independent-section RoPE
+ *                           and multi-section mRoPE (t/h/w/e).
+ * @param mrope_used         Whether to enable multi-section rotary embedding.
+ * @param is_imrope          Whether to apply interleaved mRoPE rules.
+ * @param indep_sects        Whether each dimension runs independent exponent
+ *                           resets based on @p sections.
  */
-static void aclnn_cache_init(ggml_backend_cann_context & ctx,
-                             ggml_tensor *               dst,
-                             float *                     corr_dims,
-                             float                       ext_factor,
-                             float                       theta_scale,
-                             float                       freq_scale,
-                             float                       attn_factor,
-                             bool                        is_neox) {
+static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
+                                  ggml_tensor *               dst,
+                                  float *                     corr_dims,
+                                  float                       ext_factor,
+                                  float                       theta_scale,
+                                  float                       freq_scale,
+                                  float                       attn_factor,
+                                  bool                        is_neox,
+                                  int                         sections[4],
+                                  bool                        mrope_used,
+                                  bool                        is_imrope,
+                                  bool                        indep_sects) {
     ggml_tensor * src0 = dst->src[0];  // input
     ggml_tensor * src1 = dst->src[1];  // position
     ggml_tensor * src2 = dst->src[2];  // freq_factors
 
-    if (src2 == nullptr && ctx.rope_cache.cached && ctx.rope_cache.ext_factor == ext_factor &&
-        ctx.rope_cache.theta_scale == theta_scale && ctx.rope_cache.freq_scale == freq_scale &&
-        ctx.rope_cache.attn_factor == attn_factor && ctx.rope_cache.is_neox == is_neox) {
+    int64_t theta_scale_length = src0->ne[0] / 2;
+    int64_t position_length    = dst->ne[2];
+
+    // TODO: check theta_scale_length and position_length.
+    if (src2 == nullptr && ctx.rope_cache.cached &&
+        ctx.rope_cache.equal(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor,
+                             is_neox, indep_sects, mrope_used, is_imrope, sections)) {
         // use cache.
         return;
     }
 
-    int64_t theta_scale_length = src0->ne[0] / 2;
-    int64_t theta_scale_ne[]   = { theta_scale_length, 1, 1, 1 };
-    size_t  theta_scale_nb[]   = { sizeof(float), sizeof(float), sizeof(float), theta_scale_length * sizeof(float) };
+    // Step0: calculate tensor shape.
+    int64_t theta_scale_ne[] = { theta_scale_length, 1, 1, 1 };
+    size_t  theta_scale_nb[] = { sizeof(float), theta_scale_length * sizeof(float), theta_scale_length * sizeof(float),
+                                 theta_scale_length * sizeof(float) };
 
     GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    int64_t position_length = src1->ne[0];
-    int64_t position_ne[]   = { 1, 1, position_length, 1 };
-    size_t  position_nb[]   = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length };
+    int64_t position_ne[] = { 1, 1, position_length, 1 };
+    size_t  position_nb[] = { sizeof(int32_t), sizeof(int32_t), sizeof(int32_t), sizeof(int32_t) * position_length };
 
-    int64_t theta_ne[] = { theta_scale_length, 1, position_length, 1 };
-    size_t  theta_nb[GGML_MAX_DIMS];
-    theta_nb[0] = sizeof(float);
+    int64_t cache_ne[] = { theta_scale_length, 1, position_length, 1 };
+    size_t  cache_nb[GGML_MAX_DIMS];
+    cache_nb[0] = sizeof(float);
     for (int i = 1; i < GGML_MAX_DIMS; i++) {
-        theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
+        cache_nb[i] = cache_nb[i - 1] * cache_ne[i - 1];
     }
 
-    // theta_scale arange, [0,1,...,ne00/2 - 1]
+    // Step1: Compute the coefficient of theta. During the cache_init process, aside from
+    // (1) multiplying by the position,
+    // (2) dividing by freq_factors,
+    // (3) computing the sine and cosine,
+    // the other parameters used in the computation generally do not change in most scenarios.
+    // Therefore, we can first compute this part of the result and then cache it.
+
+    // Step1.1: prepare theta_scale exponent. if this exponent updated, should update theta_scale_tensor.
     acl_tensor_ptr acl_theta_scale_tensor;
-    // cache theta scale
-    if (ctx.rope_cache.theta_scale_length != theta_scale_length ||
-        // theta_scale and freq_scale should not change during the current token inference process,
-        // so we can directly use == here instead of comparing the absolute difference.
-        ctx.rope_cache.theta_scale != theta_scale || ctx.rope_cache.freq_scale != freq_scale) {
-        ctx.rope_cache.theta_scale_length = theta_scale_length;
+    bool           theta_scale_updated = false;
+    if (ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.theta_scale != theta_scale ||
+        ctx.rope_cache.indep_sects != indep_sects) {
+        theta_scale_updated = true;
+        if (ctx.rope_cache.theta_scale_exp_host != nullptr) {
+            free(ctx.rope_cache.theta_scale_exp_host);
+        }
+        ctx.rope_cache.theta_scale_exp_host = (float *) malloc(theta_scale_length * sizeof(float));
+        GGML_ASSERT(ctx.rope_cache.theta_scale_exp_host != nullptr);
+        if (!indep_sects) {
+            ctx.rope_cache.theta_scale_exp_host[0] = 1;
+            for (int i = 1; i < theta_scale_length; i++) {
+                ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale;
+            }
+        } else {
+            int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
+            int sec_w     = sections[1] + sections[0];
+            int sec_e     = sections[2] + sec_w;
+
+            ctx.rope_cache.theta_scale_exp_host[0] = 1;
+            for (int i = 1; i < theta_scale_length; i++) {
+                int sector = i % sect_dims;
+                if (sector == 0 || sector == sections[0] || sector == sec_w || sector == sec_e) {
+                    ctx.rope_cache.theta_scale_exp_host[i] = 1;
+                    continue;
+                }
+                ctx.rope_cache.theta_scale_exp_host[i] = ctx.rope_cache.theta_scale_exp_host[i - 1] * theta_scale;
+            }
+        }
 
         if (ctx.rope_cache.theta_scale_cache != nullptr) {
             ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
@@ -2286,74 +2328,138 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
         ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),
                               ACL_MEM_MALLOC_HUGE_FIRST));
 
+        ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),
+                                   ctx.rope_cache.theta_scale_exp_host, theta_scale_length * sizeof(float),
+                                   ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream()));
+
         acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
                                                          theta_scale_ne, theta_scale_nb, 1);
+    }
 
-        float start      = 0;
-        float step       = 1;
-        float stop       = theta_scale_length;
-        float n_elements = theta_scale_length;
-        aclnn_arange(ctx, acl_theta_scale_tensor.get(), start, stop, step, n_elements);
-
-        ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
-        acl_tensor_ptr       acl_yarn_ramp_tensor;
-        if (ext_factor != 0) {
-            // -rope_yarn_ramp
-            // const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
-            // return MIN(1, MAX(0, y)) - 1;
-            yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
-            void * yarn_ramp_buffer = yarn_ramp_allocator.get();
-            acl_yarn_ramp_tensor =
-                ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
-            float          zero_value = 0, one_value = 1;
-            float          denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
-            acl_scalar_ptr low              = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
-            acl_scalar_ptr zero             = ggml_cann_create_scalar(&zero_value, aclDataType::ACL_FLOAT);
-            acl_scalar_ptr one              = ggml_cann_create_scalar(&one_value, aclDataType::ACL_FLOAT);
-            acl_scalar_ptr denom_safe       = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT);
-            acl_scalar_ptr ext_factor_sc    = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT);
-
-            GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor.get(), low.get(), one.get(),
-                                    acl_yarn_ramp_tensor.get());
-            GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get());
-            GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get());
-            GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get());
-            GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), one.get(), one.get());
-            GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), ext_factor_sc.get());
-
-            // theta_interp = freq_scale * theta_extrap;
-            // theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
-            // theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix;
-            // theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix;
-            // theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix);
-            //
-            // we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse
-            // cache freq_scale + (freq_scale - 1) * ramp_mix
-            float          freq_scale_1    = freq_scale - 1;
-            acl_scalar_ptr freq_scale_sc   = ggml_cann_create_scalar(&freq_scale, aclDataType::ACL_FLOAT);
-            acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT);
-            GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
-            GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
-        }
-
-        // power
-        acl_scalar_ptr acl_theta_scale = ggml_cann_create_scalar(&theta_scale, aclDataType::ACL_FLOAT);
-        GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale.get(), acl_theta_scale_tensor.get(),
-                                acl_theta_scale_tensor.get());
-
-        if (ext_factor != 0) {
+    // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
+    bool                 yarn_ramp_tensor_updated = false;
+    ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
+    acl_tensor_ptr       acl_yarn_ramp_tensor;
+    if (ext_factor != 0 &&
+        // TODO: check more parameter.
+        (ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) {
+        yarn_ramp_tensor_updated = true;
+
+        // -rope_yarn_ramp
+        // const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
+        // return MIN(1, MAX(0, y)) - 1;
+        yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
+        void * yarn_ramp_buffer = yarn_ramp_allocator.get();
+        acl_yarn_ramp_tensor =
+            ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
+        float          zero_value = 0, one_value = 1;
+        float          denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
+        acl_scalar_ptr low              = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
+        acl_scalar_ptr zero             = ggml_cann_create_scalar(&zero_value, aclDataType::ACL_FLOAT);
+        acl_scalar_ptr one              = ggml_cann_create_scalar(&one_value, aclDataType::ACL_FLOAT);
+        acl_scalar_ptr denom_safe       = ggml_cann_create_scalar(&denom_safe_value, aclDataType::ACL_FLOAT);
+        acl_scalar_ptr ext_factor_sc    = ggml_cann_create_scalar(&ext_factor, aclDataType::ACL_FLOAT);
+
+        aclnn_arange(ctx, acl_yarn_ramp_tensor.get(), 0, theta_scale_length, 1, theta_scale_length);
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), low.get(), one.get());
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor.get(), denom_safe.get());
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor.get(), zero.get(), zero.get());
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor.get(), one.get());
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor.get(), one.get(), one.get());
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), ext_factor_sc.get());
+
+        // theta_interp = freq_scale * theta_extrap;
+        // theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
+        // theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix;
+        // theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix;
+        // theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix);
+        //
+        // we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse
+        // cache freq_scale + (freq_scale - 1) * ramp_mix
+        float          freq_scale_1    = freq_scale - 1;
+        acl_scalar_ptr freq_scale_sc   = ggml_cann_create_scalar(&freq_scale, aclDataType::ACL_FLOAT);
+        acl_scalar_ptr freq_scale_1_sc = ggml_cann_create_scalar(&freq_scale_1, aclDataType::ACL_FLOAT);
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
+        GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
+    }
+
+    // Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
+    if (ext_factor != 0) {
+        if (theta_scale_updated || yarn_ramp_tensor_updated) {
+            theta_scale_updated = true;
             aclnn_mul(ctx, acl_theta_scale_tensor.get(), acl_yarn_ramp_tensor.get());
-        } else if (freq_scale != 1) {
-            aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true);
         }
     } else {
-        // use cache
+        if (freq_scale != 1 && (ctx.rope_cache.freq_scale != freq_scale || theta_scale_updated)) {
+            theta_scale_updated = true;
+            aclnn_muls(ctx, acl_theta_scale_tensor.get(), freq_scale, nullptr, true);
+        }
+    }
+
+    // Nothing changed, use cache.
+    if (!theta_scale_updated) {
         acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
                                                          theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
     }
 
+    // Step 1.4: prepare select index if mrope
+    acl_tensor_ptr position_select_index_tensor;
+    if (mrope_used) {
+        if (ctx.rope_cache.sections[0] != sections[0] || ctx.rope_cache.sections[1] != sections[1] ||
+            ctx.rope_cache.sections[2] != sections[2] || ctx.rope_cache.sections[3] != sections[3] ||
+            ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.is_imrope != is_imrope) {
+            if (ctx.rope_cache.position_select_index_host != nullptr) {
+                free(ctx.rope_cache.position_select_index_host);
+            }
+            ctx.rope_cache.position_select_index_host = (int *) malloc(theta_scale_length * sizeof(int));
+            GGML_ASSERT(ctx.rope_cache.position_select_index_host != nullptr);
+            int sect_dims = sections[0] + sections[1] + sections[2] + sections[3];
+            int sec_w     = sections[1] + sections[0];
+            int sec_e     = sections[2] + sec_w;
+            // t,h,w,e
+            for (int i = 0; i < theta_scale_length; i++) {
+                int sector = i % sect_dims;
+
+                if (is_imrope) {  // qwen3vl apply interleaved mrope
+                    if (sector % 3 == 1 && sector < 3 * sections[1]) {
+                        ctx.rope_cache.position_select_index_host[i] = 1;
+                    } else if (sector % 3 == 2 && sector < 3 * sections[2]) {
+                        ctx.rope_cache.position_select_index_host[i] = 2;
+                    } else if (sector % 3 == 0 && sector < 3 * sections[0]) {
+                        ctx.rope_cache.position_select_index_host[i] = 0;
+                    } else {
+                        ctx.rope_cache.position_select_index_host[i] = 3;
+                    }
+                } else {
+                    if (sector >= sections[0] && sector < sec_w) {
+                        ctx.rope_cache.position_select_index_host[i] = 1;
+                    } else if (sector >= sec_w && sector < sec_e) {
+                        ctx.rope_cache.position_select_index_host[i] = 2;
+                    } else if (sector >= sec_e) {
+                        ctx.rope_cache.position_select_index_host[i] = 3;
+                    } else {
+                        ctx.rope_cache.position_select_index_host[i] = 0;
+                    }
+                }
+            }
+
+            if (ctx.rope_cache.position_select_index != nullptr) {
+                ACL_CHECK(aclrtFree(ctx.rope_cache.position_select_index));
+            }
+            ACL_CHECK(aclrtMalloc(&ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int),
+                                  ACL_MEM_MALLOC_HUGE_FIRST));
+
+            ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.position_select_index, theta_scale_length * sizeof(int),
+                                       ctx.rope_cache.position_select_index_host, theta_scale_length * sizeof(int),
+                                       ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream()));
+        }
+
+        position_select_index_tensor = ggml_cann_create_tensor(ctx.rope_cache.position_select_index, ACL_INT32,
+                                                               sizeof(int), theta_scale_ne, theta_scale_nb, 1);
+    }
+
+    // Step2: divide by freq_factors
     ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool());
-    // freq_factors
     if (src2) {
         freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float));
         void *         freq_fac_res_ptr = freq_fac_res_allocator.get();
@@ -2366,6 +2472,85 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
         std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor);
     }
 
+    // Step3: prepare position_tensor
+    acl_tensor_ptr       acl_position_tensor;
+    ggml_cann_pool_alloc mrope_position_acllocator(ctx.pool());
+    if (mrope_used) {
+        // Step3.1: select current position;
+        // position :
+        // pos1: [[0, 1 ,2 ,3 ],
+        // pos2:  [4, 5 ,6 ,7 ],
+        // pos3:  [8, 9 ,10,11],
+        // pos4:  [12,13,14,15] ]
+        //
+        // select index = [0, 1, 2, 2, 1, 0]
+        //
+        // selected_tensor:
+        // [[0, 1 ,2 ,3 ],
+        //  [4, 5 ,6 ,7 ],
+        //  [8, 9 ,10,11],
+        //  [8, 9 ,10,11],
+        //  [4, 5 ,6 ,7 ],
+        //  [0, 1 ,2 ,3 ]]
+        //
+        // transpose, from [seq_len:dims] to [dims:seq_len]
+        // [0, 4, 8 ,8 ,4, 0],
+        // [1, 5, 9, 9, 5, 1],
+        // [2, 6, 10,10,6 ,2],
+        // [3, 7, 11,11,7 3 ]]
+        //
+        // multipy by theta_scale_tensor
+        // [theta_scale^0, theta_scale^1, ..., theta_scale ^ n]
+
+        int64_t        mrope_position_ne[] = { position_length, 4 };
+        size_t         mrope_position_nb[] = { sizeof(int), position_length * sizeof(int) };
+        acl_tensor_ptr mrope_position =
+            ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type),
+                                    mrope_position_ne, mrope_position_nb, 2);
+
+        // selected position tensor's shape is a transpose of cache tensor.
+        int64_t selected_position_ne[] = { position_length, theta_scale_length };
+        size_t  selected_position_nb[] = { sizeof(float), position_length * sizeof(float) };
+        mrope_position_acllocator.alloc(theta_scale_length * position_length * sizeof(float));
+        void * mrope_position_buffer = mrope_position_acllocator.get();
+        acl_position_tensor =
+            ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type),
+                                    ggml_type_size(src1->type), selected_position_ne, selected_position_nb, 2);
+        GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, mrope_position.get(), 0, position_select_index_tensor.get(),
+                                acl_position_tensor.get());
+
+        // transpose
+        int64_t transposed_ne[] = { position_length, 1, theta_scale_length, 1 };
+        size_t  transposed_nb[GGML_MAX_DIMS];
+        transposed_nb[0] = sizeof(float);
+        for (int i = 1; i < GGML_MAX_DIMS; i++) {
+            transposed_nb[i] = transposed_nb[i - 1] * transposed_ne[i - 1];
+        }
+
+        std::swap(transposed_ne[0], transposed_ne[2]);
+        std::swap(transposed_nb[0], transposed_nb[2]);
+
+        acl_position_tensor =
+            ggml_cann_create_tensor(mrope_position_buffer, ggml_cann_type_mapping(src1->type),
+                                    ggml_type_size(src1->type), transposed_ne, transposed_nb, GGML_MAX_DIMS);
+
+    } else {
+        // auto bcast.
+        acl_position_tensor =
+            ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type),
+                                    position_ne, position_nb, GGML_MAX_DIMS);
+    }
+
+    // Step4: multiply by the position
+    int64_t              theta_length = theta_scale_length * position_length;
+    ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float));
+    void *               theta_buffer = theta_allocator.get();
+
+    acl_tensor_ptr acl_theta_tensor =
+        ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS);
+    aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get());
+
+    // Step5: calculate sin cos.
     // init sin_repeat && cos_repeat, only to accelerate first layer on each device
     if (position_length > ctx.rope_cache.position_length) {
         ctx.rope_cache.position_length = position_length;
@@ -2382,44 +2567,30 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
             aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
     }
 
-    // position
-    acl_tensor_ptr acl_position_tensor =
-        ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), position_ne,
-                                position_nb, GGML_MAX_DIMS);
-
-    // power * position
-    int64_t              theta_length = theta_scale_length * position_length;
-    ggml_cann_pool_alloc theta_allocator(ctx.pool(), theta_length * sizeof(float));
-    void *               theta_buffer = theta_allocator.get();
-
-    acl_tensor_ptr acl_theta_tensor =
-        ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS);
-    aclnn_mul(ctx, acl_position_tensor.get(), acl_theta_scale_tensor.get(), acl_theta_tensor.get());
-
     // sin/cos
     ggml_cann_pool_alloc sin_allocator(ctx.pool(), theta_length * sizeof(float));
     void *               sin_buffer = sin_allocator.get();
     acl_tensor_ptr       acl_sin_tensor =
-        ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
+        ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
     aclnn_sin(ctx, acl_theta_tensor.get(), acl_sin_tensor.get());
 
     ggml_cann_pool_alloc cos_allocator(ctx.pool(), theta_length * sizeof(float));
     void *               cos_buffer = cos_allocator.get();
     acl_tensor_ptr       acl_cos_tensor =
-        ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
+        ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float), cache_ne, cache_nb, GGML_MAX_DIMS, ACL_FORMAT_ND);
     aclnn_cos(ctx, acl_theta_tensor.get(), acl_cos_tensor.get());
 
     if (ext_factor != 0) {
         attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale);
     }
 
-    // attn_factor
+    // Step 5: multiply by attn_factor
     if (attn_factor != 1) {
         aclnn_muls(ctx, acl_sin_tensor.get(), attn_factor, nullptr, true);
         aclnn_muls(ctx, acl_cos_tensor.get(), attn_factor, nullptr, true);
     }
 
-    int64_t sin_reshape_ne[4] = { src0->ne[0], 1, src0->ne[2], 1 };
+    int64_t sin_reshape_ne[4] = { src0->ne[0], 1, dst->ne[2], 1 };
     size_t  sin_reshape_nb[GGML_MAX_DIMS];
     sin_reshape_nb[0] = sizeof(float);
     for (int i = 1; i < GGML_MAX_DIMS; i++) {
@@ -2430,8 +2601,9 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
     acl_tensor_ptr acl_cos_repeat_tensor = ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),
                                                                    sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
 
-    // repeat
+    // Step 6: repeat
     if (is_neox) {
+        // [sinθ1, sinθ1, sinθ2, sinθ2, ..., sinθn, sinθn]
         int64_t repeatsArray[] = { 1, 1, 1, 2 };
         aclnn_repeat(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), repeatsArray);
         aclnn_repeat(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), repeatsArray);
@@ -2439,17 +2611,15 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
         int64_t num_repeats = 2;
         int64_t dim         = 3;
         int64_t output_size = theta_scale_length * num_repeats;
+        // [sinθ1, sinθ2, ..., sinθn, sinθ1, sinθ2, ..., sinθn]
         aclnn_repeat_interleave(ctx, acl_sin_tensor.get(), acl_sin_repeat_tensor.get(), dim, num_repeats, output_size);
         aclnn_repeat_interleave(ctx, acl_cos_tensor.get(), acl_cos_repeat_tensor.get(), dim, num_repeats, output_size);
     }
 
-    // Other layers use cache except first layer.
-    ctx.rope_cache.cached      = true;
-    ctx.rope_cache.ext_factor  = ext_factor;
-    ctx.rope_cache.theta_scale = theta_scale;
-    ctx.rope_cache.freq_scale  = freq_scale;
-    ctx.rope_cache.attn_factor = attn_factor;
-    ctx.rope_cache.is_neox     = is_neox;
+    // Update cached value.
+    ctx.rope_cache.cached = true;
+    ctx.rope_cache.set(theta_scale_length, position_length, ext_factor, theta_scale, freq_scale, attn_factor, is_neox,
+                       indep_sects, mrope_used, is_imrope, sections);
 }
 
 #ifdef __cplusplus
@@ -2475,6 +2645,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
 
     // param
     float     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+    int sections[4];
     // const int n_past     = ((int32_t *) dst->op_params)[0];
     const int n_dims     = ((int32_t *) dst->op_params)[1];
     const int mode       = ((int32_t *) dst->op_params)[2];
@@ -2483,12 +2654,13 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
-    memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
-    memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
-    memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
-    memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
-    memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
-    memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
+    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
+    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
+    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
+    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
+    memcpy(&sections,    (int32_t *) dst->op_params + 11, sizeof(int)*4);
 
     // TODO: n_dims <= ne0
     GGML_ASSERT(n_dims == ne0);
@@ -2499,10 +2671,25 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
     float corr_dims[2];
     ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
 
-    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+    bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
+    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
+    const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;  // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
+    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+
+    if (mrope_used) {
+        GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
+    }
+
+    if (is_vision) {
+        GGML_ASSERT(n_dims == ne0/2);
+    }
+
+    if (is_imrope || mrope_used) {
+        is_neox = true;
+    }
 
     // init ctx.rope_cos/rope_sin cache
-    aclnn_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox);
+    aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, mrope_used, is_imrope, is_vision);
 
     int64_t sin_reshape_ne[4] = { ne00, 1, ne02, 1 };
     size_t  sin_reshape_nb[GGML_MAX_DIMS];
@@ -2658,8 +2845,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
     return;
 #endif
 
-    // ggml_mode = 0 --> aclnn_model = 1
-    int64_t acl_mode = mode == 0 ? 1 : mode;
+    int64_t acl_mode = is_neox ? 0 : 1;
 
     switch (src0->type) {
         case GGML_TYPE_F32:

+ 76 - 14
ggml/src/ggml-cann/common.h

@@ -300,30 +300,92 @@ struct ggml_cann_graph_lru_cache {
 
 struct ggml_cann_rope_cache {
     ~ggml_cann_rope_cache() {
-        if (theta_scale_cache != nullptr) {
+        if (theta_scale_cache) {
             ACL_CHECK(aclrtFree(theta_scale_cache));
         }
-        if (sin_cache != nullptr) {
+        if (sin_cache) {
             ACL_CHECK(aclrtFree(sin_cache));
         }
-        if (cos_cache != nullptr) {
+        if (cos_cache) {
             ACL_CHECK(aclrtFree(cos_cache));
         }
+        if (position_select_index) {
+            ACL_CHECK(aclrtFree(position_select_index));
+        }
+        if (theta_scale_exp_host) {
+            free(theta_scale_exp_host);
+        }
+        if(position_select_index_host) {
+            free(position_select_index_host);
+        }
+    }
+
+    bool equal(int64_t theta_scale_length,
+               int64_t position_length,
+               float   ext_factor,
+               float   theta_scale,
+               float   freq_scale,
+               float   attn_factor,
+               bool    is_neox,
+               bool    indep_sects,
+               bool    mrope_used,
+               bool    is_imrope,
+               int     sections[4]) {
+        return this->theta_scale_length == theta_scale_length && this->position_length == position_length &&
+               this->ext_factor == ext_factor && this->theta_scale == theta_scale && this->freq_scale == freq_scale &&
+               this->attn_factor == attn_factor && this->is_neox == is_neox && this->indep_sects == indep_sects &&
+               this->mrope_used == mrope_used && this->is_imrope == is_imrope && this->sections[0] == sections[0] &&
+               this->sections[1] == sections[1] && this->sections[2] == sections[2] && this->sections[3] == sections[3];
+    }
+
+    void set(int64_t theta_scale_length,
+             int64_t position_length,
+             float    ext_factor,
+             float   theta_scale,
+             float   freq_scale,
+             float   attn_factor,
+             bool    is_neox,
+             bool    indep_sects,
+             bool    mrope_used,
+             bool    is_imrope,
+             int     sections[4]) {
+        this->theta_scale_length = theta_scale_length;
+        this->position_length    = position_length;
+        this->ext_factor         = ext_factor;
+        this->theta_scale        = theta_scale;
+        this->freq_scale         = freq_scale;
+        this->attn_factor        = attn_factor;
+        this->is_neox            = is_neox;
+        this->indep_sects        = indep_sects;
+        this->mrope_used         = mrope_used;
+        this->is_imrope          = is_imrope;
+        this->sections[0]        = sections[0];
+        this->sections[1]        = sections[1];
+        this->sections[2]        = sections[2];
+        this->sections[3]        = sections[3];
     }
 
-    void *  theta_scale_cache  = nullptr;
-    int64_t theta_scale_length = 0;
+    // memory cache, prepare before inferencing.
+    void *  theta_scale_cache          = nullptr;
+    float * theta_scale_exp_host       = nullptr;
+    int *   position_select_index_host = nullptr;
+    void *  position_select_index      = nullptr;
     // sin/cos cache, used only to accelerate first layer on each device
-    void *  sin_cache          = nullptr;
-    void *  cos_cache          = nullptr;
-    int64_t position_length    = 0;
+    void *  sin_cache                  = nullptr;
+    void *  cos_cache                  = nullptr;
     // Properties to check before reusing the sincos cache
-    bool    cached             = false;
-    float   ext_factor         = 0.0f;
-    float   theta_scale        = 0.0f;
-    float   freq_scale         = 0.0f;
-    float   attn_factor        = 0.0f;
-    bool    is_neox            = false;
+    int64_t theta_scale_length         = 0;
+    int64_t position_length            = 0;
+    bool    cached                     = false;
+    float   ext_factor                 = 0.0f;
+    float   theta_scale                = 0.0f;
+    float   freq_scale                 = 0.0f;
+    float   attn_factor                = 0.0f;
+    bool    is_neox                    = false;
+    bool    indep_sects                = false;
+    bool    mrope_used                 = false;
+    int     sections[4]                = { 0, 0, 0, 0 };
+    bool    is_imrope                  = false;
 };
 
 struct ggml_cann_tensor_cache {

+ 0 - 7
ggml/src/ggml-cann/ggml-cann.cpp

@@ -2480,13 +2480,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
                     return false;
                 }
 
-                const int mode = ((const int32_t *) op->op_params)[2];
-                if (mode & GGML_ROPE_TYPE_MROPE) {
-                    return false;
-                }
-                if (mode & GGML_ROPE_TYPE_VISION) {
-                    return false;
-                }
                 if (op->src[0]->ne[0] > 896) {
                     return false;
                 }