Browse Source

CANN: support gated linear attn (#18653)

* CANN: support gated linear attn

This change adds support for the GGML_OP_GATED_LINEAR_ATTN operator.
The feature was implemented by YushengZhao. Because the previous
submission was based on an outdated codebase, this PR was rebased to
merge.

Co-authored-by: YushengZhao <yusheng.chao@outlook.com>
Co-authored-by: hipudding <huafengchun@gmail.com>

* CANN: optimize OP gla

Optimize gla for high preformance

* Remove unused comments

---------

Co-authored-by: 赵禹昇 <2501112001@cninfer02.localdomain>
Co-authored-by: YushengZhao <yusheng.chao@outlook.com>
hipudding 2 weeks ago
parent
commit
baa4ba0aec
3 changed files with 186 additions and 161 deletions
  1. 143 77
      ggml/src/ggml-cann/aclnn_ops.cpp
  2. 39 84
      ggml/src/ggml-cann/aclnn_ops.h
  3. 4 0
      ggml/src/ggml-cann/ggml-cann.cpp

+ 143 - 77
ggml/src/ggml-cann/aclnn_ops.cpp

@@ -58,6 +58,7 @@
 #include <aclnnop/aclnn_mean.h>
 #include <aclnnop/aclnn_mm.h>
 #include <aclnnop/aclnn_mul.h>
+#include <aclnnop/aclnn_mv.h>
 #include <aclnnop/aclnn_permute.h>
 #include <aclnnop/aclnn_pow.h>
 #include <aclnnop/aclnn_pow_tensor_tensor.h>
@@ -2338,20 +2339,21 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
 
     // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor.
     // TODO: acl_yarn_ramp_tensor use rope cache.
-    bool                 yarn_ramp_tensor_updated = false;
-    acl_tensor_ptr       acl_yarn_ramp_tensor;
+    bool           yarn_ramp_tensor_updated = false;
+    acl_tensor_ptr acl_yarn_ramp_tensor;
     if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache.theta_scale_length != theta_scale_length ||
                             ctx.rope_cache.freq_scale != freq_scale)) {
         yarn_ramp_tensor_updated = true;
         if (ctx.rope_cache.yarn_ramp_cache != nullptr) {
             ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache));
         }
-        ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
+        ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float),
+                              ACL_MEM_MALLOC_HUGE_FIRST));
         // -rope_yarn_ramp
         // const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
         // return MIN(1, MAX(0, y)) - 1;
-        acl_yarn_ramp_tensor =
-            ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
+        acl_yarn_ramp_tensor      = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float),
+                                                            theta_scale_ne, theta_scale_nb, 1);
         float          zero_value = 0, one_value = 1;
         float          denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
         acl_scalar_ptr low              = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT);
@@ -2382,8 +2384,8 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx,
         GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get());
         GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get());
     } else {
-        acl_yarn_ramp_tensor =
-            ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1);
+        acl_yarn_ramp_tensor = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float),
+                                                       theta_scale_ne, theta_scale_nb, 1);
     }
     // Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale.
     if (ext_factor != 0) {
@@ -2991,20 +2993,20 @@ void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
     GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src.get(), 3, false, acl_dst.get());
 }
 
-void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){
+void ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
     ggml_tensor * src0 = dst->src[0];
     ggml_tensor * src1 = dst->src[1];
 
     // stride
-    int64_t s0 = ((const int32_t*)(dst->op_params))[0];
+    int64_t s0 = ((const int32_t *) (dst->op_params))[0];
 
-    acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);
+    acl_tensor_ptr acl_input  = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL);
     acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL);
-    acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL);
+    acl_tensor_ptr acl_dst    = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL);
 
     // get base information of input and kernel
-    int64_t input_len = *(src1->ne);
-    int64_t dst_len = *(dst->ne);
+    int64_t input_len   = *(src1->ne);
+    int64_t dst_len     = *(dst->ne);
     int64_t kernel_size = *(src0->ne);
 
     // set the max kernel size for each conv
@@ -3012,56 +3014,55 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
 
     // compute the partition of kernel
     int64_t part_num = 1;
