1
0
Эх сурвалжийг харах

CANN: add support for partial RoPE and Vision mode (#17543)

* cann: add support for partial RoPE and Vision mode

Add support for two important RoPE variants: partial rotation (rope_dims < ne0)
and Vision mode rotation.

1. Support for partial RoPE (rope_dims < ne0):
   - Split tensor into head (first rope_dims dimensions) and tail portions
   - Apply rotation only to head portion using RotaryPositionEmbedding operator
   - Copy unrotated tail portion directly from source to destination
   - Handle both contiguous and non-contiguous tensor layouts

2. Support for Vision mode (GGML_ROPE_TYPE_VISION):
   - Set rope_dims = ne0 for Vision mode to rotate entire tensor
   - Vision mode pairs dimension i with dimension i+n_dims (where n_dims = ne0/2)
   - No tail handling needed since entire tensor is rotated

Implementation details:
   - Use has_tail flag to determine execution path: head/tail splitting when
     rope_dims < ne0, or full tensor rotation when rope_dims == ne0
   - Support both F32 and F16 data types with intermediate F32 conversion
   - Copy non-contiguous tensors to contiguous buffers before calling
     RotaryPositionEmbedding operator for compatibility
   - Improve cache invalidation logic to include rope_dims and indep_sects
     parameters

These enhancements enable CANN backend to handle various RoPE configurations
used in modern vision-language models and models with partial rotation.

* cann: fix review comment
Chenguang Li 1 сар өмнө
parent
commit
ca709e427b

+ 153 - 61
ggml/src/ggml-cann/aclnn_ops.cpp

@@ -2251,12 +2251,12 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
                                   int                         sections[4],
                                   bool                        mrope_used,
                                   bool                        is_imrope,
-                                  bool                        indep_sects) {
-    ggml_tensor * src0 = dst->src[0];  // input
+                                  bool                        indep_sects,
+                                  int64_t                     rope_dims) {
     ggml_tensor * src1 = dst->src[1];  // position
     ggml_tensor * src2 = dst->src[2];  // freq_factors
 
-    int64_t theta_scale_length = src0->ne[0] / 2;
+    int64_t theta_scale_length = rope_dims / 2;
     int64_t position_length    = dst->ne[2];
 
     // TODO: check theta_scale_length and position_length.
@@ -2331,18 +2331,17 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
         ACL_CHECK(aclrtMemcpyAsync(ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float),
                                    ctx.rope_cache.theta_scale_exp_host, theta_scale_length * sizeof(float),
                                    ACL_MEMCPY_HOST_TO_DEVICE, ctx.stream()));
-
-        acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
-                                                         theta_scale_ne, theta_scale_nb, 1);
     }
+    acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
+                                                     theta_scale_ne, theta_scale_nb, 1);
 
     // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
+    // TODO: acl_yarn_ramp_tensor use rope cache.
     bool                 yarn_ramp_tensor_updated = false;
     ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
     acl_tensor_ptr       acl_yarn_ramp_tensor;
-    if (ext_factor != 0 &&
-        // TODO: check more parameter.
-        (ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) {
+    if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache.theta_scale_length != theta_scale_length ||
+                            ctx.rope_cache.freq_scale != freq_scale)) {
         yarn_ramp_tensor_updated = true;
 
         // -rope_yarn_ramp
@@ -2590,7 +2589,7 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
         aclnn_muls(ctx, acl_cos_tensor.get(), attn_factor, nullptr, true);
     }
 
