summaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu20
1 files changed, 17 insertions, 3 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 2a90ee55..d0a754ee 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -119,6 +119,20 @@ int ggml_cuda_get_device() {
return id;
}
+static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
+ ggml_cuda_set_device(device);
+#if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA)
+ auto res = hipMallocManaged(ptr, size);
+ if (res == hipSuccess) {
+ // if error we "need" to know why...
+ CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
+ }
+ return res;
+#else
+ return cudaMalloc(ptr, size);
+#endif
+}
+
static ggml_cuda_device_info ggml_cuda_init() {
#ifdef __HIP_PLATFORM_AMD__
// Workaround for a rocBLAS bug when using multiple graphics cards:
@@ -271,7 +285,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
size_t look_ahead_size = (size_t) (1.05 * size);
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
ggml_cuda_set_device(device);
- CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
+ CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
*actual_size = look_ahead_size;
pool_size += look_ahead_size;
#ifdef DEBUG_CUDA_MALLOC
@@ -537,7 +551,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe
size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
void * dev_ptr;
- cudaError_t err = cudaMalloc(&dev_ptr, size);
+ cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
if (err != cudaSuccess) {
// clear the error
cudaGetLastError();
@@ -798,7 +812,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_bu
// currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
ggml_cuda_set_device(id);
char * buf;
- CUDA_CHECK(cudaMalloc(&buf, size));
+ CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));
// set padding to 0 to avoid possible NaN values
if (size > original_size) {