|
|
@@ -3696,6 +3696,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
|
|
}
|
|
|
|
|
|
static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ctx) {
|
|
|
+
|
|
|
#ifdef USE_CUDA_GRAPH
|
|
|
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
|
|
|
|
|
|
@@ -3736,17 +3737,15 @@ static bool ggml_cuda_set_cuda_graph_enabled(ggml_backend_cuda_context * cuda_ct
|
|
|
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
|
|
|
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
|
|
|
|
|
|
+ ggml_cuda_set_device(cuda_ctx->device);
|
|
|
+
|
|
|
bool use_cuda_graph = false;
|
|
|
bool cuda_graph_update_required = false;
|
|
|
|
|
|
// graph_optimize calls set_cuda_graph_enabled, in-case it not called (i.e. graph_compute is directly called)
|
|
|
// we call it here instead.
|
|
|
#ifdef USE_CUDA_GRAPH
|
|
|
- if (!cuda_ctx->cuda_graph) {
|
|
|
- use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);
|
|
|
- } else {
|
|
|
- use_cuda_graph = cuda_ctx->cuda_graph && cuda_ctx->cuda_graph->cuda_graphs_enabled;
|
|
|
- }
|
|
|
+ use_cuda_graph = ggml_cuda_set_cuda_graph_enabled(cuda_ctx);
|
|
|
|
|
|
if (use_cuda_graph) {
|
|
|
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
|
|
|
@@ -3762,6 +3761,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
|
|
|
|
|
|
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
|
|
|
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
|
|
|
+ cuda_ctx->cuda_graph->cuda_graphs_enabled = false;
|
|
|
#ifndef NDEBUG
|
|
|
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
|
|
|
#endif
|