-    part_num = (kernel_size + max_kernel_size - 1) / max_kernel_size;
+    part_num         = (kernel_size + max_kernel_size - 1) / max_kernel_size;
 
     int64_t strideVal[1];
-    strideVal[0] = s0;
-    acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1);
-    int64_t paddingVal[] = {0};
-    acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1);
-    int64_t dilationVal[] = {1};
-    acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1);
-    bool transposed = true;
-    int64_t groups = 1;
-    int8_t cubeMathType = 0;
+    strideVal[0]                    = s0;
+    acl_int_array_ptr stride        = ggml_cann_create_int_array(strideVal, 1);
+    int64_t           paddingVal[]  = { 0 };
+    acl_int_array_ptr padding       = ggml_cann_create_int_array(paddingVal, 1);
+    int64_t           dilationVal[] = { 1 };
+    acl_int_array_ptr dilation      = ggml_cann_create_int_array(dilationVal, 1);
+    bool              transposed    = true;
+    int64_t           groups        = 1;
+    int8_t            cubeMathType  = 0;
 
 #ifdef ASCEND_310P
     cubeMathType = 1;
 #endif
 
     auto weight_type = ggml_cann_type_mapping(src0->type);
-    auto dst_type = ggml_cann_type_mapping(dst->type);
+    auto dst_type    = ggml_cann_type_mapping(dst->type);
 
     // slice the kernel to make each conv available
-    int64_t slice_dim = -1;
+    int64_t slice_dim   = -1;
     int64_t slice_start = 0;
-    int64_t slice_end = max_kernel_size;
-    int64_t slice_step = 1;
-    int64_t interval = max_kernel_size;
+    int64_t slice_end   = max_kernel_size;
+    int64_t slice_step  = 1;
+    int64_t interval    = max_kernel_size;
 
-    int64_t left_pad_len = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0];
+    int64_t left_pad_len  = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0];
     int64_t right_pad_len = 0;
 
-    acl_scalar_ptr alpha = nullptr;
-    float alphaValue = 1.0;
-    alpha = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);
+    acl_scalar_ptr alpha      = nullptr;
+    float          alphaValue = 1.0;
+    alpha                     = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT);
 
     // set zero to destination
     GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get());
 
-    for(int k = 0; k < part_num; k++){
-
+    for (int k = 0; k < part_num; k++) {
         // create part kernel tensor and slice from big kernel
         slice_start = max_kernel_size * k;
-        if(k == part_num - 1){
+        if (k == part_num - 1) {
             slice_end = kernel_size;
-            interval = kernel_size - max_kernel_size * k;
-        }else{
-            slice_end = max_kernel_size * (k+1);
+            interval  = kernel_size - max_kernel_size * k;
+        } else {
+            slice_end = max_kernel_size * (k + 1);
         }
 
         int64_t part_ne[4];
-        for(int i = 0; i < 4; i++) {
+        for (int i = 0; i < 4; i++) {
             part_ne[i] = *(src0->ne + i);
         }
         part_ne[0] = interval;
@@ -3074,16 +3075,17 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
 
         ggml_cann_pool_alloc part_kernel_allocator;
         part_kernel_allocator.alloc(ctx.pool(), part_nb[3]);
-        void* part_kernel_buf = part_kernel_allocator.get();
+        void * part_kernel_buf = part_kernel_allocator.get();
 
-        acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type,
-                                ggml_element_size(src0), part_ne, part_nb, 3, ACL_FORMAT_NCL);
+        acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type, ggml_element_size(src0),
+                                                             part_ne, part_nb, 3, ACL_FORMAT_NCL);
 
-        GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step, part_kernel.get());
+        GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step,
+                                part_kernel.get());
 
         // create the part conv result tensor
         int64_t part_dst_ne[4];
-        for(int i = 0; i < 4; i++){
+        for (int i = 0; i < 4; i++) {
             part_dst_ne[i] = *(dst->ne + i);
         }
         part_dst_ne[0] = (input_len - 1) * strideVal[0] - 2 * paddingVal[0] + dilationVal[0] * (part_ne[0] - 1) + 1;
@@ -3095,32 +3097,33 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
         }
         ggml_cann_pool_alloc part_dst_allocator;
         part_dst_allocator.alloc(ctx.pool(), part_dst_nb[3]);
