Selaa lähdekoodia

CUDA: add stream-based concurrency (#16991)

* CUDA: add stream-based concurrency

* HIP: fix hipStreamWaitEvent define and nodiscard warnings

* ggml-cuda: fix fusion inside stream

* ggml-cuda: fix bug w.r.t first stream launch

* ggml-cuda: format

* ggml-cuda: improve assert message

* ggml-cuda: use lambda instead of duplicating code

* ggml-cuda: add some more comments

* ggml-cuda: add more detailed comments about concurrency

* ggml-cuda: rename + remove unused var

* ggml-cuda: fix condition for stream launch

* ggml-cuda: address review comments, add destructor

* common.cuh: add is_valid for concurrent events

* common.cuh: make comment better

* update comment

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* update comment

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

* common.cuh: fix lower_bound condition + remove join_node data from write_ranges

* ggml-cuda: fix overlap condition + shadowing parameter

---------

Co-authored-by: Carl Philipp Klemm <carl@uvos.xyz>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Aman Gupta 1 kuukausi sitten
vanhempi
sitoutus
c7af376c29
3 muutettua tiedostoa jossa 469 lisäystä ja 14 poistoa
  1. 162 8
      ggml/src/ggml-cuda/common.cuh
  2. 306 5
      ggml/src/ggml-cuda/ggml-cuda.cu
  3. 1 1
      ggml/src/ggml-cuda/vendors/hip.h

+ 162 - 8
ggml/src/ggml-cuda/common.cuh

@@ -21,10 +21,12 @@
 #include "ggml-common.h"
 
 #include <array>
+#include <algorithm>
 #include <cassert>
 #include <cfloat>
 #include <cstdio>
 #include <string>
+#include <unordered_map>
 #include <vector>
 
 #if defined(GGML_USE_HIP)
@@ -980,6 +982,154 @@ struct ggml_cuda_graph {
 #endif
 };
 
+struct ggml_cuda_concurrent_event {
+    std::vector<cudaEvent_t> join_events;
+    cudaEvent_t              fork_event = nullptr;
+
+    int                                          n_streams = 0;
+    std::unordered_map<const ggml_tensor *, int> stream_mapping;
+
+    const ggml_tensor * join_node;
+
+    ggml_cuda_concurrent_event() = default;
+
+    ggml_cuda_concurrent_event(const ggml_cuda_concurrent_event &) = delete;
+    ggml_cuda_concurrent_event & operator=(const ggml_cuda_concurrent_event &) = delete;
+
+    explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) {
+        join_events.resize(n_streams);
+
+        for (size_t i = 0; i < join_events.size(); ++i) {
+            CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming));
+        }
+
+        CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming));
+    }
+
+    ggml_cuda_concurrent_event(ggml_cuda_concurrent_event && other) noexcept
+    : join_events(std::move(other.join_events))
+    , fork_event(other.fork_event)
+    , n_streams(other.n_streams)
+    , stream_mapping(std::move(other.stream_mapping))
+    , join_node(other.join_node) {
+        other.fork_event = nullptr;
+    }
+
+    // 1. check if any branches write to overlapping memory ranges (except the join node)
+    // 2. check whether all srcs are either within the branch or outside the nodes covered by ggml_cuda_concurrent_event
+    // we assume all nodes have the same buffer
+    bool is_valid() const {
+        std::vector<std::vector<std::pair<int64_t, int64_t>>> write_ranges;
+        write_ranges.resize(n_streams);
+
+        // get join_node's memory range to exclude from overlap checking.
+        // multiple nodes can use join_node's buffer; we synchronize on the join node.
+        const ggml_tensor * join_t     = join_node->view_src ? join_node->view_src : join_node;
+        const int64_t       join_start = (int64_t) join_t->data;
+        const int64_t       join_end   = join_start + ggml_nbytes(join_t);
+
+        for (const auto & [tensor, stream] : stream_mapping) {
+            const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
+            const int64_t       t_start = (int64_t) t->data;
+            const int64_t       t_end   = t_start + ggml_nbytes(t);
+
+            // skip tensors that overlap with join_node's buffer.
+            if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
+                continue;
+            }
+
+            // concurrent streams begin from 1
+            write_ranges[stream - 1].emplace_back(t_start, t_end);
+        }
+
+        for (int i = 0; i < n_streams; ++i) {
+            // sorts first by start then by end of write range
+            std::sort(write_ranges[i].begin(), write_ranges[i].end());
+        }
+
+        bool writes_overlap = false;
+        bool dependent_srcs = false;
+        for (const auto & [tensor, stream] : stream_mapping) {
+            const ggml_tensor * t = tensor->view_src ? tensor->view_src : tensor;
+            const int64_t       t_start = (int64_t) t->data;
+            const int64_t       t_end   = t_start + ggml_nbytes(t);
+
+            // skip tensors that overlap with join_node's buffer
+            if ((t_start <= join_start && join_start < t_end) || (join_start <= t_start && t_start < join_end)) {
+                continue;
+            }
+
+            // check if this buffer's write data overlaps with another stream's
+            std::pair<int64_t, int64_t> data_range = std::make_pair(t_start, t_end);
+            for (int i = 0; i < n_streams; ++i) {
+                if (i == stream - 1) {
+                    continue;
+                }
+                auto it = std::lower_bound(write_ranges[i].begin(), write_ranges[i].end(), data_range);
+
+                if (it != write_ranges[i].end()) {
+                    const std::pair<int64_t, int64_t> & other = *it;
+
+                    // std::lower_bound returns the first element where other >= data_range (lexicographically).
+                    // This guarantees other.first >= data_range.first.
+                    // Therefore, overlap occurs iff other.first < data_range.second
+                    // (i.e., the other range starts before this range ends).
+                    if (other.first < data_range.second) {
+                        GGML_LOG_DEBUG("Writes overlap for %s", tensor->name);
+                        writes_overlap = true;
+                        break;
+                    }
+                }
+            }
+
+            //check if all srcs are either in branch or don't have a branch
+            for (int i = 0; i < GGML_MAX_SRC; ++i) {
+                if (!tensor->src[i]) {
+                    continue;
+                }
+
+                auto it = stream_mapping.find(tensor->src[i]);
+
+                if (it == stream_mapping.end()) {
+                    continue;
+                }
+
+                if (it->second != stream) {
+                    dependent_srcs = true;
+                    break;
+                }
+            }
+
+            if (dependent_srcs || writes_overlap) {
+                break;
+            }
+        }
+
+        return !writes_overlap && !dependent_srcs;
+    }
+
+    ~ggml_cuda_concurrent_event() {
+        if (fork_event != nullptr) {
+            CUDA_CHECK(cudaEventDestroy(fork_event));
+        }
+        for (cudaEvent_t e : join_events) {
+            if (e != nullptr) {
+                CUDA_CHECK(cudaEventDestroy(e));
+            }
+        }
+    }
+};
+
+struct ggml_cuda_stream_context {
+    std::vector<const ggml_tensor *>                                    original_nodes;
+    std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> concurrent_events;
+
+    void reset() {
+        original_nodes.clear();
+        concurrent_events.clear();
+    }
+};
+
 struct ggml_backend_cuda_context {
     int device;
     std::string name;
@@ -990,11 +1140,15 @@ struct ggml_backend_cuda_context {
 
     std::unique_ptr<ggml_cuda_graph> cuda_graph;
 
+    int curr_stream_no = 0;
+
     explicit ggml_backend_cuda_context(int device) :
         device(device),
         name(GGML_CUDA_NAME + std::to_string(device)) {
     }
 
+    ggml_cuda_stream_context concurrent_stream_context;
+
     ~ggml_backend_cuda_context();
 
     cudaStream_t stream(int device, int stream) {
@@ -1005,9 +1159,9 @@ struct ggml_backend_cuda_context {
         return streams[device][stream];
     }
 
-    cudaStream_t stream() {
-        return stream(device, 0);
-    }
+    cudaStream_t stream() { return stream(device, curr_stream_no); }
+
+    ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; }
 
     cublasHandle_t cublas_handle(int device) {
         if (cublas_handles[device] == nullptr) {
@@ -1023,15 +1177,15 @@ struct ggml_backend_cuda_context {
     }
 
     // pool
-    std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
+    std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];
 
-    static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
+    static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);
 
     ggml_cuda_pool & pool(int device) {
-        if (pools[device] == nullptr) {
-            pools[device] = new_pool_for_device(device);
+        if (pools[device][curr_stream_no] == nullptr) {
+            pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no);
         }
-        return *pools[device];
+        return *pools[device][curr_stream_no];
     }
 
     ggml_cuda_pool & pool() {

+ 306 - 5
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -522,7 +522,8 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
 };
 #endif // defined(GGML_USE_VMM)
 
-std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
+std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int                  device,
+                                                                               [[maybe_unused]] int stream_no) {
 #if defined(GGML_USE_VMM)
     if (ggml_cuda_info().devices[device].vmm) {
         return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
@@ -3200,18 +3201,83 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
     // flag used to determine whether it is an integrated_gpu
     const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
 
+    ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();
+    bool                         is_concurrent_event_active = false;
+    ggml_cuda_concurrent_event * concurrent_event           = nullptr;
+    bool                         should_launch_concurrent_events = false;
+
+    const auto try_launch_concurrent_event = [&](const ggml_tensor * node) {
+        if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) {
+            concurrent_event = &stream_ctx.concurrent_events[node];
+
+            is_concurrent_event_active = true;
+
+            GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name);
+
+            cudaStream_t main_stream = cuda_ctx->stream();  // this should be stream 0
+            GGML_ASSERT(cuda_ctx->curr_stream_no == 0);
+            CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream));
+
+            for (int i = 1; i <= concurrent_event->n_streams; ++i) {
+                cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i);
+                CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event));
+            }
+        }
+    };
+
     while (!graph_evaluated_or_captured) {
         // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
         // With the use of CUDA graphs, the execution will be performed by the graph launch.
         if (!use_cuda_graph || cuda_graph_update_required) {
-
             [[maybe_unused]] int prev_i = 0;
 
+            if (stream_ctx.concurrent_events.size() > 0) {
+                should_launch_concurrent_events = true;
+                for (const auto & [tensor, event] : stream_ctx.concurrent_events) {
+                    should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid();
+                }
+            }
+            if (should_launch_concurrent_events) {
+                //Restore the original graph to enable fusion within the streams
+                cgraph->nodes   = const_cast<ggml_tensor **>(stream_ctx.original_nodes.data());
+                cgraph->n_nodes = (int) stream_ctx.original_nodes.size();
+            }
+
             for (int i = 0; i < cgraph->n_nodes; i++) {
                 ggml_tensor * node = cgraph->nodes[i];
+                if (is_concurrent_event_active) {
+                    GGML_ASSERT(concurrent_event);
+
+                    if (node == concurrent_event->join_node) {
+                        cuda_ctx->curr_stream_no = 0;
+                        for (int i = 1; i <= concurrent_event->n_streams; ++i) {
+                            // Wait on join events of forked streams in the main stream
+                            CUDA_CHECK(cudaEventRecord(concurrent_event->join_events[i - 1],
+                                                       cuda_ctx->stream(cuda_ctx->device, i)));
+                            CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->join_events[i - 1]));
+                        }
+
+                        is_concurrent_event_active = false;
+                        concurrent_event           = nullptr;
+                    } else {
+                        GGML_ASSERT (concurrent_event->stream_mapping.find(node) != concurrent_event->stream_mapping.end());
+                        cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
+                        GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
+                    }
+                } else if (i - prev_i > 1) {
+                    //the previous node was fused
+                    const ggml_tensor * prev_node = cgraph->nodes[i - 1];
+                    try_launch_concurrent_event(prev_node);
+
+                    if (is_concurrent_event_active) {
+                        cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node];
+                        GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name);
+                    }
+                }
+                prev_i = i;
+
 #ifdef GGML_CUDA_DEBUG
                 const int nodes_fused = i - prev_i - 1;
