summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
authoragray3 <agray3@users.noreply.github.com>2024-06-04 21:06:49 +0100
committerGitHub <noreply@github.com>2024-06-04 22:06:49 +0200
commitb90dc566c1c615289b05b50d61680f23744a21e7 (patch)
tree99efb9d3f2075147d9f1855914fe82442d11f88e /ggml-cuda.cu
parent1442677f92e45a475be7b4d056e3633d1d6f813b (diff)
Allow number of nodes in CUDA graph to change (#7738)
Previously the code would have failed to cope in the case that the number of nodes changes in an existing CUDA graph. This fixes the issue by removing an unnecessary conditional.
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu6
1 files changed, 2 insertions, 4 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index daaa0cd6..c81c6a0d 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -2702,10 +2702,8 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
if (cuda_graph_update_required) {
// Extract nodes from graph
- if (cuda_ctx->cuda_graph->num_nodes == 0) {
- // First call with null argument gets number of nodes in graph
- CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
- }
+ // First call with null argument gets number of nodes in graph
+ CUDA_CHECK(cudaGraphGetNodes(cuda_ctx->cuda_graph->graph, nullptr, &cuda_ctx->cuda_graph->num_nodes));
// Subsequent call with non-null argument gets nodes
cuda_ctx->cuda_graph->nodes.resize(cuda_ctx->cuda_graph->num_nodes);
cuda_ctx->cuda_graph->params.resize(cuda_ctx->cuda_graph->num_nodes);