-        void* part_dst_buf = part_dst_allocator.get();
+        void * part_dst_buf = part_dst_allocator.get();
 
         acl_tensor_ptr acl_part_dst = ggml_cann_create_tensor(part_dst_buf, dst_type, ggml_element_size(dst),
-                                    part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL);
+                                                              part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL);
         GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_part_dst.get());
 
         // compute part conv transpose 1d
         GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), part_kernel.get(), nullptr, stride.get(),
-        padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(), cubeMathType);
+                                padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(),
+                                cubeMathType);
 
         // compute the position of part result in final result
         int64_t global_start = slice_start;
-        int64_t global_end = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len);
+        int64_t global_end   = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len);
 
-        left_pad_len = global_start;
+        left_pad_len  = global_start;
         right_pad_len = dst_len - global_end;
 
-        std::vector<int64_t> padDataVal = {left_pad_len,right_pad_len};
-        acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2);
+        std::vector<int64_t> padDataVal = { left_pad_len, right_pad_len };
+        acl_int_array_ptr    padData    = ggml_cann_create_int_array(padDataVal.data(), 2);
 
-        acl_scalar_ptr pad_value = nullptr;
-        float pad_valueVal = 0.0;
-        pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT);
+        acl_scalar_ptr pad_value    = nullptr;
+        float          pad_valueVal = 0.0;
+        pad_value                   = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT);
 
         int64_t conv_result_ne[4];
-        for(int i = 0; i < 4; i++){
+        for (int i = 0; i < 4; i++) {
             conv_result_ne[i] = *(dst->ne + i);
         }
 
@@ -3132,13 +3135,14 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
 
         ggml_cann_pool_alloc conv_result_allocator;
         conv_result_allocator.alloc(ctx.pool(), conv_result_nb[3]);
-        void* conv_result_buf = conv_result_allocator.get();
+        void * conv_result_buf = conv_result_allocator.get();
 
         acl_tensor_ptr conv_result = ggml_cann_create_tensor(conv_result_buf, dst_type, ggml_element_size(dst),
-                                    conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL);
+                                                             conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL);
 
         GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, conv_result.get());
-        GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(), conv_result.get());
+        GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(),
+                                conv_result.get());
         GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), conv_result.get(), alpha.get());
     }
 }
@@ -3742,15 +3746,15 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
     // we want a view:  ne_w = { nc, 1, nr }   // [K, 1, C]
     // so that reversed dims -> [C, 1, K] which matches
     //   [out_channels, in_channels/groups, kernel_size]
-    int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 }; // [K, 1 input ch. per group, C groups]
+    int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 };  // [K, 1 input ch. per group, C groups]
     // Layout: src1 data is [K, C] with
     //   offset(k, c) = k*nb0 + c*nb1
     // We want offset_w(k, 0, c) = k*nb0 + c*nb1,
     // so we can reuse nb0 and nb1, and set nb2 = nb1.
-    size_t  w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1
+    size_t  w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] };  // same as src1
 
-    acl_tensor_ptr acl_w = ggml_cann_create_tensor(
-        src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL);
+    acl_tensor_ptr acl_w = ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type),
+                                                   ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL);
 
     // 3) Output: dst is { d_inner, n_t, n_s } (CLN)
     //
@@ -3768,11 +3772,12 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
     //   nb_y[0] = nr * sizeof(float);           // step in L
     //   nb_y[1] = sizeof(float);                // step in C
     //   nb_y[2] = nr * n_t * sizeof(float);     // step in N
-    int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N]
-    size_t  y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float), dst->nb[3] }; // [nr, 1, nr * n_t]
+    int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 };  // [L_out, C, N]
+    size_t  y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float),
+                                    dst->nb[3] };       // [nr, 1, nr * n_t]
 
-    acl_tensor_ptr acl_y = ggml_cann_create_tensor(
-        dst->data, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL);
+    acl_tensor_ptr acl_y = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type),
+                                                   ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL);
 
     // --- Conv1d parameters: depthwise, stride 1, no padding ("valid") ---
     int64_t strideVal[1]   = { 1 };
@@ -3791,22 +3796,15 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
     cubeMathType = 1;
 #endif
 
