Browse Source

CANN: Refactor `evaluate_and_capture_cann_graph` (#17333)

* CANN: Refactor `evaluate_and_capture_cann_graph`

**Description of the problem**

* `matched_graph` is obtained even if graph mode is disabled.
* End of graph capture and graph replay are unnecessarily placed in different `if` blocks.

**Proposed solution**

* Obtain `matched_graph` only if graph mode is enabled.
* Place end of graph capture and graph reply inside the same `if` block.
* Unify graph related comments.

* Remove trailing whitespace
Raul Torres 1 month ago
parent
commit
2370665e56
1 changed files with 8 additions and 7 deletions
  1. 8 7
      ggml/src/ggml-cann/ggml-cann.cpp

+ 8 - 7
ggml/src/ggml-cann/ggml-cann.cpp

@@ -2246,8 +2246,7 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
                                             bool &                      use_cann_graph,
                                             bool &                      cann_graph_update_required) {
 #ifdef USE_ACL_GRAPH
-    ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
-    if (use_cann_graph && cann_graph_update_required) {
+    if (use_cann_graph && cann_graph_update_required) {  // Begin CANN graph capture
         ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
     }
 #endif  // USE_ACL_GRAPH
@@ -2271,12 +2270,14 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
     }
 
 #ifdef USE_ACL_GRAPH
-    if (use_cann_graph && cann_graph_update_required) {  // End CANN graph capture
-        ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
-    }
-
     if (use_cann_graph) {
-        // Execute graph
+        ggml_cann_graph * matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
+
+        if (cann_graph_update_required) {  // End CANN graph capture
+            ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
+        }
+
+        // Execute CANN graph
         ACL_CHECK(aclmdlRIExecuteAsync(matched_graph->graph, cann_ctx->stream()));
     }
 #endif  // USE_ACL_GRAPH