瀏覽代碼

ggml-cuda: reorder only relevant nodes (#17639)

Aman Gupta 1 月之前
父節點
當前提交
ed32089927
共有 2 個文件被更改,包括 62 次插入13 次删除
  1. 5 2
      ggml/src/ggml-cuda/common.cuh
  2. 57 11
      ggml/src/ggml-cuda/ggml-cuda.cu

+ 5 - 2
ggml/src/ggml-cuda/common.cuh

@@ -989,6 +989,10 @@ struct ggml_cuda_concurrent_event {
     int                                          n_streams = 0;
     std::unordered_map<const ggml_tensor *, int> stream_mapping;
 
+    // Original order of nodes in this concurrent region (before interleaving)
+    // Used to restore grouping for fusion within streams
+    std::vector<const ggml_tensor *> original_order;
+
     const ggml_tensor * join_node;
 
     ggml_cuda_concurrent_event() = default;
@@ -1011,6 +1015,7 @@ struct ggml_cuda_concurrent_event {
     , fork_event(other.fork_event)
     , n_streams(other.n_streams)
     , stream_mapping(std::move(other.stream_mapping))
+    , original_order(std::move(other.original_order))
     , join_node(other.join_node) {
         other.fork_event = nullptr;
     }
@@ -1121,11 +1126,9 @@ struct ggml_cuda_concurrent_event {
 };
 
 struct ggml_cuda_stream_context {
-    std::vector<const ggml_tensor *>                                    original_nodes;
     std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;
 
     void reset() {
-        original_nodes.clear();
         concurrent_events.clear();
     }
 };

+ 57 - 11
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -3238,9 +3238,56 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                 }
             }
             if (should_launch_concurrent_events) {
-                //Restore the original graph to enable fusion within the streams
-                cgraph->nodes   = const_cast<ggml_tensor **>(stream_ctx.original_nodes.data());
-                cgraph->n_nodes = (int) stream_ctx.original_nodes.size();
+                // Restore original node order within each concurrent region to enable fusion within streams
+
+                std::unordered_map<const ggml_tensor *, int> node_to_idx;
+                node_to_idx.reserve(cgraph->n_nodes);
+                for (int i = 0; i < cgraph->n_nodes; ++i) {
+                    node_to_idx[cgraph->nodes[i]] = i;
+                }
+
+                for (auto & [fork_node, event] : stream_ctx.concurrent_events) {
+                    // Find positions of all nodes from this event in the current graph
+                    std::vector<int> positions;
+                    positions.reserve(event.original_order.size());
+
+                    bool all_found = true;
+                    for (const ggml_tensor * orig_node : event.original_order) {
+                        auto it = node_to_idx.find(orig_node);
+                        if (it != node_to_idx.end()) {
+                            positions.push_back(it->second);
+                        } else {
+                            all_found = false;
+                            break;
+                        }
+                    }
+
+                    if (!all_found || positions.size() != event.original_order.size()) {
+                        continue;
+                    }
+
+                    // Sort positions to get contiguous range
+                    std::vector<int> sorted_positions = positions;
+                    std::sort(sorted_positions.begin(), sorted_positions.end());
+
+                    bool is_contiguous = true;
+                    for (size_t i = 1; i < sorted_positions.size(); ++i) {
+                        if (sorted_positions[i] != sorted_positions[i-1] + 1) {
+                            is_contiguous = false;
+                            break;
+                        }
+                    }
+
+                    if (!is_contiguous) {
+                        continue;
+                    }
+
+                    // Restore original order at the sorted positions
+                    int start_pos = sorted_positions[0];
+                    for (size_t i = 0; i < event.original_order.size(); ++i) {
+                        cgraph->nodes[start_pos + i] = const_cast<ggml_tensor *>(event.original_order[i]);
+                    }
+                }
             }
 
             for (int i = 0; i < cgraph->n_nodes; i++) {
@@ -3805,14 +3852,6 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
     // store {fork_idx, join_idx}
     std::vector<std::pair<int, int>> concurrent_node_ranges;
 
-    // save the original nodes
-    std::vector<const ggml_tensor *> original_nodes;
-    original_nodes.reserve(cgraph->n_nodes);
-    for (int i = 0; i < cgraph->n_nodes; ++i) {
-        original_nodes.push_back(cgraph->nodes[i]);
-    }
-    cuda_ctx->stream_context().original_nodes = std::move(original_nodes);
-
     for (const auto & [root_node, count] : fan_out) {
         if (count >= min_fan_out && count <= max_fan_out) {
             const int root_node_idx = node_indices[root_node];
@@ -3917,6 +3956,13 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
                     continue;
                 }
 
+                // Save the original order of nodes in this region before interleaving
+                // This is used later to restore grouping for fusion within streams
+                concurrent_event.original_order.reserve(total_branch_nodes);
+                for (int i = fork_node_idx + 1; i < join_node_idx; ++i) {
+                    concurrent_event.original_order.push_back(cgraph->nodes[i]);
+                }
+
                 std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context().concurrent_events;
                 GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end());
                 concurrent_events.emplace(root_node, std::move(concurrent_event));