Browse Source

ggml-cuda: enable cuda-graphs for `n-cpu-moe` (#18934)

* ggml-cuda: add split-wise cuda graph

* add n-cpu-moe compare_llama_bench.py

* fix hip/musa builds
Aman Gupta 4 days ago
parent
commit
81ab64f3c8

+ 36 - 2
ggml/src/ggml-cuda/common.cuh

@@ -1327,10 +1327,44 @@ struct ggml_backend_cuda_context {
     cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
     cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
 
-    std::unique_ptr<ggml_cuda_graph> cuda_graph;
-
     int curr_stream_no = 0;
 
+#ifdef USE_CUDA_GRAPH
+    // Map from first_node_ptr to cuda_graph - allows multiple graphs per context
+    // when the computation is split across CPU/GPU (e.g., with --n-cpu-moe)
+    std::unordered_map<const void *, std::unique_ptr<ggml_cuda_graph>> cuda_graphs;
+
+    ggml_cuda_graph * cuda_graph(const void * first_node_ptr) {
+        auto it = cuda_graphs.find(first_node_ptr);
+        if (it == cuda_graphs.end()) {
+            cuda_graphs[first_node_ptr] = std::make_unique<ggml_cuda_graph>();
+            return cuda_graphs[first_node_ptr].get();
+        }
+        return it->second.get();
+    }
+
+    // Check if any CUDA graph is enabled for this context (used by kernels that need to know
+    // if graphs are in use without having access to the specific graph key)
+    bool any_cuda_graph_enabled() const {
+        for (const auto & [key, graph] : cuda_graphs) {
+            if (graph && graph->is_enabled()) {
+                return true;
+            }
+        }
+        return false;
+    }
+
+    // Check if any CUDA graph has an instance for this context
+    bool any_cuda_graph_has_instance() const {
+        for (const auto & [key, graph] : cuda_graphs) {
+            if (graph && graph->instance != nullptr) {
+                return true;
+            }
+        }
+        return false;
+    }
+#endif // USE_CUDA_GRAPH
+
     explicit ggml_backend_cuda_context(int device) :
         device(device),
         name(GGML_CUDA_NAME + std::to_string(device)) {

+ 57 - 38
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -2969,18 +2969,25 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
     return true;
 }
 
+static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
+    return cgraph->nodes[0];
+}
+
 static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
 
     bool res = false;
 
-    if (cuda_ctx->cuda_graph->instance == nullptr) {
+    const void * graph_key = ggml_cuda_graph_get_key(cgraph);
+    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
+
+    if (graph->instance == nullptr) {
         res = true;
     }
 
     // Check if the graph size has changed
-    if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
+    if (graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
         res = true;
-        cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
+        graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
     }
 
     // Loop over nodes in GGML graph to determine if CUDA graph update is required
@@ -2988,37 +2995,38 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
     for (int i = 0; i < cgraph->n_nodes; i++) {
         bool props_match = true;
         if (!res) {
-            props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]);
+            props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
         }
         if (!props_match) {
             res = true;
         }
-        ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]);
+        ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
     }
 
     for (int i = 0; i < cgraph->n_leafs; i++) {
-        bool props_match= true;
+        bool props_match = true;
         if (!res) {
-            props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]);
+            props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &graph->props[cgraph->n_nodes + i]);
         }
         if (!props_match) {
             res = true;
         }
-        ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
+        ggml_cuda_graph_node_set_properties(&graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
     }
 
     return res;
 }
 
-static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) {
+static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
+    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
 
 #if CUDART_VERSION >= 12000
     cudaGraphExecUpdateResultInfo result_info;
-    cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
+    cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info);
 #else
     cudaGraphNode_t errorNode;
     cudaGraphExecUpdateResult result_info;
