Просмотр исходного кода

CANN: add operator fusion support for ADD + RMS_NORM (#17512)

This commit implements operator fusion for ADD + RMS_NORM operations
in the CANN backend to reduce memory access overhead and improve
performance. The fusion is controlled by the GGML_CANN_OPERATOR_FUSION
environment variable (default: false).

Changes:
- Implement ggml_cann_op_add_rms_norm_fused() using ACLNN AddRmsNorm
- Add ggml_cann_can_fuse() to check fusion eligibility
- Integrate fusion logic into computation graph evaluation
- Add test cases for ADD + RMS_NORM fusion
- Update documentation with new environment variable

The fusion combines ADD and RMS_NORM into a single kernel call,
which is more efficient than executing them separately.
Chenguang Li 3 недель назад
Родитель
Сommit
67e3f6f601

+ 4 - 0
docs/backend/CANN.md

@@ -327,3 +327,7 @@ Maximum number of compiled CANN graphs kept in the LRU cache, default is 12. Whe
 ### GGML_CANN_PREFILL_USE_GRAPH
 ### GGML_CANN_PREFILL_USE_GRAPH
 
 
 Enable ACL graph execution during the prefill stage, default is false. This option is only effective when FA is enabled.
 Enable ACL graph execution during the prefill stage, default is false. This option is only effective when FA is enabled.
+
+### GGML_CANN_OPERATOR_FUSION
+
+Enable operator fusion during computation, default is false. This option fuses compatible operators (e.g., ADD + RMS_NORM) to reduce overhead and improve performance.

+ 55 - 0
ggml/src/ggml-cann/aclnn_ops.cpp

@@ -26,6 +26,7 @@
 #include "ggml.h"
 #include "ggml.h"
 
 
 #include <aclnnop/aclnn_add.h>
 #include <aclnnop/aclnn_add.h>
+#include <aclnnop/aclnn_add_rms_norm.h>
 #include <aclnnop/aclnn_addcdiv.h>
 #include <aclnnop/aclnn_addcdiv.h>
 #include <aclnnop/aclnn_argmax.h>
 #include <aclnnop/aclnn_argmax.h>
 #include <aclnnop/aclnn_avgpool2d.h>
 #include <aclnnop/aclnn_avgpool2d.h>
@@ -3805,3 +3806,57 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
                             cubeMathType);
                             cubeMathType);
 }
 }
 
 
+
+void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
+                                     ggml_tensor *               add_node,
+                                     ggml_tensor *               rms_norm_node) {
+    // Get the two input tensors for ADD operation
+    ggml_tensor * x1 = add_node->src[0];
+    ggml_tensor * x2 = add_node->src[1];
+
+    // Create ACL tensors for the two ADD inputs
+    acl_tensor_ptr acl_x1 = ggml_cann_create_tensor(x1);
+    acl_tensor_ptr acl_x2 = ggml_cann_create_tensor(x2);
+
+    // Get epsilon parameter from rms_norm_tensor
+    float eps;
+    memcpy(&eps, rms_norm_node->op_params, sizeof(float));
+
+    // Build gamma tensor (RMS normalization scaling factor)
+    // Gamma should match the normalized dimensions (last dimension of x1)
+    size_t acl_gamma_nb[GGML_MAX_DIMS];
+    acl_gamma_nb[0] = ggml_type_size(rms_norm_node->type);
+    for (int i = 1; i < GGML_MAX_DIMS; i++) {
+        acl_gamma_nb[i] = acl_gamma_nb[i - 1] * x1->ne[i - 1];
+    }
+    acl_tensor_ptr acl_gamma =
+        get_cache_acl_tensor(ctx, &ctx.rms_norm_one_tensor_cache.cache, ctx.rms_norm_one_tensor_cache.size, x1->ne,
+                             acl_gamma_nb, rms_norm_node->type,
+                             1,    // dims - only the last dimension
+                             1.0f  // value
+        );
+
+    // Build rstdOut tensor (output for normalized standard deviation)
+    // Shape should be the dimensions that are NOT normalized
+    int64_t acl_rstd_ne[] = { 1, x1->ne[1], x1->ne[2], x1->ne[3] };
+    size_t  acl_rstd_nb[GGML_MAX_DIMS - 1];
+    acl_rstd_nb[0] = sizeof(float);
+    for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
+        acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1];
+    }
+    acl_tensor_ptr acl_rstd =
+        get_cache_acl_tensor(ctx, &ctx.rms_norm_zero_tensor_cache.cache, ctx.rms_norm_zero_tensor_cache.size,
+                             acl_rstd_ne, acl_rstd_nb, GGML_TYPE_F32, GGML_MAX_DIMS,
+                             0.0f  // value
+        );
+
+    acl_tensor_ptr acl_xout = ggml_cann_create_tensor(add_node);
+
+    // Create yOut tensor (final output after RMS normalization)
+    acl_tensor_ptr acl_yout = ggml_cann_create_tensor(rms_norm_node);
+
+    // Call fused ADD + RMS_NORM operator
+    GGML_CANN_CALL_ACLNN_OP(ctx, AddRmsNorm, acl_x1.get(), acl_x2.get(), acl_gamma.get(),
+                            eps,  // double type
+                            acl_yout.get(), acl_rstd.get(), acl_xout.get());
+}

