Browse Source

CUDA: disable cuda graph when using n-cpu-moe (#18593)

* CUDA: disable cuda graph when using n-cpu-moe

* call ggml_cuda_set_device
Aman Gupta 3 weeks ago
parent
commit
908a9e5a1e
1 changed files with 5 additions and 5 deletions
  1. 5 5
      ggml/src/ggml-cuda/ggml-cuda.cu

+ 5 - 5
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -3696,6 +3696,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
 }
 
 static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ctx) {
+
 #ifdef USE_CUDA_GRAPH
     static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
 
@@ -3736,17 +3737,15 @@ static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ct
 static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
     ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
 
+    ggml_cuda_set_device(cuda_ctx->device);
+
     bool use_cuda_graph             = false;
     bool cuda_graph_update_required = false;
 
     // graph_optimize calls set_cuda_graph_enabled, in-case it not called (i.e. graph_compute is directly called)
     // we call it here instead.
 #ifdef USE_CUDA_GRAPH
-    if (!cuda_ctx->cuda_graph) {
-        use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);
-    } else {
-        use_cuda_graph = cuda_ctx->cuda_graph && cuda_ctx->cuda_graph->cuda_graphs_enabled;
-    }
+    use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);
 
     if (use_cuda_graph) {
         cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
@@ -3762,6 +3761,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
 
         if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
             cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
+            cuda_ctx->cuda_graph->cuda_graphs_enabled = false;
 #ifndef NDEBUG
             GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
 #endif