瀏覽代碼

CANN: Improve ACL graph matching (#16166)

* CANN: improve ACL graph matching

Record `ne` and `nb` information for src tensors and include them in the
graph matching check. This enhances the robustness of ACL graph matching
by preventing incorrect matches when src tensors share the same data
address but differ in shape or stride.

* CANN: add op_params match
Chenguang Li 3 月之前
父節點
當前提交
aa4711d369
共有 2 個文件被更改,包括 45 次插入12 次删除
  1. 8 1
      ggml/src/ggml-cann/common.h
  2. 37 11
      ggml/src/ggml-cann/ggml-cann.cpp

+ 8 - 1
ggml/src/ggml-cann/common.h

@@ -341,11 +341,18 @@ private:
 
 #ifdef USE_ACL_GRAPH
 struct ggml_graph_node_properties {
+    // dst tensor
     void * node_address;
-    ggml_op node_op;
     int64_t ne[GGML_MAX_DIMS];
     size_t nb[GGML_MAX_DIMS];
+
+    // src tensor
     void * src_address[GGML_MAX_SRC];
+    int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS];
+    size_t  src_nb[GGML_MAX_SRC][GGML_MAX_DIMS];
+
+    // op
+    ggml_op node_op;
     int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
 };
 

+ 37 - 11
ggml/src/ggml-cann/ggml-cann.cpp

@@ -2186,7 +2186,15 @@ static void add_lru_matched_graph_node_properties(
         std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
 
         for (int src = 0; src < GGML_MAX_SRC; ++src) {
-            prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr;
+            if (node->src[src]) {
+                prop.src_address[src] = node->src[src]->data;
+                std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
+                std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
+            } else {
+                prop.src_address[src] = nullptr;
+                std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
+                std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
+            }
         }
 
         memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
@@ -2206,14 +2214,18 @@ static void add_lru_matched_graph_node_properties(
  * @param graph_node_properties The stored properties of a CANN graph node.
  * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
  */
-static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
+static bool ggml_graph_node_has_matching_properties(
+        ggml_tensor * node,
+        ggml_graph_node_properties * graph_node_properties) {
     if (node->data != graph_node_properties->node_address &&
-           node->op != GGML_OP_VIEW) {
+            node->op != GGML_OP_VIEW) {
         return false;
     }
+
     if (node->op != graph_node_properties->node_op) {
         return false;
     }
+
     for (int i = 0; i < GGML_MAX_DIMS; i++) {
         if (node->ne[i] != graph_node_properties->ne[i]) {
             return false;
@@ -2222,17 +2234,31 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
             return false;
         }
     }
+
     for (int i = 0; i < GGML_MAX_SRC; i++) {
-        if (node->src[i] &&
-            node->src[i]->data != graph_node_properties->src_address[i] &&
-            node->op != GGML_OP_VIEW
-        ) {
-            return false;
+        if (node->src[i]) {
+            if (node->src[i]->data != graph_node_properties->src_address[i] &&
+                node->op != GGML_OP_VIEW) {
+                return false;
+            }
+
+            for (int d = 0; d < GGML_MAX_DIMS; d++) {
+                if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) {
+                    return false;
+                }
+                if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) {
+                    return false;
+                }
+            }
+        } else {
+            if (graph_node_properties->src_address[i] != nullptr) {
+                return false;
+            }
         }
     }
-    if (node->op == GGML_OP_SCALE &&
-        memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
-        return false;
+
+    if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
+        return memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
     }
     return true;
 }