+ 14 - 0
ggml/src/ggml-cann/aclnn_ops.h

@@ -935,6 +935,20 @@ template <typename... Args> void register_acl_resources(std::vector<any_acl_reso
  */
  */
 void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst);
 void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst);
 
 
+/**
+ * @brief Performs fused ADD + RMS_NORM operation using the CANN backend.
+ *
+ * This function fuses the ADD and RMS_NORM operations into a single kernel call
+ * for better performance. It first adds two input tensors (x1 + x2), then applies
+ * RMS normalization to the result.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param dst The ADD operation node, contains the two input tensors to be added.
+ * @param rms_norm_tensor The RMS_NORM operation node, contains the gamma weights
+ *                        and epsilon parameter.
+ */
+void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, ggml_tensor * add_node, ggml_tensor * rms_norm_node);
+
 /**
 /**
  * @brief   Check whether a tensor is a weight tensor for matrix multiplication.
  * @brief   Check whether a tensor is a weight tensor for matrix multiplication.
  *
  *

+ 44 - 0
ggml/src/ggml-cann/ggml-cann.cpp

@@ -1888,6 +1888,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
             break;
             break;
         case GGML_OP_OUT_PROD:
         case GGML_OP_OUT_PROD:
             ggml_cann_out_prod(ctx, dst);
             ggml_cann_out_prod(ctx, dst);
+            break;
         case GGML_OP_SSM_CONV:
         case GGML_OP_SSM_CONV:
             ggml_cann_ssm_conv(ctx, dst);
             ggml_cann_ssm_conv(ctx, dst);
             break;
             break;
@@ -2077,6 +2078,40 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
     ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
     ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
 }
 }
 
 
+/**
+ * @brief Check if CANN backend can fuse the specified operation sequence
+ *
+ * This function determines whether an operation sequence starting from the specified node
+ * can be fused into an optimized operation in the CANN backend. Operation fusion can reduce
+ * memory access overhead and improve computational efficiency.
+ *
+ * @param cgraph Pointer to the computation graph
+ * @param node_idx Index of the starting node in the computation graph
+ * @param ops Sequence of operation types to check for fusion
+ * @return true if the operations can be fused
+ * @return false if the operations cannot be fused
+ */
+static bool ggml_cann_can_fuse(const struct ggml_cgraph *          cgraph,
+                               int                                 node_idx,
+                               std::initializer_list<enum ggml_op> ops) {
+    if (!ggml_can_fuse(cgraph, node_idx, ops)) {
+        return false;
+    }
+
+    // CANN backend supports fusing ADD + RMS_NORM operations
+    if ((ops.size() == 2) && ops.begin()[0] == GGML_OP_ADD && ops.begin()[1] == GGML_OP_RMS_NORM) {
+        ggml_tensor * add_node = cgraph->nodes[node_idx];
+        // TODO: support broadcast for ADD + RMS_NORM
+        if (add_node->src[0]->ne[0] != add_node->src[1]->ne[0] || add_node->src[0]->ne[1] != add_node->src[1]->ne[1] ||
+            add_node->src[0]->ne[2] != add_node->src[1]->ne[2] || add_node->src[0]->ne[3] != add_node->src[1]->ne[3]) {
+            return false;
+        }
+        return true;
+    }
+
+    return false;
+}
+
 /**
 /**
  * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
  * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
  *
  *
@@ -2101,9 +2136,18 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
 #endif  // USE_ACL_GRAPH
 #endif  // USE_ACL_GRAPH
     // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
     // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
     // With the use of CANN graphs, the execution will be performed by the graph launch.
     // With the use of CANN graphs, the execution will be performed by the graph launch.
+    static bool opt_fusion = parse_bool(get_env("GGML_CANN_OPERATOR_FUSION").value_or(""));
+
     if (!use_cann_graph || cann_graph_capture_required) {
     if (!use_cann_graph || cann_graph_capture_required) {
         for (int i = 0; i < cgraph->n_nodes; i++) {
         for (int i = 0; i < cgraph->n_nodes; i++) {
             ggml_tensor * node = cgraph->nodes[i];
             ggml_tensor * node = cgraph->nodes[i];
+            if (opt_fusion) {
+                if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
+                    ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
+                    i++;
+                    continue;
+                }
+            }
 
 
             if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
             if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
                 node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
                 node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {

+ 62 - 0
tests/test-backend-ops.cpp

@@ -3431,6 +3431,65 @@ struct test_rms_norm_mul_add : public test_case {
     }
     }
 };
 };
 
 
+// GGML_OP_ADD + GGML_OP_RMS_NORM (fused operation)
+struct test_add_rms_norm : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const float eps;
+    const bool broadcast;
+
+    std::string op_desc(ggml_tensor * t) override {
+        GGML_UNUSED(t);
+        return "ADD_RMS_NORM";
+    }
+
+    bool run_whole_graph() override { return true; }
+
+    std::string vars() override {
+        return VARS_TO_STR4(type, ne, eps, broadcast);
+    }
+
+    test_add_rms_norm(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {64, 5, 4, 3},
+            float eps = 1e-6f, bool broadcast = false)
+        : type(type), ne(ne), eps(eps), broadcast(broadcast) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        std::array<int64_t, 4> broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4};
+
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data());
+        ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
+
+        ggml_set_param(a);
+        ggml_set_name(a, "a");
+        ggml_set_param(b);
+        ggml_set_name(b, "b");
+
+        // ADD operation followed by RMS_NORM
+        ggml_tensor * add_result = ggml_add(ctx, a, b);
+        ggml_set_name(add_result, "add_result");
+
+        ggml_tensor * out = ggml_rms_norm(ctx, add_result, eps);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, -10.f, 10.f);
+        }
+    }
+
+    float grad_eps() override {
+        return 1.0f;
+    }
+
+    bool grad_precise() override {
+        return true;
+    }
+};
+
 // GGML_OP_SSM_CONV
 // GGML_OP_SSM_CONV
 struct test_ssm_conv : public test_case {
 struct test_ssm_conv : public test_case {
     const ggml_type type;
     const ggml_type type;
@@ -7393,11 +7452,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
         test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
         test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
         test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
         test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
         test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
+        test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false));
+        test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true));
     }
     }
     for (uint32_t n : {1, 511, 1025, 8192, 33*512}) {
     for (uint32_t n : {1, 511, 1025, 8192, 33*512}) {
         for (bool multi_add : {false, true}) {
         for (bool multi_add : {false, true}) {
             test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f, false, multi_add));
             test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f, false, multi_add));
         }
         }
+        test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f, false));
     }
     }
 
 
     for (auto multi_add : {false, true}) {
     for (auto multi_add : {false, true}) {