|
@@ -2610,13 +2610,15 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
|
|
|
|
|
|
|
|
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
|
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
|
|
|
|
|
|
|
|
|
|
+#if CUDART_VERSION >= 12000
|
|
|
cudaGraphExecUpdateResultInfo result_info;
|
|
cudaGraphExecUpdateResultInfo result_info;
|
|
|
-#ifdef __HIP_PLATFORM_AMD__
|
|
|
|
|
- hipGraphNode_t errorNode;
|
|
|
|
|
- hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
|
|
|
|
|
-#else
|
|
|
|
|
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
|
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
|
|
|
-#endif
|
|
|
|
|
|
|
+#else
|
|
|
|
|
+ cudaGraphNode_t errorNode;
|
|
|
|
|
+ cudaGraphExecUpdateResult result_info;
|
|
|
|
|
+ cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
|
|
|
|
|
+#endif // CUDART_VERSION >= 12000
|
|
|
|
|
+
|
|
|
if (stat == cudaErrorGraphExecUpdateFailure) {
|
|
if (stat == cudaErrorGraphExecUpdateFailure) {
|
|
|
#ifndef NDEBUG
|
|
#ifndef NDEBUG
|
|
|
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
|
|
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
|