|
@@ -2360,6 +2360,21 @@ static enum ggml_status ggml_backend_cann_graph_compute(
|
|
|
bool use_cann_graph = true;
|
|
bool use_cann_graph = true;
|
|
|
bool cann_graph_update_required = false;
|
|
bool cann_graph_update_required = false;
|
|
|
|
|
|
|
|
|
|
+ static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
|
|
|
|
|
+ if (!prefill_use_graph) {
|
|
|
|
|
+ // Do not use acl_graph for prefill.
|
|
|
|
|
+ for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
|
|
|
+ ggml_tensor * node = cgraph->nodes[i];
|
|
|
|
|
+ // TODO: Optimize here. Currently, we can only
|
|
|
|
|
+ // get seq_len by FA's input.
|
|
|
|
|
+ if (node->op == GGML_OP_FLASH_ATTN_EXT) {
|
|
|
|
|
+ // Q -> src[0], shape: [B, S, N, D]
|
|
|
|
|
+ use_cann_graph = (node->src[0]->ne[1] == 1);
|
|
|
|
|
+ break;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
if (!cann_ctx->acl_graph_mode) {
|
|
if (!cann_ctx->acl_graph_mode) {
|
|
|
use_cann_graph = false;
|
|
use_cann_graph = false;
|
|
|
}
|
|
}
|