-    GGML_CANN_CALL_ACLNN_OP(ctx,
-                            Convolution,
+    GGML_CANN_CALL_ACLNN_OP(ctx, Convolution,
                             acl_x.get(),    // input:  N, C, L_in = ncs
                             acl_w.get(),    // weight: [C, 1, K] with groups=nr
                             nullptr,        // bias
-                            stride.get(),
-                            padding.get(),
-                            dilation.get(),
-                            transposed,
-                            padding.get(),   // output padding (unused for non-transposed)
-                            groups,
-                            acl_y.get(),
-                            cubeMathType);
+                            stride.get(), padding.get(), dilation.get(), transposed,
+                            padding.get(),  // output padding (unused for non-transposed)
+                            groups, acl_y.get(), cubeMathType);
 }
 
-
 void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
                                      ggml_tensor *               add_node,
                                      ggml_tensor *               rms_norm_node) {
@@ -3860,3 +3858,71 @@ void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
                             eps,  // double type
                             acl_yout.get(), acl_rstd.get(), acl_xout.get());
 }
+
+void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
+    ggml_tensor * k = dst->src[0];
+    ggml_tensor * v = dst->src[1];
+    ggml_tensor * q = dst->src[2];
+    ggml_tensor * g = dst->src[3];
+    ggml_tensor * s = dst->src[4];
+
+    int64_t B = dst->src[4]->ne[1];
+    int64_t T = dst->src[0]->ne[2];
+    int64_t H = dst->src[0]->ne[1];
+    int64_t C = dst->ne[0];
+    int64_t D = C / H;
+    int64_t L = T / B;
+
+    int64_t ne_qkg[2] = { 1, D };
+    int64_t ne_s[2]   = { D, D };
+    int64_t ne_st[2]  = { ne_s[1], ne_s[0] };
+    int64_t ne_vo[2]  = { D, 1 };
+    int64_t ne_q[1]   = { D };
+    size_t  nb_base   = ggml_type_size(k->type);
+    size_t  nb_qkg[2] = { nb_base, nb_base };
+    size_t  nb_s[2]   = { nb_base, D * nb_base };
+    size_t  nb_st[2]  = { nb_s[1], nb_s[0] };
+    size_t  nb_vo[2]  = { nb_base, D * nb_base };
+    size_t  nb_q[1]   = { nb_base };
+
+    const float scale = ggml_get_op_params_f32(dst, 0);
+
+    acl_tensor_ptr acl_s     = ggml_cann_create_tensor(s, s->ne, s->nb, 2, ACL_FORMAT_ND);
+    acl_tensor_ptr new_state = ggml_cann_create_tensor(dst, s->ne, s->nb, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base);
+    cann_copy(ctx, acl_s.get(), new_state.get());
+
+    for (int64_t b = 0; b < B; b++) {
+        for (int64_t h = 0; h < H; h++) {
+            size_t         s_offset = (b * (H * D * D) + h * (D * D)) * nb_base;
+            // D * D
+            acl_tensor_ptr acl_s_new =
+                ggml_cann_create_tensor(dst, ne_s, nb_s, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset);
+            acl_tensor_ptr acl_s_new_t =
+                ggml_cann_create_tensor(dst, ne_st, nb_st, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset);
+            for (int64_t l = 0; l < L; l++) {
+                size_t               qkvgo_offset = (b * (L * H * D) + l * (H * D) + h * (D)) * nb_base;
+                // D * 1
+                acl_tensor_ptr       acl_k = ggml_cann_create_tensor(k, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);
+                acl_tensor_ptr       acl_g = ggml_cann_create_tensor(g, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset);
+                // D
+                acl_tensor_ptr       acl_q = ggml_cann_create_tensor(q, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);
+                // 1 * D
+                acl_tensor_ptr       acl_v = ggml_cann_create_tensor(v, ne_vo, nb_vo, 2, ACL_FORMAT_ND, qkvgo_offset);
+                // D
+                acl_tensor_ptr       acl_o = ggml_cann_create_tensor(dst, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset);
+                // k ⊗ v
+                size_t               buf_size = D * D * nb_base;
+                ggml_cann_pool_alloc buffer_allocator(ctx.pool(), buf_size);
+                acl_tensor_ptr       tmp_tensor = ggml_cann_create_tensor(
+                    buffer_allocator.get(), ggml_cann_type_mapping(k->type), nb_base, ne_s, nb_s, 2);
+                aclnn_mul(ctx, acl_k.get(), acl_v.get(), tmp_tensor.get());
+                //s_new = g ⊗ s_old + k ⊗ v
+                aclnn_mul(ctx, acl_s_new.get(), acl_g.get(), nullptr);
+                aclnn_add(ctx, acl_s_new.get(), tmp_tensor.get(), nullptr);
+                // compute output
+                GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_s_new_t.get(), acl_q.get(), acl_o.get(), 1);
+                aclnn_muls(ctx, acl_o.get(), scale, nullptr, true);
+            }
+        }
+    }
+}

