From 4601a8c3735d8e47c46e0927712d77c4f422be6c Mon Sep 17 00:00:00 2001 From: Kawrakow Date: Fri, 7 Feb 2025 08:33:42 +0200 Subject: cuda: non-contiguous rms norm (#190) Co-authored-by: Iwan Kawrakow --- ggml/src/ggml-cuda/norm.cu | 153 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 141 insertions(+), 12 deletions(-) (limited to 'ggml/src') diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 7e670912..9e4931a3 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -131,6 +131,51 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol } } +template +static __global__ void rms_norm_f32_nc( + const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps) { + const int nrows = gridDim.x; + const int nchannels = gridDim.y; + + const int row = blockIdx.x; + const int channel = blockIdx.y; + const int sample = blockIdx.z; + const int tid = threadIdx.x; + + x += sample*stride_sample + channel*stride_channel + row*stride_row; + dst += ((sample*nchannels + channel)*nrows + row)*ncols; + + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[col]; + tmp += xi * xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp); + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); + __shared__ float s_sum[32]; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = s_sum[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + const float mean = tmp / ncols; + const float scale = rsqrtf(mean + eps); + + for (int col = tid; col < ncols; col += block_size) { + dst[col] = scale * x[col]; + } +} + template static __global__ void fused_rms_norm_f32(const float * x, const float * y, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; @@ -165,6 +210,51 @@ static __global__ void fused_rms_norm_f32(const float * x, const float * y, floa } } +template +static __global__ void fused_rms_norm_f32_nc( + const float * x, const float * y, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps) { + const int nrows = gridDim.x; + const int nchannels = gridDim.y; + + const int row = blockIdx.x; + const int channel = blockIdx.y; + const int sample = blockIdx.z; + const int tid = threadIdx.x; + + x += sample*stride_sample + channel*stride_channel + row*stride_row; + dst += ((sample*nchannels + channel)*nrows + row)*ncols; + + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[col]; + tmp += xi * xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp); + if constexpr (block_size > WARP_SIZE) { + static_assert(block_size == 1024, "unexpected block_size"); + __shared__ float s_sum[32]; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = s_sum[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + const float mean = tmp / ncols; + const float scale = rsqrtf(mean + eps); + + for (int col = tid; col < ncols; col += block_size) { + dst[col] = scale * y[col] * x[col]; + } +} + static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { @@ -197,6 +287,19 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con } } +static void rms_norm_f32_nc_cuda( + const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { + const dim3 blocks_num(nrows, nchannels, nsamples); + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + rms_norm_f32_nc<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_f32_nc<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } +} + static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); @@ -209,6 +312,19 @@ static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * ds } } +static void fused_rms_norm_f32_nc_cuda( + const float * x, const float * y, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { + const dim3 blocks_num(nrows, nchannels, nsamples); + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + fused_rms_norm_f32_nc<<>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } else { + const dim3 block_dims(1024, 1, 1); + fused_rms_norm_f32_nc<1024><<>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } +} + void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -255,18 +371,24 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); - float eps; memcpy(&eps, dst->op_params, sizeof(float)); - rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); + const int64_t ne00 = src0->ne[0]; + if (ggml_is_contiguous(src0)) { + const int64_t nrows = ggml_nrows(src0); + rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); + } else { + auto ts0 = ggml_type_size(src0->type); + GGML_ASSERT(src0->nb[0] == ts0); + auto s01 = src0->nb[1] / ts0; + auto s02 = src0->nb[2] / ts0; + auto s03 = src0->nb[3] / ts0; + rms_norm_f32_nc_cuda(src0_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream); + } } void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -281,19 +403,26 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT(src0->ne[0] == src1->ne[0]); GGML_ASSERT(ggml_nrows(src1) == 1); - const int64_t ne00 = src0->ne[0]; - const int64_t nrows = ggml_nrows(src0); - float eps; memcpy(&eps, dst->op_params, sizeof(float)); - fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream); + const int64_t ne00 = src0->ne[0]; + + if (ggml_is_contiguous(src0)) { + const int64_t nrows = ggml_nrows(src0); + fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream); + } else { + auto ts0 = ggml_type_size(src0->type); + GGML_ASSERT(src0->nb[0] == ts0); + auto s01 = src0->nb[1] / ts0; + auto s02 = src0->nb[2] / ts0; + auto s03 = src0->nb[3] / ts0; + fused_rms_norm_f32_nc_cuda(src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream); + } } -- cgit v1.2.3