-    int64_t sin_reshape_ne[4] = { src0->ne[0], 1, dst->ne[2], 1 };
+    int64_t sin_reshape_ne[4] = { rope_dims, 1, dst->ne[2], 1 };
     size_t  sin_reshape_nb[GGML_MAX_DIMS];
     sin_reshape_nb[0] = sizeof(float);
     for (int i = 1; i < GGML_MAX_DIMS; i++) {
@@ -2645,7 +2644,7 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
 
     // param
     float     freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
-    int sections[4];
+    int       sections[4];
     // const int n_past     = ((int32_t *) dst->op_params)[0];
     const int n_dims     = ((int32_t *) dst->op_params)[1];
     const int mode       = ((int32_t *) dst->op_params)[2];
@@ -2654,44 +2653,60 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
 
     GGML_TENSOR_UNARY_OP_LOCALS
 
-    memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
-    memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
-    memcpy(&ext_factor,  (int32_t *) dst->op_params +  7, sizeof(float));
-    memcpy(&attn_factor, (int32_t *) dst->op_params +  8, sizeof(float));
-    memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
-    memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
-    memcpy(&sections,    (int32_t *) dst->op_params + 11, sizeof(int)*4);
+    memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
+    memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
+    memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
+    memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
+    memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
+    memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
+    memcpy(&sections, (int32_t *) dst->op_params + 11, sizeof(int) * 4);
 
-    // TODO: n_dims <= ne0
-    GGML_ASSERT(n_dims == ne0);
     GGML_ASSERT(n_dims % 2 == 0);
+    GGML_ASSERT(n_dims <= ne00);
 
     const float theta_scale = powf(freq_base, -2.0f / n_dims);
 
     float corr_dims[2];
     ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
 
-    bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
-    const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; // qwen3vl apply interleaved mrope
-    const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;  // ggml_rope_multi, note: also true for vision (24 & 8 == true) and for imrope
-    const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
+    bool       is_neox    = mode & GGML_ROPE_TYPE_NEOX;
+    const bool is_imrope  = mode == GGML_ROPE_TYPE_IMROPE;  // qwen3vl apply interleaved mrope
+    // mrope_used means the GGML_ROPE_TYPE_MROPE bit is set.
+    // Note: this bit is also set for imrope and some vision modes,
+    // so mrope_used does NOT exclusively indicate pure mrope.
+    const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE;
+    const bool is_vision  = mode == GGML_ROPE_TYPE_VISION;
 
     if (mrope_used) {
         GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
     }
 
     if (is_vision) {
-        GGML_ASSERT(n_dims == ne0/2);
+        GGML_ASSERT(n_dims == ne0 / 2);
     }
 
     if (is_imrope || mrope_used) {
         is_neox = true;
     }
 
+    int64_t rope_dims = n_dims;
+
+    //Our current RotaryPositionEmbedding does not support the VISION mode,
+    //but essentially it only modifies theta_base in mrope,
+    //then repeats it at the end in the same way as is_neox.
+    //In fact, RoPE is still applied across all dimensions.
+    if (is_vision) {
+        rope_dims = src0->ne[0];
+    }
+    int64_t tail_dims = ne00 - rope_dims;
+    bool    has_tail  = tail_dims > 0;
+
     // init ctx.rope_cos/rope_sin cache
-    aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, mrope_used, is_imrope, is_vision);
+    aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections,
+                          mrope_used, is_imrope, is_vision, rope_dims);
 
-    int64_t sin_reshape_ne[4] = { ne00, 1, ne02, 1 };
+    // Cache is generated with ne00 dimensions, so we use ne00 for reshape
+    int64_t sin_reshape_ne[4] = { rope_dims, 1, ne02, 1 };
     size_t  sin_reshape_nb[GGML_MAX_DIMS];
     sin_reshape_nb[0] = sizeof(float);
     for (int i = 1; i < GGML_MAX_DIMS; i++) {
@@ -2704,7 +2719,6 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
 
     acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0);
     acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst);
-
 #ifdef ASCEND_310P
     // Special ROPE operation for 310P
 
@@ -2844,46 +2858,124 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
     }
     return;
 #endif
-
     int64_t acl_mode = is_neox ? 0 : 1;
 
