|
|
@@ -2186,7 +2186,15 @@ static void add_lru_matched_graph_node_properties(
|
|
|
std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
|
|
|
|
|
|
for (int src = 0; src < GGML_MAX_SRC; ++src) {
|
|
|
- prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr;
|
|
|
+ if (node->src[src]) {
|
|
|
+ prop.src_address[src] = node->src[src]->data;
|
|
|
+ std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]);
|
|
|
+ std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]);
|
|
|
+ } else {
|
|
|
+ prop.src_address[src] = nullptr;
|
|
|
+ std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0);
|
|
|
+ std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
|
|
|
@@ -2206,14 +2214,18 @@ static void add_lru_matched_graph_node_properties(
|
|
|
* @param graph_node_properties The stored properties of a CANN graph node.
|
|
|
* @return true if all fields match (excluding GGML_OP_VIEW); false otherwise.
|
|
|
*/
|
|
|
-static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
|
|
|
+static bool ggml_graph_node_has_matching_properties(
|
|
|
+ ggml_tensor * node,
|
|
|
+ ggml_graph_node_properties * graph_node_properties) {
|
|
|
if (node->data != graph_node_properties->node_address &&
|
|
|
- node->op != GGML_OP_VIEW) {
|
|
|
+ node->op != GGML_OP_VIEW) {
|
|
|
return false;
|
|
|
}
|
|
|
+
|
|
|
if (node->op != graph_node_properties->node_op) {
|
|
|
return false;
|
|
|
}
|
|
|
+
|
|
|
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
|
|
if (node->ne[i] != graph_node_properties->ne[i]) {
|
|
|
return false;
|
|
|
@@ -2222,17 +2234,31 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
|
|
|
return false;
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
|
|
- if (node->src[i] &&
|
|
|
- node->src[i]->data != graph_node_properties->src_address[i] &&
|
|
|
- node->op != GGML_OP_VIEW
|
|
|
- ) {
|
|
|
- return false;
|
|
|
+ if (node->src[i]) {
|
|
|
+ if (node->src[i]->data != graph_node_properties->src_address[i] &&
|
|
|
+ node->op != GGML_OP_VIEW) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ for (int d = 0; d < GGML_MAX_DIMS; d++) {
|
|
|
+ if (node->src[i]->ne[d] != graph_node_properties->src_ne[i][d]) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ if (node->src[i]->nb[d] != graph_node_properties->src_nb[i][d]) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ if (graph_node_properties->src_address[i] != nullptr) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
- if (node->op == GGML_OP_SCALE &&
|
|
|
- memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
|
|
|
- return false;
|
|
|
+
|
|
|
+ if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) {
|
|
|
+ return memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0;
|
|
|
}
|
|
|
return true;
|
|
|
}
|