-                prev_i = i;
                 if (nodes_fused > 0) {
                     GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused);
                 }
@@ -3221,6 +3287,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                     continue;
                 }
 
+
+                // start of fusion operations
                 static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
                 if (!disable_fusion) {
 
@@ -3513,13 +3581,17 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
                 }
 #else
                 GGML_UNUSED(integrated);
-#endif // NDEBUG
+#endif  // NDEBUG
 
                 bool ok = ggml_cuda_compute_forward(*cuda_ctx, node);
                 if (!ok) {
                     GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
                 }
                 GGML_ASSERT(ok);
+
+                if (!is_concurrent_event_active) {
+                    try_launch_concurrent_event(node);
+               }
             }
         }
 
@@ -3659,6 +3731,235 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
     }
 }
 
+static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
+    ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
+
+    static bool enable_graph_optimization = [] {
+        const char * env = getenv("GGML_CUDA_GRAPH_OPT");
+        return env != nullptr && atoi(env) == 1;
+    }();
+
+    if (!enable_graph_optimization) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_backend_cuda_get_device_count() == 1 && "compute graph optimization is only supported on single GPU in the CUDA backend");
+    GGML_LOG_DEBUG("Optimizing CUDA graph %p with %d nodes\n", cgraph->nodes, cgraph->n_nodes);
+
+    ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context();
+    stream_context.reset();
+
+    // number of out-degrees for a particular node
+    std::unordered_map<const ggml_tensor *, int> fan_out;
+    // reverse mapping of node to index in the cgraph
+    std::unordered_map<const ggml_tensor *, int> node_indices;
+
+    const auto & is_noop = [](const ggml_tensor * node) -> bool {
+        return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE ||
+               node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
+    };
+
+    const auto & depends_on = [](const ggml_tensor * dst, const ggml_tensor * src) -> bool {
+        for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
+            if (dst->src[s] == src) {
+                return true;
+            }
+        }
+        // implicit dependency if they view the same tensor
+        const ggml_tensor * dst2 = dst->view_src ? dst->view_src : dst;
+        const ggml_tensor * src2 = src->view_src ? src->view_src : src;
+        if (dst2 == src2) {
+            return true;
+        }
+        return false;
+    };
+
+    for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) {
+        const ggml_tensor * node = cgraph->nodes[node_idx];
+        node_indices[node]       = node_idx;
+
+        if (is_noop(node)) {
+            continue;
+        }
+        for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
+            const ggml_tensor * src = cgraph->nodes[node_idx]->src[src_idx];
+            //TODO: check why nrows > 1 fails
+            if (node && !is_noop(node) && ggml_nrows(node) <= 1) {
+                fan_out[src] += 1;
+            }
+        }
+    }
+
+    // Target Q, K, V for concurrency
+    // this is a more general way to find nodes which can be candidates for concurrency (although it has not been tested for anything else):
+    // 1. find fan-out (fork) nodes where the same input is used at least N times (in QKV, it would be "attn-norm")
+    // 2. find the join node, where 2 or more of the outputs are required (in QKV, this would "KQ" or "flash-attn")
+    // 3. account for all branches from the fork to the join
+    // 4. To extend lifetimes of the tensors, we interleave the branches (see below for more details)
+    // 5. save the original cgraph and restore it in graph_compute, to enable fusion within streams
+    // See discussion: https://github.com/ggml-org/llama.cpp/pull/16991#issuecomment-3522620030
+
+    const int min_fan_out = 3;
+    const int max_fan_out = 3;
+
+    // store {fork_idx, join_idx}
+    std::vector<std::pair<int, int>> concurrent_node_ranges;
+
+    // save the original nodes
+    std::vector<const ggml_tensor *> original_nodes;
+    original_nodes.reserve(cgraph->n_nodes);
+    for (int i = 0; i < cgraph->n_nodes; ++i) {
+        original_nodes.push_back(cgraph->nodes[i]);
+    }
+    cuda_ctx->stream_context().original_nodes = std::move(original_nodes);
+
+    for (const auto & [root_node, count] : fan_out) {
+        if (count >= min_fan_out && count <= max_fan_out) {
+            const int root_node_idx = node_indices[root_node];
+
+            bool is_part_of_event = false;
+            for (const auto & [start, end] : concurrent_node_ranges) {
+                if (root_node_idx >= start && root_node_idx <= end) {
+                    is_part_of_event = true;
+                }
+            }
+
+            if (is_part_of_event) {
+                continue;
+            }
+
+            std::vector<std::vector<const ggml_tensor *>> nodes_per_branch;
+            for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
+                const ggml_tensor * node = cgraph->nodes[i];
+                if (!is_noop(node) && depends_on(node, root_node)) {
+                    nodes_per_branch.push_back({ node });
+                }
+            }
+
+            GGML_ASSERT(nodes_per_branch.size() == (size_t) count);
+
+            //find the join point
+            const ggml_tensor * join_node = nullptr;
+
+            const auto & belongs_to_branch = [&](const ggml_tensor *                      node,
+                                                 const std::vector<const ggml_tensor *> & branch) -> bool {
+                for (const ggml_tensor * n : branch) {
+                    if (depends_on(node, n)) {
+                        return true;
+                    }
+                }
+                return false;
+            };
+
+            for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) {
+                const ggml_tensor * curr_node = cgraph->nodes[i];
+
+                int num_joins = 0;
+                for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
+                    if (belongs_to_branch(curr_node, nodes_per_branch[branch_idx])) {
+                        num_joins++;
+                    }
+                }
+
+                if (num_joins >= 2) {
+                    join_node = curr_node;
+                    break;
+                }
+
+                bool found_branch = false;
+                for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
+                    std::vector<const ggml_tensor *> & branch_vec = nodes_per_branch[branch_idx];
+                    if (belongs_to_branch(curr_node, branch_vec)) {
+                        //continue accumulating
+                        if (std::find(branch_vec.begin(), branch_vec.end(), curr_node) == branch_vec.end()) {
+                            branch_vec.push_back(curr_node);
+                        }
+                        found_branch = true;
+                    }
+                }
+
+                if (!found_branch && is_noop(curr_node)) {
+                    // we can put it in any branch because it will be ignored
+                    nodes_per_branch[0].push_back({ curr_node });
+                }
+            }
+
+            if (join_node) {
+                //Create ggml_cuda_concurrent_event
+                ggml_cuda_concurrent_event concurrent_event(nodes_per_branch.size());
+                concurrent_event.join_node = join_node;
+
+                for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) {
+                    for (const ggml_tensor * n : nodes_per_branch[branch_idx]) {
+                        concurrent_event.stream_mapping[n] = branch_idx + 1;
+                    }
+                }
+
+                int fork_node_idx = node_indices[root_node];
+                int join_node_idx = node_indices[join_node];
+
+                int       current_branch_idx = 0;
+                int       current_node_idx   = fork_node_idx + 1;
+                const int n_branches         = nodes_per_branch.size();
+
+                int total_branch_nodes = 0;
+                for (std::vector<const ggml_tensor *> branch_nodes : nodes_per_branch) {
+                    total_branch_nodes += branch_nodes.size();
+                }
+
+                // there are other nodes in the middle which are unaccounted for
+                // usually (cpy) nodes, then ignore this fork
+                if (join_node_idx - fork_node_idx - 1 != total_branch_nodes) {
+                    GGML_LOG_DEBUG(
+                        "Skipping %s because the number of nodes in the middle is not equal to the total number of "
+                        "branch nodes %d != %d\n",
+                        root_node->name, join_node_idx - fork_node_idx - 1, total_branch_nodes);
+                    continue;
+                }
+
+                std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context().concurrent_events;
+                GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end());
+                concurrent_events.emplace(root_node, std::move(concurrent_event));
+                GGML_LOG_DEBUG("Adding stream at node %s %p\n", root_node->name, root_node);
+                concurrent_node_ranges.emplace_back(fork_node_idx, join_node_idx);
+
+                // interleave tensors to extend lifetimes so that ggml graph doesn't recycle them
+                // example transformation:
+                // [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] ->
+                // [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn]
+                while (current_node_idx < join_node_idx) {
+                    std::vector<const ggml_tensor *> & branch_nodes = nodes_per_branch[current_branch_idx];
+
+                    bool has_node = false;
+                    for (std::vector<const ggml_tensor *> branch_node : nodes_per_branch) {
+                        has_node |= branch_node.size() > 0;
+                    }
+
+                    GGML_ASSERT(has_node);
+
+                    if (branch_nodes.empty()) {
+                        current_branch_idx = (current_branch_idx + 1) % n_branches;
+                        continue;
+                    }
+
+                    cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
+                    current_node_idx++;
+                    branch_nodes.erase(branch_nodes.begin());
+
+                    // append all empty nodes
+                    while (!branch_nodes.empty() && is_noop(branch_nodes.front())) {
+                        cgraph->nodes[current_node_idx] = const_cast<ggml_tensor *>(branch_nodes.front());
+                        current_node_idx++;
+                        branch_nodes.erase(branch_nodes.begin());
+                    }
+
+                    current_branch_idx = (current_branch_idx + 1) % n_branches;
+                }
+            }
+        }
+    }
+}
+
 static const ggml_backend_i ggml_backend_cuda_interface = {
     /* .get_name                = */ ggml_backend_cuda_get_name,
     /* .free                    = */ ggml_backend_cuda_free,
@@ -3673,7 +3974,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
     /* .graph_compute           = */ ggml_backend_cuda_graph_compute,
     /* .event_record            = */ ggml_backend_cuda_event_record,
     /* .event_wait              = */ ggml_backend_cuda_event_wait,
-    /* .graph_optimize          = */ NULL,
+    /* .graph_optimize          = */ ggml_backend_cuda_graph_optimize,
 };
 
 static ggml_guid_t ggml_backend_cuda_guid() {

+ 1 - 1
ggml/src/ggml-cuda/vendors/hip.h

@@ -105,7 +105,7 @@
 #define cudaStreamNonBlocking hipStreamNonBlocking
 #define cudaStreamPerThread hipStreamPerThread
 #define cudaStreamSynchronize hipStreamSynchronize
-#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
+#define cudaStreamWaitEvent hipStreamWaitEvent
 #define cudaGraphExec_t hipGraphExec_t
 #define cudaGraphNode_t hipGraphNode_t
 #define cudaKernelNodeParams hipKernelNodeParams