+ 39 - 84
ggml/src/ggml-cann/aclnn_ops.h

@@ -814,67 +814,20 @@ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst);
  */
 void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst);
 
-/*
- * @brief A generic wrapper for ACL resources with custom deleter support.
- */
-using any_acl_resource = std::unique_ptr<void, std::function<void(void *)>>;
-
 /**
- * @brief Trait structure used to define how to destroy a given ACL resource type.
+ * @brief Forward Gated Linear Attention on the CANN backend.
  *
- * @tparam T ACL resource type.
- */
-template <typename T> struct acl_resource_traits;
-
-/**
- * @brief Specialization for aclTensor, defines how to destroy an aclTensor resource.
- */
-template <> struct acl_resource_traits<aclTensor> {
-    static void destroy(void * p) { ACL_CHECK(aclDestroyTensor(static_cast<aclTensor *>(p))); }
-};
-
-/**
- * @brief Specialization for aclIntArray, defines how to destroy an aclIntArray resource.
- */
-template <> struct acl_resource_traits<aclIntArray> {
-    static void destroy(void * p) { ACL_CHECK(aclDestroyIntArray(static_cast<aclIntArray *>(p))); }
-};
-
-/**
- * @brief Specialization for aclScalar, defines how to destroy an aclScalar resource.
- */
-template <> struct acl_resource_traits<aclScalar> {
-    static void destroy(void * p) { ACL_CHECK(aclDestroyScalar(static_cast<aclScalar *>(p))); }
-};
-
-/**
- * @brief Specialization for aclTensorList, defines how to destroy an aclTensorList resource.
- */
-template <> struct acl_resource_traits<aclTensorList> {
-    static void destroy(void * p) { ACL_CHECK(aclDestroyTensorList(static_cast<aclTensorList *>(p))); }
-};
-
-/**
- * @brief Creates a generic ACL resource wrapper with proper destruction logic.
+ * Expects dst->src[0..4] = {k, v, q, g, s} with shape conventions:
+ *   k, v, q, g: [D] with outer dims T x H batched as ne[2]=T, ne[1]=H
+ *   s: initial state [B, H, D, D], where B is batch and D=C/H
+ * dst holds both outputs (o) and updated state; a scale factor is read from op params.
  *
- * @tparam T ACL resource type.
- * @param ptr Raw pointer to ACL resource.
- * @return any_acl_resource Smart pointer that handles destruction.
- */
-template <typename T> any_acl_resource make_acl_resource(T * ptr) {
-    return any_acl_resource(static_cast<void *>(ptr), [](void * p) { acl_resource_traits<T>::destroy(p); });
-}
-
-/**
- * @brief Registers multiple ACL resources into a vector for lifetime management.
+ * The kernel updates per time step l: S_new = g ⊗ S_old + k ⊗ v, then computes o = (S_new^T q) * scale.
  *
- * @tparam Args Variadic list of ACL resource types.
- * @param vec Target vector to hold ACL resources.
- * @param args Raw pointers to ACL resources.
+ * @param ctx Backend context providing stream/allocator utilities.
+ * @param dst Output tensor; src deps are k, v, q, g, s as above.
  */
