diff options
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r-- | ggml-cuda.cu | 44 |
1 files changed, 24 insertions, 20 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu index c6bc3f64..64d3b674 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -543,6 +543,10 @@ GGML_CALL static const char * ggml_backend_cuda_buffer_type_name(ggml_backend_bu return ctx->name.c_str(); } +static bool ggml_backend_buft_is_cuda(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_cuda_buffer_type_name; +} + GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context; @@ -585,24 +589,12 @@ GGML_CALL static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backen GGML_UNUSED(buft); } -GGML_CALL static bool ggml_backend_cuda_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { - if (!ggml_backend_is_cuda(backend)) { - return false; - } - - ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context; - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; - - return buft_ctx->device == cuda_ctx->device; -} - static ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = { /* .get_name = */ ggml_backend_cuda_buffer_type_name, /* .alloc_buffer = */ ggml_backend_cuda_buffer_type_alloc_buffer, /* .get_alignment = */ ggml_backend_cuda_buffer_type_get_alignment, /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ ggml_backend_cuda_buffer_type_get_alloc_size, - /* .supports_backend = */ ggml_backend_cuda_buffer_type_supports_backend, /* .is_host = */ NULL, }; @@ -863,6 +855,10 @@ GGML_CALL static const char * ggml_backend_cuda_split_buffer_type_name(ggml_back GGML_UNUSED(buft); } +static bool ggml_backend_buft_is_cuda_split(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_cuda_split_buffer_type_name; +} + GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point // instead, we allocate them for each tensor separately in init_tensor @@ -906,12 +902,6 @@ GGML_CALL static size_t ggml_backend_cuda_split_buffer_type_get_alloc_size(ggml_ return total_size; } -GGML_CALL static bool ggml_backend_cuda_split_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { - return ggml_backend_is_cuda(backend); - - GGML_UNUSED(buft); -} - GGML_CALL static bool ggml_backend_cuda_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) { return false; @@ -924,7 +914,6 @@ static ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_interface /* .get_alignment = */ ggml_backend_cuda_split_buffer_type_get_alignment, /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ ggml_backend_cuda_split_buffer_type_get_alloc_size, - /* .supports_backend = */ ggml_backend_cuda_split_buffer_type_supports_backend, /* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host, }; @@ -1024,7 +1013,6 @@ GGML_CALL ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type() { /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment, /* .get_max_size = */ NULL, // defaults to SIZE_MAX /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, - /* .supports_backend = */ ggml_backend_cpu_buffer_type()->iface.supports_backend, /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, }, /* .context = */ nullptr, @@ -2879,6 +2867,20 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons GGML_UNUSED(backend); } +GGML_CALL static bool ggml_backend_cuda_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { + if (ggml_backend_buft_is_cuda_split(buft)) { + return true; + } + + if (ggml_backend_buft_is_cuda(buft)) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *)buft->context; + return buft_ctx->device == cuda_ctx->device; + } + + return false; +} + GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) { const int min_batch_size = 32; @@ -2951,9 +2953,11 @@ static ggml_backend_i ggml_backend_cuda_interface = { /* .synchronize = */ ggml_backend_cuda_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_cuda_graph_compute, /* .supports_op = */ ggml_backend_cuda_supports_op, + /* .supports_buft = */ ggml_backend_cuda_supports_buft, /* .offload_op = */ ggml_backend_cuda_offload_op, /* .event_new = */ ggml_backend_cuda_event_new, /* .event_free = */ ggml_backend_cuda_event_free, |