diff options
author | Kawrakow <48489457+ikawrakow@users.noreply.github.com> | 2024-08-12 15:14:32 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-08-12 15:14:32 +0200 |
commit | 8f43e551038af2547b5c01d0e9edd641c0e4bd29 (patch) | |
tree | 07a4373620a9381d0b5c7189a475990a6feb48a5 /ggml/src/ggml-cuda.cu | |
parent | f5d1af61d79fb53ccfbac2e665e43208c07b083d (diff) |
Merge mainline - Aug 12 2024 (#17)
* Merge mainline
* Fix after merge
* Remove CI check
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/src/ggml-cuda.cu')
-rw-r--r-- | ggml/src/ggml-cuda.cu | 91 |
1 files changed, 56 insertions, 35 deletions
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 7641d5b5..f594cd26 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -98,7 +98,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in GGML_CUDA_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func, file, line); GGML_CUDA_LOG_ERROR(" %s\n", stmt); // abort with GGML_ASSERT to get a stack trace - GGML_ASSERT(!"CUDA error"); + GGML_ABORT("CUDA error"); } // this is faster on Windows @@ -130,7 +130,22 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) } return res; #else + +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) + cudaError_t err; + if (getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr) + { + err = cudaMallocManaged(ptr, size); + } + else + { + err = cudaMalloc(ptr, size); + } + return err; +#else return cudaMalloc(ptr, size); +#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) + #endif } @@ -167,7 +182,7 @@ static ggml_cuda_device_info ggml_cuda_init() { for (int id = 0; id < info.device_count; ++id) { int device_vmm = 0; -#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) CUdevice device; CU_CHECK(cuDeviceGet(&device, id)); CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device)); @@ -179,7 +194,7 @@ static ggml_cuda_device_info ggml_cuda_init() { alloc_prop.location.id = id; CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); } -#endif // !defined(GGML_USE_HIPBLAS) +#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) info.devices[id].vmm = !!device_vmm; cudaDeviceProp prop; @@ -315,7 +330,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { }; // pool with virtual memory -#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) struct ggml_cuda_pool_vmm : public ggml_cuda_pool { static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB @@ -409,14 +424,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { GGML_ASSERT(ptr == (void *) (pool_addr + pool_used)); } }; -#endif // !defined(GGML_USE_HIPBLAS) +#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) { -#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) if (ggml_cuda_info().devices[device].vmm) { return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device)); } -#endif +#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA) return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device)); } @@ -1341,7 +1356,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { static cudaError_t ggml_cuda_Memcpy2DPeerAsync( void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) { -#if !defined(GGML_USE_HIPBLAS) +#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices cudaMemcpy3DPeerParms p = {}; p.dstDevice = dstDevice; @@ -1355,7 +1370,7 @@ static cudaError_t ggml_cuda_Memcpy2DPeerAsync( GGML_UNUSED(dstDevice); GGML_UNUSED(srcDevice); return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream); -#endif // !defined(GGML_USE_HIPBLAS) +#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) } static void ggml_cuda_op_mul_mat( @@ -1486,7 +1501,7 @@ static void ggml_cuda_op_mul_mat( } // If src0 is on a temporary compute buffers (partial offloading) there may be some padding that needs to be cleared: - if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) { + if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) { const int64_t nbytes_data = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00); const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING); CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream)); @@ -1596,7 +1611,7 @@ static void ggml_cuda_op_mul_mat( CUDA_CHECK(ggml_cuda_cpy_tensor_2d( src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream)); } else { - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } if (quantize_src1 && !src1_is_contiguous) { @@ -1828,6 +1843,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co } } #else +#ifdef GGML_USE_MUSA + GGML_ASSERT(false); +#else // !GGML_USE_MUSA if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) { // there is no broadcast and src0, src1 are contiguous across dims 2, 3 // use cublasGemmStridedBatchedEx @@ -1870,6 +1888,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co cu_compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)); } +#endif // GGML_USE_MUSA #endif if (dst->op_params[0] == GGML_PREC_DEFAULT) { @@ -1881,10 +1900,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buffer_is_cuda_split(src0->buffer); - bool use_dequantize_mul_mat_vec = (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) + bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[0] >= GGML_CUDA_DMMV_X*2 - && src1->ne[1] == 1; + && src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1; bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; @@ -2340,33 +2358,35 @@ GGML_CALL static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, } GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { - GGML_ASSERT(ggml_backend_is_cuda(backend_src) || ggml_backend_is_cuda(backend_dst)); - ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer; ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer; - if (!ggml_backend_buffer_is_cuda(src->buffer)) { + if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) { return false; } - if (!ggml_backend_buffer_is_cuda(dst->buffer)) { + if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) { return false; } - // device -> device + // device -> device copy ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context; ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context; - if (backend_src != backend_dst) { - ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context; - ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context; + ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context; + ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context; - GGML_ASSERT(cuda_ctx_src->device == buf_ctx_src->device); - GGML_ASSERT(cuda_ctx_dst->device == buf_ctx_dst->device); + if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) { +#ifndef NDEBUG + GGML_CUDA_LOG_WARN("%s: backend and buffer devices do not match\n", __func__); +#endif + return false; + } + if (backend_src != backend_dst) { // copy on src stream if (cuda_ctx_src->device == cuda_ctx_dst->device) { - CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_dst->stream())); + CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream())); } else { #ifdef GGML_CUDA_NO_PEER_COPY return false; @@ -2375,7 +2395,7 @@ GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_ #endif } - // record event on src stream + // record event on src stream after the copy if (!cuda_ctx_src->copy_event) { ggml_cuda_set_device(cuda_ctx_src->device); CUDA_CHECK(cudaEventCreateWithFlags(&cuda_ctx_src->copy_event, cudaEventDisableTiming)); @@ -2387,7 +2407,7 @@ GGML_CALL static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_ CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx_dst->stream(), cuda_ctx_src->copy_event, 0)); } else { // src and dst are on the same backend - CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_dst->stream())); + CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream())); } return true; } @@ -2724,11 +2744,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_MUL_MAT_ID: { struct ggml_tensor * a = op->src[0]; - if (op->op == GGML_OP_MUL_MAT) { - struct ggml_tensor * b = op->src[1]; - if (a->ne[3] != b->ne[3]) { - return false; - } + struct ggml_tensor * b = op->src[1]; + if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) { + return false; + } + if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) { + return false; } switch (a->type) { case GGML_TYPE_F32: @@ -2867,7 +2888,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons return true; case GGML_OP_FLASH_ATTN_EXT: #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) - return op->src[0]->ne[0] == 64 || op->src[0]->ne[0] == 128; + return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128; #else if (op->src[0]->ne[0] == 128) { return true; @@ -2953,7 +2974,7 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev CUDA_CHECK(cudaLaunchHostFunc(cuda_ctx->stream(), wait_fn, event)); #endif - GGML_ASSERT(false); + GGML_ABORT("fatal error"); } } @@ -3035,7 +3056,7 @@ GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size return false; } -#if CUDART_VERSION >= 11100 +#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA) cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly); if (err != cudaSuccess) { // clear the error |