-template <typename... Args> void register_acl_resources(std::vector<any_acl_resource> & vec, Args *... args) {
-    (vec.emplace_back(make_acl_resource(args)), ...);
-}
+void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst);
 
 /**
  * @brief Launches an asynchronous task using the memory allocator.
@@ -894,19 +847,19 @@ template <typename... Args> void register_acl_resources(std::vector<any_acl_reso
  * same stream are executed in queue order.
  */
 
-#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...)                                           \
-    do {                                                                                     \
-        uint64_t        workspaceSize = 0;                                                   \
-        aclOpExecutor * executor;                                                            \
-        void *          workspaceAddr = nullptr;                                             \
-        ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
-        /* workspace should alloced in main thread to keep malloc order when using vmm. */   \
-        if (workspaceSize > 0) {                                                             \
-            ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize);             \
-            workspaceAddr = workspace_allocator.get();                                       \
-        }                                                                                    \
-        ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream()));     \
-    } while (0)
+#    define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...)                                           \
+        do {                                                                                     \
+            uint64_t        workspaceSize = 0;                                                   \
+            aclOpExecutor * executor;                                                            \
+            void *          workspaceAddr = nullptr;                                             \
+            ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \
+            /* workspace should alloced in main thread to keep malloc order when using vmm. */   \
+            if (workspaceSize > 0) {                                                             \
+                ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize);             \
+                workspaceAddr = workspace_allocator.get();                                       \
+            }                                                                                    \
+            ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream()));     \
+        } while (0)
 
 /**
  * @brief   Performs sparse expert-based matrix multiplication using the CANN backend.
@@ -947,7 +900,9 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst);
  * @param rms_norm_tensor The RMS_NORM operation node, contains the gamma weights
  *                        and epsilon parameter.
  */
-void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, ggml_tensor * add_node, ggml_tensor * rms_norm_node);
+void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
+                                     ggml_tensor *               add_node,
+                                     ggml_tensor *               rms_norm_node);
 
 /**
  * @brief   Check whether a tensor is a weight tensor for matrix multiplication.
@@ -1104,13 +1059,13 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac
  * @see ggml_cann_op_unary
  * @see GGML_CANN_CALL_ACLNN_OP
  */
-#define GGML_CANN_CALL_OP_UNARY(OP_NAME)                                                              \
-    do {                                                                                              \
-        auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \
-            GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);                                  \
-        };                                                                                            \
-        ggml_cann_op_unary(lambda, ctx, dst);                                                         \
-    } while (0)
+#    define GGML_CANN_CALL_OP_UNARY(OP_NAME)                                                              \
+        do {                                                                                              \
+            auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \
+                GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);                                  \
+            };                                                                                            \
+            ggml_cann_op_unary(lambda, ctx, dst);                                                         \
+        } while (0)
 
 /**
  * @brief Helper macro to call a gated unary ACL operator via ggml_cann_op_unary_gated.
@@ -1133,13 +1088,13 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac
  * @see ggml_cann_op_unary_gated
  * @see GGML_CANN_CALL_ACLNN_OP
  */
-#define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME)                                                        \
-    do {                                                                                              \
-        auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \
-            GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);                                  \
-        };                                                                                            \
-        ggml_cann_op_unary_gated(lambda, ctx, dst);                                                   \
-    } while (0)
+#    define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME)                                                        \
+        do {                                                                                              \
+            auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \
+                GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);                                  \
+            };                                                                                            \
+            ggml_cann_op_unary_gated(lambda, ctx, dst);                                                   \
+        } while (0)
 
 #endif  // CANN_ACLNN_OPS
 

+ 4 - 0
ggml/src/ggml-cann/ggml-cann.cpp

@@ -1889,6 +1889,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
         case GGML_OP_OUT_PROD:
             ggml_cann_out_prod(ctx, dst);
             break;
+        case GGML_OP_GATED_LINEAR_ATTN:
+            ggml_cann_gated_linear_attn(ctx, dst);
+            break;
         case GGML_OP_SSM_CONV:
             ggml_cann_ssm_conv(ctx, dst);
             break;
@@ -2454,6 +2457,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
         case GGML_OP_MEAN:
         case GGML_OP_PAD_REFLECT_1D:
         case GGML_OP_COUNT_EQUAL:
+        case GGML_OP_GATED_LINEAR_ATTN:
             return true;
         case GGML_OP_OUT_PROD:
             {