-    switch (src0->type) {
-        case GGML_TYPE_F32:
-            {
-                GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(),
-                                        acl_sin_reshape_tensor.get(), acl_mode, acl_dst.get());
-                break;
-            }
-        case GGML_TYPE_F16:
-            {
-                ggml_cann_pool_alloc src_trans_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(float));
-                void *               src_trans_buffer = src_trans_allocator.get();
-                ggml_cann_pool_alloc dst_trans_allocator(ctx.pool(), ggml_nelements(dst) * sizeof(float));
-                void *               dst_trans_buffer = dst_trans_allocator.get();
+    // Pre-define head and tail dimensions for reuse
+    int64_t head_ne[GGML_MAX_DIMS] = { rope_dims, ne01, ne02, ne03 };
+    int64_t tail_ne[GGML_MAX_DIMS] = { tail_dims, ne01, ne02, ne03 };
+
+    // Step 1: Prepare trans tensors for F16 type conversion to F32 if needed
+    bool                 src_dst_need_trans = false;
+    ggml_cann_pool_alloc src_trans_allocator(ctx.pool());
+    ggml_cann_pool_alloc dst_trans_allocator(ctx.pool());
+    acl_tensor_ptr       acl_src_trans_tensor;
+    acl_tensor_ptr       acl_dst_trans_tensor;
+    void *               src_trans_buffer = nullptr;
+    void *               dst_trans_buffer = nullptr;
+    size_t               src_dst_trans_nb[GGML_MAX_DIMS];
+    if (src0->type == GGML_TYPE_F16) {
+        src_dst_need_trans = true;
+        src_trans_buffer   = src_trans_allocator.alloc(ggml_nelements(src0) * sizeof(float));
+        dst_trans_buffer   = dst_trans_allocator.alloc(ggml_nelements(dst) * sizeof(float));
+
+        src_dst_trans_nb[0] = sizeof(float);
+        for (int i = 1; i < GGML_MAX_DIMS; i++) {
+            src_dst_trans_nb[i] = src_dst_trans_nb[i - 1] * src0->ne[i - 1];
+        }
+        acl_src_trans_tensor = ggml_cann_create_tensor(src_trans_buffer, ACL_FLOAT, sizeof(float), src0->ne,
+                                                       src_dst_trans_nb, GGML_MAX_DIMS);
+        acl_dst_trans_tensor = ggml_cann_create_tensor(dst_trans_buffer, ACL_FLOAT, sizeof(float), dst->ne,
+                                                       src_dst_trans_nb, GGML_MAX_DIMS);
+        aclnn_cast(ctx, acl_src.get(), acl_src_trans_tensor.get(), ACL_FLOAT);
+    }
+
+    // Step 2: Prepare head tensors for tail splitting if needed
+    acl_tensor_ptr acl_src_head;
+    acl_tensor_ptr acl_dst_head;
+    if (has_tail) {
+        // Create head views for RotaryPositionEmbedding (only first rope_dims dimensions)
+        // RotaryPositionEmbedding requires contiguous dst tensor, so we use a temporary buffer
+        if (src_dst_need_trans) {
+            // Use F32 trans tensor strides
+            acl_src_head = ggml_cann_create_tensor((char *) src_trans_buffer, ACL_FLOAT, sizeof(float), head_ne,
+                                                   src_dst_trans_nb, GGML_MAX_DIMS);
+        } else {
+            // Use original F32 tensor strides
+            acl_src_head = ggml_cann_create_tensor((char *) src0->data, ACL_FLOAT, sizeof(float), head_ne, src0->nb,
+                                                   GGML_MAX_DIMS);
+        }
 
-                size_t src_trans_nb[GGML_MAX_DIMS];
-                src_trans_nb[0] = sizeof(float);
-                for (int i = 1; i < GGML_MAX_DIMS; i++) {
-                    src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
-                }
+        int64_t              head_elements = rope_dims * ne01 * ne02 * ne03;
+        ggml_cann_pool_alloc dst_head_contiguous_allocator(ctx.pool(), head_elements * sizeof(float));
+        void *               dst_head_contiguous_buffer = dst_head_contiguous_allocator.get();
 
-                acl_tensor_ptr acl_src_trans_tensor = ggml_cann_create_tensor(
-                    src_trans_buffer, ACL_FLOAT, sizeof(float), src0->ne, src_trans_nb, GGML_MAX_DIMS);
-                acl_tensor_ptr acl_dst_trans_tensor = ggml_cann_create_tensor(
-                    dst_trans_buffer, ACL_FLOAT, sizeof(float), dst->ne, src_trans_nb, GGML_MAX_DIMS);
+        size_t head_contiguous_nb[GGML_MAX_DIMS];
+        head_contiguous_nb[0] = sizeof(float);
+        for (int i = 1; i < GGML_MAX_DIMS; i++) {
+            head_contiguous_nb[i] = head_contiguous_nb[i - 1] * head_ne[i - 1];
+        }
+        acl_dst_head = ggml_cann_create_tensor(dst_head_contiguous_buffer, ACL_FLOAT, sizeof(float), head_ne,
+                                               head_contiguous_nb, GGML_MAX_DIMS);
+    }
 
-                aclnn_cast(ctx, acl_src.get(), acl_src_trans_tensor.get(), ACL_FLOAT);
+    // Step 3: Execute RotaryPositionEmbedding
+    if (has_tail) {
+        // Rotate only the head portion (first rope_dims dimensions)
+        GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_head.get(), acl_cos_reshape_tensor.get(),
+                                acl_sin_reshape_tensor.get(), acl_mode, acl_dst_head.get());
 
