summaryrefslogtreecommitdiff
path: root/ggml-cuda
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda')
-rw-r--r--ggml-cuda/clamp.cu1
-rw-r--r--ggml-cuda/common.cuh40
-rw-r--r--ggml-cuda/convert.cu4
-rw-r--r--ggml-cuda/cpy.cu29
-rw-r--r--ggml-cuda/cpy.cuh2
-rw-r--r--ggml-cuda/mmq.cu30
-rw-r--r--ggml-cuda/mmvq.cu6
-rw-r--r--ggml-cuda/scale.cu1
8 files changed, 84 insertions, 29 deletions
diff --git a/ggml-cuda/clamp.cu b/ggml-cuda/clamp.cu
index 379ded04..8009a3e3 100644
--- a/ggml-cuda/clamp.cu
+++ b/ggml-cuda/clamp.cu
@@ -31,5 +31,4 @@ void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
clamp_f32_cuda(src0_d, dst_d, min, max, ggml_nelements(src0), stream);
- CUDA_CHECK(cudaGetLastError());
}
diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh
index b2627b7b..a4197f11 100644
--- a/ggml-cuda/common.cuh
+++ b/ggml-cuda/common.cuh
@@ -19,6 +19,7 @@
#include <cassert>
#include <cfloat>
#include <string>
+#include <vector>
#if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
@@ -526,6 +527,43 @@ struct ggml_tensor_extra_gpu {
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
};
+
+#if (CUDART_VERSION >= 12000) && defined(GGML_CUDA_USE_GRAPHS)
+#define USE_CUDA_GRAPH
+#endif
+
+struct ggml_graph_node_properties {
+ void * node_address;
+ ggml_op node_op;
+ int64_t ne[GGML_MAX_DIMS];
+ size_t nb[GGML_MAX_DIMS];
+ void * src_address[GGML_MAX_SRC];
+};
+
+struct ggml_cuda_graph {
+#ifdef USE_CUDA_GRAPH
+ ~ggml_cuda_graph() {
+ if (instance != nullptr) {
+ CUDA_CHECK(cudaGraphExecDestroy(instance));
+ }
+ if (graph != nullptr) {
+ CUDA_CHECK(cudaGraphDestroy(graph));
+ }
+ }
+ cudaGraph_t graph = nullptr;
+ cudaGraphExec_t instance = nullptr;
+ size_t num_nodes = 0;
+ std::vector<cudaGraphNode_t> nodes;
+ std::vector<cudaKernelNodeParams> params;
+ bool disable_due_to_gpu_arch = false;
+ bool disable_due_to_too_many_updates = false;
+ bool disable_due_to_failed_graph_capture = false;
+ int number_consecutive_updates = 0;
+ std::vector<ggml_graph_node_properties> ggml_graph_properties;
+ std::vector<char **> updated_kernel_arg;
+#endif
+};
+
struct ggml_backend_cuda_context {
int device;
std::string name;
@@ -534,6 +572,8 @@ struct ggml_backend_cuda_context {
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
+ std::unique_ptr<ggml_cuda_graph> cuda_graph;
+
explicit ggml_backend_cuda_context(int device) :
device(device),
name(GGML_CUDA_NAME + std::to_string(device)) {
diff --git a/ggml-cuda/convert.cu b/ggml-cuda/convert.cu
index 75e50c98..830e2d75 100644
--- a/ggml-cuda/convert.cu
+++ b/ggml-cuda/convert.cu
@@ -727,7 +727,6 @@ static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict_
}
to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
- int id;
switch (type) {
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
@@ -738,8 +737,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
- CUDA_CHECK(cudaGetDevice(&id));
- if (ggml_cuda_info().devices[id].cc >= CC_PASCAL) {
+ if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) {
return dequantize_block_q8_0_f16_cuda;
}
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
diff --git a/ggml-cuda/cpy.cu b/ggml-cuda/cpy.cu
index 16d9c8ff..12d741f0 100644
--- a/ggml-cuda/cpy.cu
+++ b/ggml-cuda/cpy.cu
@@ -459,3 +459,32 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
ggml_cuda_cpy(ctx, src0, dst);
}
+
+void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
+ return (void*) cpy_f32_f16<cpy_1_f32_f32>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
+ return (void*) cpy_f32_f16<cpy_1_f32_f16>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
+ return (void*) cpy_f32_q<cpy_blck_f32_q8_0, QK8_0>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
+ return (void*) cpy_f32_q<cpy_blck_f32_q4_0, QK4_0>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
+ return (void*) cpy_f32_q<cpy_blck_f32_q4_1, QK4_1>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
+ return (void*) cpy_f32_q<cpy_blck_f32_q5_0, QK5_0>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
+ return (void*) cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL>;
+ } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
+ return (void*) cpy_f32_q<cpy_blck_f32_q5_1, QK5_1>;
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+ return (void*) cpy_f32_f16<cpy_1_f32_f16>;
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
+ return (void*) cpy_f32_f16<cpy_1_f16_f32>;
+ } else {
+ fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
+ ggml_type_name(src0->type), ggml_type_name(src1->type));
+ GGML_ASSERT(false);
+ }
+}
+
diff --git a/ggml-cuda/cpy.cuh b/ggml-cuda/cpy.cuh
index f0b2c453..79616742 100644
--- a/ggml-cuda/cpy.cuh
+++ b/ggml-cuda/cpy.cuh
@@ -5,3 +5,5 @@
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
diff --git a/ggml-cuda/mmq.cu b/ggml-cuda/mmq.cu
index 60d6616a..7948f1b1 100644
--- a/ggml-cuda/mmq.cu
+++ b/ggml-cuda/mmq.cu
@@ -1735,8 +1735,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;
@@ -1780,8 +1779,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;
@@ -1825,8 +1823,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;
@@ -1870,8 +1867,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;
@@ -1915,8 +1911,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;
@@ -1960,8 +1955,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;
@@ -2007,8 +2001,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
#if QK_K == 256
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;
@@ -2053,8 +2046,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;
@@ -2098,8 +2090,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;
@@ -2143,8 +2134,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
const int compute_capability = ggml_cuda_info().devices[id].cc;
int mmq_x, mmq_y, nwarps;
diff --git a/ggml-cuda/mmvq.cu b/ggml-cuda/mmvq.cu
index 39655900..65cc1bca 100644
--- a/ggml-cuda/mmvq.cu
+++ b/ggml-cuda/mmvq.cu
@@ -89,8 +89,7 @@ static void mul_mat_vec_q_cuda(
GGML_ASSERT(ncols_x % qk == 0);
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
int64_t nwarps = 1;
int64_t rows_per_cuda_block = 1;
@@ -328,8 +327,7 @@ void ggml_cuda_op_mul_mat_vec_q(
const int64_t ne0 = dst->ne[0];
- int id;
- CUDA_CHECK(cudaGetDevice(&id));
+ int id = ggml_cuda_get_device();
// the main device has a larger memory buffer to hold the results from all GPUs
// nrows_dst == nrows of the matrix that the kernel writes into
diff --git a/ggml-cuda/scale.cu b/ggml-cuda/scale.cu
index 6e3617d1..1405e066 100644
--- a/ggml-cuda/scale.cu
+++ b/ggml-cuda/scale.cu
@@ -28,5 +28,4 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
memcpy(&scale, dst->op_params, sizeof(float));
scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
- CUDA_CHECK(cudaGetLastError());
}