Răsfoiți Sursa

cuda : enable CUDA Graph on CUDA Toolkit < 12.x (#12394)

* Enable CUDA Graph on CTK < 12.x

`cudaGraphExecUpdate` API was changed on 12.x. For this reason CUDA graph support was disabled on older CUDA toolkit. This change enables CUDA support in CTK version < 12.x by using older API if CTK < 12.x.

* Fix compilation errors with MUSA

* Disable CUDA Graph for MUSA
Gaurav Garg 10 luni în urmă
părinte
comite
b1b132efcb

+ 1 - 1
ggml/src/ggml-cuda/common.cuh

@@ -678,7 +678,7 @@ struct ggml_tensor_extra_gpu {
 };
 
 
-#if ((CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)) || defined(GGML_HIP_GRAPHS)
+#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS))
 #define USE_CUDA_GRAPH
 #endif
 

+ 7 - 5
ggml/src/ggml-cuda/ggml-cuda.cu

@@ -2610,13 +2610,15 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
 
 static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
 
+#if CUDART_VERSION >= 12000
     cudaGraphExecUpdateResultInfo result_info;
-#ifdef __HIP_PLATFORM_AMD__
-    hipGraphNode_t errorNode;
-    hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
-#else
     cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
-#endif
+#else
+    cudaGraphNode_t errorNode;
+    cudaGraphExecUpdateResult result_info;
+    cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
+#endif // CUDART_VERSION >= 12000
+
     if (stat == cudaErrorGraphExecUpdateFailure) {
 #ifndef NDEBUG
         GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);

+ 1 - 1
ggml/src/ggml-cuda/vendors/hip.h

@@ -112,7 +112,7 @@
 #define cudaGraphExecDestroy hipGraphExecDestroy
 #define cudaGraphLaunch hipGraphLaunch
 #define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
-#define cudaGraphExecUpdateResultInfo hipGraphExecUpdateResult
+#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
 #define cudaGraphNodeType hipGraphNodeType
 #define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
 #define cudaGraphInstantiate hipGraphInstantiate

+ 2 - 1
ggml/src/ggml-cuda/vendors/musa.h

@@ -119,7 +119,7 @@
 #define cudaGraphExecDestroy musaGraphExecDestroy
 #define cudaGraphExec_t musaGraphExec_t
 #define cudaGraphExecUpdate musaGraphExecUpdate
-#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
+#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
 #define cudaGraphGetNodes musaGraphGetNodes
 #define cudaGraphInstantiate musaGraphInstantiate
 #define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
@@ -132,6 +132,7 @@
 #define cudaGraph_t musaGraph_t
 #define cudaKernelNodeParams musaKernelNodeParams
 #define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
+#define cudaStreamBeginCapture musaStreamBeginCapture
 #define cudaStreamEndCapture musaStreamEndCapture
 
 typedef mt_bfloat16 nv_bfloat16;

+ 0 - 4
ggml/src/ggml-musa/CMakeLists.txt

@@ -67,10 +67,6 @@ if (MUSAToolkit_FOUND)
     add_compile_definitions(GGML_USE_MUSA)
     add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
 
-    if (GGML_CUDA_GRAPHS)
-        add_compile_definitions(GGML_CUDA_USE_GRAPHS)
-    endif()
-
     if (GGML_CUDA_FORCE_MMQ)
         add_compile_definitions(GGML_CUDA_FORCE_MMQ)
     endif()