Просмотр исходного кода

ggml-cuda: add stricter checking for fusion (#17568)

* ggml-cuda: make conditions for fusion more explicit

* ggml-cuda: remove size check as std::equal already does it
Aman Gupta 1 месяц назад
Родитель
Сommit
2e7ef98f18
1 измененных файлов с 15 добавлено и 11 удалено
  1. 15 11
      ggml/src/ggml-cuda/ggml-cuda.cu

+ 15 - 11
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -3050,7 +3050,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
     std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
         ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
 
-    if (ops.size() == topk_moe_ops_with_norm.size() &&
+    const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
+                             const std::initializer_list<enum ggml_op> & list2) {
+        return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
+    };
+
+    if (is_equal(topk_moe_ops_with_norm, ops) &&
         ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
         ggml_tensor * softmax = cgraph->nodes[node_idx];
         ggml_tensor * weights = cgraph->nodes[node_idx + 9];
@@ -3060,8 +3065,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
         }
     }
 
-    if (ops.size() == topk_moe_ops.size() &&
-        ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
+    if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
         ggml_tensor * softmax = cgraph->nodes[node_idx];
         ggml_tensor * weights = cgraph->nodes[node_idx + 4];
         if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
@@ -3069,7 +3073,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
         }
     }
 
-    if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
+    if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
         ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
         ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
         ggml_tensor * weights = cgraph->nodes[node_idx + 5];
@@ -3085,9 +3089,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
     std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
     std::initializer_list<enum ggml_op> mul_mat_glu_ops    = { GGML_OP_MUL_MAT,    GGML_OP_MUL_MAT,    GGML_OP_GLU };
 
-    if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) ||
-                            ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) {
-
+    if ((is_equal(mul_mat_bias_glu_ops, ops) || is_equal(mul_mat_id_bias_glu_ops, ops)) &&
+        ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 4 })) {
         const ggml_tensor * ffn_gate      = cgraph->nodes[node_idx];
         const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];
         const ggml_tensor * ffn_up        = cgraph->nodes[node_idx + 2];
@@ -3099,9 +3102,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
         }
     }
 
-    if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) ||
-                            ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) {
-
+    if ((is_equal(mul_mat_id_glu_ops, ops) || is_equal(mul_mat_glu_ops, ops)) &&
+        ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
         const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
         const ggml_tensor * ffn_up   = cgraph->nodes[node_idx + 1];
         const ggml_tensor * glu      = cgraph->nodes[node_idx + 2];
@@ -3111,7 +3113,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
         }
     }
 
-    if (ops.size() == 3 && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
+    std::initializer_list<enum ggml_op> rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS };
+
+    if (is_equal(rope_set_rows_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
         const ggml_tensor * rope     = cgraph->nodes[node_idx];
         const ggml_tensor * view     = cgraph->nodes[node_idx + 1];
         const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2];