summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu44
1 files changed, 24 insertions, 20 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index c6bc3f64..64d3b674 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -543,6 +543,10 @@ GGML_CALL static const char * ggml_backend_cuda_buffer_type_name(ggml_backend_bu
return ctx->name.c_str();
}
+static bool ggml_backend_buft_is_cuda(ggml_backend_buffer_type_t buft) {
+ return buft->iface.get_name == ggml_backend_cuda_buffer_type_name;
+}
+
GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
@@ -585,24 +589,12 @@ GGML_CALL static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backen
GGML_UNUSED(buft);
}
-GGML_CALL static bool ggml_backend_cuda_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
- if (!ggml_backend_is_cuda(backend)) {
- return false;
- }
-
- ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
- ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
-
- return buft_ctx->device == cuda_ctx->device;
-}
-
static ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = {
/* .get_name = */ ggml_backend_cuda_buffer_type_name,
/* .alloc_buffer = */ ggml_backend_cuda_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cuda_buffer_type_get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size,
- /* .supports_backend = */ ggml_backend_cuda_buffer_type_supports_backend,
/* .is_host = */ NULL,
};
@@ -863,6 +855,10 @@ GGML_CALL static const char * ggml_backend_cuda_split_buffer_type_name(ggml_back
GGML_UNUSED(buft);
}
+static bool ggml_backend_buft_is_cuda_split(ggml_backend_buffer_type_t buft) {
+ return buft->iface.get_name == ggml_backend_cuda_split_buffer_type_name;
+}
+
GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
// since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
// instead, we allocate them for each tensor separately in init_tensor
@@ -906,12 +902,6 @@ GGML_CALL static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_
return total_size;
}
-GGML_CALL static bool ggml_backend_cuda_split_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) {
- return ggml_backend_is_cuda(backend);
-
- GGML_UNUSED(buft);
-}
-
GGML_CALL static bool ggml_backend_cuda_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
return false;
@@ -924,7 +914,6 @@ static ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface
/* .get_alignment = */ ggml_backend_cuda_split_buffer_type_get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_alloc_size = */ ggml_backend_cuda_split_buffer_type_get_alloc_size,
- /* .supports_backend = */ ggml_backend_cuda_split_buffer_type_supports_backend,
/* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host,
};
@@ -1024,7 +1013,6 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() {
/* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
- /* .supports_backend = */ ggml_backend_cpu_buffer_type()->iface.supports_backend,
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
},
/* .context = */ nullptr,
@@ -2879,6 +2867,20 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
GGML_UNUSED(backend);
}
+GGML_CALL static bool ggml_backend_cuda_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+ if (ggml_backend_buft_is_cuda_split(buft)) {
+ return true;
+ }
+
+ if (ggml_backend_buft_is_cuda(buft)) {
+ ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+ ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context;
+ return buft_ctx->device == cuda_ctx->device;
+ }
+
+ return false;
+}
+
GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
const int min_batch_size = 32;
@@ -2951,9 +2953,11 @@ static ggml_backend_i ggml_backend_cuda_interface = {
/* .synchronize = */ ggml_backend_cuda_synchronize,
/* .graph_plan_create = */ NULL,
/* .graph_plan_free = */ NULL,
+ /* .graph_plan_update = */ NULL,
/* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_cuda_graph_compute,
/* .supports_op = */ ggml_backend_cuda_supports_op,
+ /* .supports_buft = */ ggml_backend_cuda_supports_buft,
/* .offload_op = */ ggml_backend_cuda_offload_op,
/* .event_new = */ ggml_backend_cuda_event_new,
/* .event_free = */ ggml_backend_cuda_event_free,