-    cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
+    cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info);
 #endif // CUDART_VERSION >= 12000
 
     if (stat == cudaErrorGraphExecUpdateFailure) {
@@ -3029,14 +3037,14 @@ static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_c
         // The pre-existing graph exec cannot be updated due to violated constraints
         // so instead clear error and re-instantiate
         (void)cudaGetLastError();
-        CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
-        cuda_ctx->cuda_graph->instance = nullptr;
-        CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
+        CUDA_CHECK(cudaGraphExecDestroy(graph->instance));
+        graph->instance = nullptr;
+        CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
     } else {
         GGML_ASSERT(stat == cudaSuccess);
     }
 }
-#endif
+#endif // USE_CUDA_GRAPH
 
 static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
                                                 const ggml_tensor * view,
@@ -3241,7 +3249,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
     return false;
 }
 
-static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) {
+static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) {
     bool graph_evaluated_or_captured = false;
 
     // flag used to determine whether it is an integrated_gpu
@@ -3695,13 +3703,14 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
         }
 
 #ifdef USE_CUDA_GRAPH
+        ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
         if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
-            if (cuda_ctx->cuda_graph->graph != nullptr) {
-                CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
-                cuda_ctx->cuda_graph->graph = nullptr;
+            if (graph->graph != nullptr) {
+                CUDA_CHECK(cudaGraphDestroy(graph->graph));
+                graph->graph = nullptr;
             }
 
-            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
+            CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph));
             graph_evaluated_or_captured = true; // CUDA graph has been captured
 
             std::lock_guard<std::mutex> lock(ggml_cuda_lock);
@@ -3714,40 +3723,39 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
     }
 
     if (use_cuda_graph) {
-        if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
-            CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
+        ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
+        if (graph->instance == nullptr) { // Create executable graph from captured graph.
+            CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0));
         }
         if (cuda_graph_update_required) { // Update graph executable
-            ggml_cuda_graph_update_executable(cuda_ctx);
+            ggml_cuda_graph_update_executable(cuda_ctx, graph_key);
         }
         // Launch graph
-        CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
+        CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream()));
 #else
         graph_evaluated_or_captured = true;
 #endif  // USE_CUDA_GRAPH
     }
 }
 
-static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) {
+static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
 
 #ifdef USE_CUDA_GRAPH
+    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
 
-    if (cuda_ctx->cuda_graph == nullptr) {
-        cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
-    }
-
-    if (cuda_ctx->cuda_graph->graph == nullptr) {
+    if (graph->graph == nullptr) {
         if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
-            if (!cuda_ctx->cuda_graph->disable_due_to_gpu_arch) {
+            if (!graph->disable_due_to_gpu_arch) {
                 GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
             }
-            cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
+            graph->disable_due_to_gpu_arch = true;
         }
     }
 
-    return cuda_ctx->cuda_graph->is_enabled();
+    return graph->is_enabled();
 #else
     GGML_UNUSED(cuda_ctx);
+    GGML_UNUSED(graph_key);
     return false;
 #endif // USE_CUDA_GRAPH
 }
@@ -3759,15 +3767,19 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
 
     bool use_cuda_graph             = false;
     bool cuda_graph_update_required = false;
+    const void * graph_key = nullptr;
 
 #ifdef USE_CUDA_GRAPH
-    use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
+    graph_key = ggml_cuda_graph_get_key(cgraph);
+
+    use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
 
-    if (cuda_ctx->cuda_graph->is_enabled()) {
+    ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
+    if (graph->is_enabled()) {
         cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph);
         use_cuda_graph             = ggml_cuda_graph_check_compability(cgraph);
 
-        cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required);
+        graph->record_update(use_cuda_graph, cuda_graph_update_required);
     }
 #endif // USE_CUDA_GRAPH
 
@@ -3781,7 +3793,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
         CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
     }
 
-    ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required);
+    ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key);
 
     return GGML_STATUS_SUCCESS;
 }
@@ -3814,7 +3826,14 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
 static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
     ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
 
