소스 검색

CANN: Improve ACL graph matching (#16166)

* CANN: improve ACL graph matching

Record `ne` and `nb` information for src tensors and include them in the
graph matching check. This enhances the robustness of ACL graph matching
by preventing incorrect matches when src tensors share the same data
address but differ in shape or stride.

* CANN: add op_params match
Chenguang Li 3 달 전
부모
커밋
aa4711d369
2개의 변경된 파일45개의 추가작업 그리고 12개의 파일을 삭제
  1. 8 1
      ggml/src/ggml-cann/common.h
  2. 37 11
      ggml/src/ggml-cann/ggml-cann.cpp

+ 8 - 1
ggml/src/ggml-cann/common.h

@@ -341,11 +341,18 @@ private:
 
 
 #ifdef USE_ACL_GRAPH
 #ifdef USE_ACL_GRAPH
 struct ggml_graph_node_properties {
 struct ggml_graph_node_properties {
+    // dst tensor
     void * node_address;
     void * node_address;
-    ggml_op node_op;
     int64_t ne[GGML_MAX_DIMS];
     int64_t ne[GGML_MAX_DIMS];
     size_t nb[GGML_MAX_DIMS];
     size_t nb[GGML_MAX_DIMS];
+
+    // src tensor
     void * src_address[GGML_MAX_SRC];
     void * src_address[GGML_MAX_SRC];
+    int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS];
+    size_t  src_nb[GGML_MAX_SRC][GGML_MAX_DIMS];
+
+    // op
+    ggml_op node_op;
     int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
     int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
 };
 };
 
 

+ 37 - 11
ggml/src/ggml-cann/ggml-cann.cpp

@@ -2186,7 +2186,15 @@ static void add_lru_matched_graph_node_properties(
         std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
         std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
 
 
         for (int src = 0; src < GGML_MAX_SRC; ++src) {
         for (int src = 0; src < GGML_MAX_SRC; ++src) {
-            prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr;
+            if (node->src[src]) {
+                prop.src_address[src] = node->src[src]->data;
+                std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
+                std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
+            } else {
+                prop.src_address[src] = nullptr;
+                std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
+                std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
+            }
         }
         }
 
 
         memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
         memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
@@ -2206,14 +2214,18 @@ static void add_lru_matched_graph_node_properties(
  * @param graph_node_properties The stored properties of a CANN graph node.
  * @param graph_node_properties The stored properties of a CANN graph node.
  * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
  * @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
  */
  */
-static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
+static bool ggml_graph_node_has_matching_properties(
+        ggml_tensor * node,
+        ggml_graph_node_properties * graph_node_properties) {
     if (node->data != graph_node_properties->node_address &&
     if (node->data != graph_node_properties->node_address &&
-           node->op != GGML_OP_VIEW) {
+            node->op != GGML_OP_VIEW) {
         return false;
         return false;
     }
     }
+
     if (node->op != graph_node_properties->node_op) {
     if (node->op != graph_node_properties->node_op) {
         return false;
         return false;
     }
     }
+
     for (int i = 0; i < GGML_MAX_DIMS; i++) {
     for (int i = 0; i < GGML_MAX_DIMS; i++) {
         if (node->ne[i] != graph_node_properties->ne[i]) {
         if (node->ne[i] != graph_node_properties->ne[i]) {
             return false;
             return false;
@@ -2222,17 +2234,31 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
             return false;
             return false;
         }
         }
     }
     }
+
     for (int i = 0; i < GGML_MAX_SRC; i++) {
     for (int i = 0; i < GGML_MAX_SRC; i++) {
-        if (node->src[i] &&
-            node->src[i]->data != graph_node_properties->src_address[i] &&
-            node->op != GGML_OP_VIEW
-        ) {
-            return false;
+        if (node->src[i]) {
+            if (node->src[i]->data != graph_node_properties->src_address[i] &&
+                node->op != GGML_OP_VIEW) {
+                return false;
+            }
+
+            for (int d = 0; d < GGML_MAX_DIMS; d++) {
+                if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) {
+                    return false;
+                }
+                if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) {
+                    return false;
+                }
+            }
+        } else {
+            if (graph_node_properties->src_address[i] != nullptr) {
+                return false;
+            }
         }
         }
     }
     }
-    if (node->op == GGML_OP_SCALE &&
-        memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
-        return false;
+
+    if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
+        return memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
     }
     }
     return true;
     return true;
 }
 }