diff options
Diffstat (limited to 'ggml-cuda/common.cuh')
-rw-r--r-- | ggml-cuda/common.cuh | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index b2627b7b..a4197f11 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -19,6 +19,7 @@ #include <cassert> #include <cfloat> #include <string> +#include <vector> #if defined(GGML_USE_HIPBLAS) #include <hip/hip_runtime.h> @@ -526,6 +527,43 @@ struct ggml_tensor_extra_gpu { cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs }; + +#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS) +#define USE_CUDA_GRAPH +#endif + +struct ggml_graph_node_properties { + void * node_address; + ggml_op node_op; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; + void * src_address[GGML_MAX_SRC]; +}; + +struct ggml_cuda_graph { +#ifdef USE_CUDA_GRAPH + ~ggml_cuda_graph() { + if (instance != nullptr) { + CUDA_CHECK(cudaGraphExecDestroy(instance)); + } + if (graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(graph)); + } + } + cudaGraph_t graph = nullptr; + cudaGraphExec_t instance = nullptr; + size_t num_nodes = 0; + std::vector<cudaGraphNode_t> nodes; + std::vector<cudaKernelNodeParams> params; + bool disable_due_to_gpu_arch = false; + bool disable_due_to_too_many_updates = false; + bool disable_due_to_failed_graph_capture = false; + int number_consecutive_updates = 0; + std::vector<ggml_graph_node_properties> ggml_graph_properties; + std::vector<char **> updated_kernel_arg; +#endif +}; + struct ggml_backend_cuda_context { int device; std::string name; @@ -534,6 +572,8 @@ struct ggml_backend_cuda_context { cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } }; cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; + std::unique_ptr<ggml_cuda_graph> cuda_graph; + explicit ggml_backend_cuda_context(int device) : device(device), name(GGML_CUDA_NAME + std::to_string(device)) { |