-    const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
+#ifdef USE_CUDA_GRAPH
+    const void * graph_key = ggml_cuda_graph_get_key(cgraph);
+    const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key);
+#else
+    const bool use_cuda_graph = false;
+    GGML_UNUSED(cuda_ctx);
+    GGML_UNUSED(cgraph);
+#endif
 
     static bool enable_graph_optimization = [] {
         const char * env     = getenv("GGML_CUDA_GRAPH_OPT");

+ 9 - 8
ggml/src/ggml-cuda/mean.cu

@@ -31,14 +31,15 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 #endif // USE_CUDA_GRAPH
     if ((nrows == 1) &&
 #ifdef USE_CUDA_GRAPH
-            // CUDA_GRAPHS_DISABLED
-            ((ncols > 65536) &&
-             ((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
-              ctx.cuda_graph->is_enabled())) ||
-        // CUDA_GRAPHS ENABLED
-        ((ncols > 32768) &&
-         !((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
-            ctx.cuda_graph->is_enabled()))) {
+            // Determine if CUDA graphs are effectively disabled for this context
+            // (no graph instance exists and we're not capturing, OR graphs are explicitly enabled)
+            (((ncols > 65536) &&
+              (((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) ||
+               ctx.any_cuda_graph_enabled())) ||
+            // CUDA graphs are enabled - use lower threshold
+             ((ncols > 32768) &&
+              !(((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) ||
+                ctx.any_cuda_graph_enabled())))) {
 #else
         (ncols > 65536)) {
 #endif // USE_CUDA_GRAPH

+ 3 - 3
scripts/compare-llama-bench.py

@@ -29,7 +29,7 @@ LLAMA_BENCH_DB_FIELDS = [
     "cpu_mask",     "cpu_strict",   "poll",           "type_k",     "type_v",       "n_gpu_layers",
     "split_mode",   "main_gpu",     "no_kv_offload",  "flash_attn", "tensor_split", "tensor_buft_overrides",
     "use_mmap",     "embeddings",   "no_op_offload",  "n_prompt",   "n_gen",        "n_depth",
-    "test_time",    "avg_ns",       "stddev_ns",      "avg_ts",     "stddev_ts",
+    "test_time",    "avg_ns",       "stddev_ns",      "avg_ts",     "stddev_ts",    "n_cpu_moe"
 ]
 
 LLAMA_BENCH_DB_TYPES = [
@@ -38,7 +38,7 @@ LLAMA_BENCH_DB_TYPES = [
     "TEXT",    "INTEGER", "INTEGER", "TEXT",    "TEXT",    "INTEGER",
     "TEXT",    "INTEGER", "INTEGER", "INTEGER", "TEXT",    "TEXT",
     "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER",
-    "TEXT",    "INTEGER", "INTEGER", "REAL",    "REAL",
+    "TEXT",    "INTEGER", "INTEGER", "REAL",    "REAL",    "INTEGER",
 ]
 
 # All test-backend-ops SQL fields
@@ -59,7 +59,7 @@ assert len(TEST_BACKEND_OPS_DB_FIELDS) == len(TEST_BACKEND_OPS_DB_TYPES)
 
 # Properties by which to differentiate results per commit for llama-bench:
 LLAMA_BENCH_KEY_PROPERTIES = [
-    "cpu_info", "gpu_info", "backends", "n_gpu_layers", "tensor_buft_overrides", "model_filename", "model_type",
+    "cpu_info", "gpu_info", "backends", "n_gpu_layers", "n_cpu_moe", "tensor_buft_overrides", "model_filename", "model_type",
     "n_batch", "n_ubatch", "embeddings", "cpu_mask", "cpu_strict", "poll", "n_threads", "type_k", "type_v",
     "use_mmap", "no_kv_offload", "split_mode", "main_gpu", "tensor_split", "flash_attn", "n_prompt", "n_gen", "n_depth"
 ]