diff options
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r-- | ggml-cuda.cu | 88 |
1 files changed, 57 insertions, 31 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu index dad8a9e2..af10f21a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1347,10 +1347,30 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { GGML_UNUSED(main_device); } +static cudaError_t ggml_cuda_Memcpy2DPeerAsync( + void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) { + +#if !defined(GGML_USE_HIPBLAS) + // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices + cudaMemcpy3DPeerParms p = {}; + p.dstDevice = dstDevice; + p.dstPtr = make_cudaPitchedPtr(dst, dpitch, dpitch, height); + p.srcDevice = srcDevice; + p.srcPtr = make_cudaPitchedPtr(src, spitch, spitch, height); + p.extent = make_cudaExtent(width, height, 1); + return cudaMemcpy3DPeerAsync(&p, stream); +#else + // HIP does not support cudaMemcpy3DPeerAsync or vmm pools + GGML_UNUSED(dstDevice); + GGML_UNUSED(srcDevice); + return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream); +#endif // !defined(GGML_USE_HIPBLAS) +} + static void ggml_cuda_op_mul_mat( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op, - const bool convert_src1_to_q8_1) { + quantize_cuda_t quantize_src1) { const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; @@ -1407,7 +1427,9 @@ static void ggml_cuda_op_mul_mat( } struct dev_data { - ggml_cuda_pool_alloc<char> src0_dd_alloc; + int cc; + + ggml_cuda_pool_alloc<char> src0_dd_alloc; ggml_cuda_pool_alloc<float> src1_ddf_alloc; ggml_cuda_pool_alloc<char> src1_ddq_alloc; ggml_cuda_pool_alloc<float> dst_dd_alloc; @@ -1426,6 +1448,8 @@ static void ggml_cuda_op_mul_mat( int used_devices = 0; for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { + dev[id].cc = ggml_cuda_info().devices[id].cc; + // by default, use all rows dev[id].row_low = 0; dev[id].row_high = ne01; @@ -1476,11 +1500,15 @@ static void ggml_cuda_op_mul_mat( dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1)); } - if (convert_src1_to_q8_1) { - dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs); + if (quantize_src1) { + size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs; + if (quantize_src1 == quantize_mmq_q8_1_cuda) { + src_1_ddq_size += get_mmq_x_max_host(dev[id].cc)*sizeof(block_q8_1_mmq); + } + dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size); if (src1_on_device && src1_is_contiguous) { - quantize_row_q8_1_cuda(dev[id].src1_ddf, dev[id].src1_ddq, ne10, nrows1, src1_padded_col_size, stream); + quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream); CUDA_CHECK(cudaGetLastError()); } } @@ -1526,7 +1554,12 @@ static void ggml_cuda_op_mul_mat( const int64_t i03 = i0 / ne12; const int64_t i02 = i0 % ne12; - const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs; + size_t src1_ddq_i_offset = i0*ne11 * src1_padded_col_size*q8_1_ts/q8_1_bs; + if (quantize_src1 == quantize_mmq_q8_1_cuda) { + src1_ddq_i_offset += src1_col_0 * sizeof(block_q8_1_mmq); + } else { + src1_ddq_i_offset += src1_col_0 * src1_padded_col_size*q8_1_ts/q8_1_bs; + } // for split tensors the data begins at i0 == i0_offset_low char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs; @@ -1543,10 +1576,17 @@ static void ggml_cuda_op_mul_mat( // copy src0, src1 to device if necessary if (src1_is_contiguous) { if (id != ctx.device) { - if (convert_src1_to_q8_1) { + if (quantize_src1) { char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset; - CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddq_i, id, src1_ddq_i_source, ctx.device, - src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream)); + if (quantize_src1 == quantize_mmq_q8_1_cuda) { + const size_t pitch = ne11*sizeof(block_q8_1_mmq); + const size_t width = src1_ncols*sizeof(block_q8_1_mmq); + const size_t height = src1_padded_col_size/(4*QK8_1); + CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(src1_ddq_i, id, pitch, src1_ddq_i_source, ctx.device, pitch, width, height, stream)); + } else { + CUDA_CHECK(cudaMemcpyPeerAsync( + src1_ddq_i, id, src1_ddq_i_source, ctx.device, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream)); + } } else { float * src1_ddf_i_source = (float *) src1->data; src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10; @@ -1561,8 +1601,8 @@ static void ggml_cuda_op_mul_mat( GGML_ASSERT(false); } - if (convert_src1_to_q8_1 && !src1_is_contiguous) { - quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream); + if (quantize_src1 && !src1_is_contiguous) { + quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream); CUDA_CHECK(cudaGetLastError()); } @@ -1587,22 +1627,8 @@ static void ggml_cuda_op_mul_mat( float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); dhf_dst_i += src1_col_0*ne0 + dev[id].row_low; -#if !defined(GGML_USE_HIPBLAS) - // cudaMemcpy2DAsync may fail with copies between vmm pools of different devices - cudaMemcpy3DPeerParms p = {}; - p.dstDevice = ctx.device; - p.dstPtr = make_cudaPitchedPtr(dhf_dst_i, ne0*sizeof(float), row_diff, src1_ncols); - p.srcDevice = id; - p.srcPtr = make_cudaPitchedPtr(dst_dd_i, row_diff*sizeof(float), row_diff, src1_ncols); - p.extent = make_cudaExtent(row_diff*sizeof(float), src1_ncols, 1); - CUDA_CHECK(cudaMemcpy3DPeerAsync(&p, stream)); -#else - // HIP does not support cudaMemcpy3DPeerAsync or vmm pools - CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float), - dst_dd_i, row_diff*sizeof(float), - row_diff*sizeof(float), src1_ncols, - cudaMemcpyDeviceToDevice, stream)); -#endif + CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync( + dhf_dst_i, ctx.device, ne0*sizeof(float), dst_dd_i, id, row_diff*sizeof(float), row_diff*sizeof(float), src1_ncols, stream)); } else { float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3); GGML_ASSERT(dst->nb[1] == ne0*sizeof(float)); @@ -1941,13 +1967,13 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor // KQ + KQV multi-batch ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr); } else if (use_mul_mat_vec_q) { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true); + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); } else if (use_mul_mat_q) { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, true); + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda); } else { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr); } } |