-                GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(),
-                                        acl_cos_reshape_tensor.get(), acl_sin_reshape_tensor.get(), acl_mode,
-                                        acl_dst_trans_tensor.get());
+        // Copy head result from contiguous buffer back to destination tensor
+        if (src_dst_need_trans) {
+            acl_tensor_ptr acl_dst_head_target = ggml_cann_create_tensor(
+                (char *) dst_trans_buffer, ACL_FLOAT, sizeof(float), head_ne, src_dst_trans_nb, GGML_MAX_DIMS);
+            cann_copy(ctx, acl_dst_head.get(), acl_dst_head_target.get());
+        } else {
+            acl_tensor_ptr acl_dst_head_target =
+                ggml_cann_create_tensor((char *) dst->data, ACL_FLOAT, sizeof(float), head_ne, dst->nb, GGML_MAX_DIMS);
+            cann_copy(ctx, acl_dst_head.get(), acl_dst_head_target.get());
+        }
+    } else if (src_dst_need_trans) {
+        // Rotate full tensor (no tail), using trans tensors
+        GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(), acl_cos_reshape_tensor.get(),
+                                acl_sin_reshape_tensor.get(), acl_mode, acl_dst_trans_tensor.get());
+    } else {
+        // Rotate full tensor (no tail), using original tensors
+        GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(),
+                                acl_sin_reshape_tensor.get(), acl_mode, acl_dst.get());
+    }
+
+    // Step 4: Copy unrotated tail portion from source to destination
+    if (has_tail) {
+        size_t src_tail_offset;
+        size_t dst_tail_offset;
+
+        auto copy_tail_device = [&](void * src_ptr, void * dst_ptr, aclDataType dtype, size_t elem_size,
+                                    size_t * nb_src_arr, size_t * nb_dst_arr) {
+            acl_tensor_ptr acl_src_tail =
+                ggml_cann_create_tensor(src_ptr, dtype, elem_size, tail_ne, nb_src_arr, GGML_MAX_DIMS);
+            acl_tensor_ptr acl_dst_tail =
+                ggml_cann_create_tensor(dst_ptr, dtype, elem_size, tail_ne, nb_dst_arr, GGML_MAX_DIMS);
+            cann_copy(ctx, acl_src_tail.get(), acl_dst_tail.get());
+        };
+
+        if (src_dst_need_trans) {
+            // Use F32 trans tensor strides and offsets
+            src_tail_offset = rope_dims * src_dst_trans_nb[0];
+            dst_tail_offset = rope_dims * src_dst_trans_nb[0];
+            copy_tail_device((char *) src_trans_buffer + src_tail_offset, (char *) dst_trans_buffer + dst_tail_offset,
+                             ACL_FLOAT, sizeof(float), src_dst_trans_nb, src_dst_trans_nb);
+        } else {
+            // Use original tensor strides and offsets
+            src_tail_offset = rope_dims * nb00;
+            dst_tail_offset = rope_dims * nb0;
+            copy_tail_device((char *) src0->data + src_tail_offset, (char *) dst->data + dst_tail_offset,
+                             ggml_cann_type_mapping(dst->type), ggml_element_size(dst), src0->nb, dst->nb);
+        }
+    }
 
-                aclnn_cast(ctx, acl_dst_trans_tensor.get(), acl_dst.get(), ACL_FLOAT16);
-                break;
-            }
-        default:
-            GGML_ABORT("Unsupported tensor type for GGML_OP_ROPE");
-            break;
+    // Step 5: Cast back to F16 if needed
+    if (src_dst_need_trans) {
+        aclnn_cast(ctx, acl_dst_trans_tensor.get(), acl_dst.get(), ACL_FLOAT16);
     }
 }
 

+ 2 - 2
ggml/src/ggml-cann/common.h

@@ -315,7 +315,7 @@ struct ggml_cann_rope_cache {
         if (theta_scale_exp_host) {
             free(theta_scale_exp_host);
         }
-        if(position_select_index_host) {
+        if (position_select_index_host) {
             free(position_select_index_host);
         }
     }
@@ -340,7 +340,7 @@ struct ggml_cann_rope_cache {
 
     void set(int64_t theta_scale_length,
              int64_t position_length,
-             float    ext_factor,
+             float   ext_factor,
              float   theta_scale,
              float   freq_scale,
              float   attn_factor,

+ 6 - 8
ggml/src/ggml-cann/ggml-cann.cpp

@@ -2308,7 +2308,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
 
     bool cann_graph_update_required = false;
 #ifdef USE_ACL_GRAPH
-    bool use_cann_graph             = true;
+    bool use_cann_graph = true;
 
     static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
     if (!prefill_use_graph) {
@@ -2338,7 +2338,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
         }
     }
 #else
-    bool use_cann_graph             = false;
+    bool use_cann_graph = false;
 #endif  // USE_ACL_GRAPH
     evaluate_and_capture_cann_graph(cann_ctx, cgraph, use_cann_graph, cann_graph_update_required);
 
@@ -2474,16 +2474,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
             }
         case GGML_OP_ROPE:
             {
-                // TODO: with ops-test v == 1
-                // TODO: n_dims <= ne0
-                if (op->src[0]->ne[0] != op->op_params[1]) {
-                    return false;
-                }
-
                 if (op->src[0]->ne[0] > 896) {
                     return false;
                 }
 #ifdef ASCEND_310P
+                // TODO: Support rope_dim < ne00(dim)
+                if (op->src[0]->ne[0] != op->op_params[1]) {
+                    return false;
+                }
                 if (!ggml_is_contiguous(op->src[0])) {
                     return false;
                 }