summaryrefslogtreecommitdiff
path: root/ggml-cuda/common.cuh
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda/common.cuh')
-rw-r--r--ggml-cuda/common.cuh40
1 files changed, 40 insertions, 0 deletions
diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh
index b2627b7b..a4197f11 100644
--- a/ggml-cuda/common.cuh
+++ b/ggml-cuda/common.cuh
@@ -19,6 +19,7 @@
#include <cassert>
#include <cfloat>
#include <string>
+#include <vector>
#if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
@@ -526,6 +527,43 @@ struct ggml_tensor_extra_gpu {
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
};
+
+#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)
+#define USE_CUDA_GRAPH
+#endif
+
+struct ggml_graph_node_properties {
+ void * node_address;
+ ggml_op node_op;
+ int64_t ne[GGML_MAX_DIMS];
+ size_t nb[GGML_MAX_DIMS];
+ void * src_address[GGML_MAX_SRC];
+};
+
+struct ggml_cuda_graph {
+#ifdef USE_CUDA_GRAPH
+ ~ggml_cuda_graph() {
+ if (instance != nullptr) {
+ CUDA_CHECK(cudaGraphExecDestroy(instance));
+ }
+ if (graph != nullptr) {
+ CUDA_CHECK(cudaGraphDestroy(graph));
+ }
+ }
+ cudaGraph_t graph = nullptr;
+ cudaGraphExec_t instance = nullptr;
+ size_t num_nodes = 0;
+ std::vector<cudaGraphNode_t> nodes;
+ std::vector<cudaKernelNodeParams> params;
+ bool disable_due_to_gpu_arch = false;
+ bool disable_due_to_too_many_updates = false;
+ bool disable_due_to_failed_graph_capture = false;
+ int number_consecutive_updates = 0;
+ std::vector<ggml_graph_node_properties> ggml_graph_properties;
+ std::vector<char **> updated_kernel_arg;
+#endif
+};
+
struct ggml_backend_cuda_context {
int device;
std::string name;
@@ -534,6 +572,8 @@ struct ggml_backend_cuda_context {
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
+ std::unique_ptr<ggml_cuda_graph> cuda_graph;
+
explicit ggml_backend_cuda_context(int device) :
device(device),
name(GGML_CUDA_NAME + std::to_string(device)) {