diff options
Diffstat (limited to 'ggml/src/ggml-cuda')
| -rw-r--r-- | ggml/src/ggml-cuda/norm.cu | 75 | ||||
| -rw-r--r-- | ggml/src/ggml-cuda/norm.cuh | 2 |
2 files changed, 77 insertions, 0 deletions
diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 133e219f..7e670912 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -131,6 +131,40 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol } } +template <int block_size> +static __global__ void fused_rms_norm_f32(const float * x, const float * y, float * dst, const int ncols, const float eps) { + const int row = blockIdx.x*blockDim.y + threadIdx.y; + const int tid = threadIdx.x; + + float tmp = 0.0f; // partial sum for thread in warp + + for (int col = tid; col < ncols; col += block_size) { + const float xi = x[row*ncols + col]; + tmp += xi * xi; + } + + // sum up partial sums + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __shared__ float s_sum[32]; + int warp_id = threadIdx.x / WARP_SIZE; + int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + s_sum[warp_id] = tmp; + } + __syncthreads(); + tmp = s_sum[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + const float mean = tmp / ncols; + const float scale = rsqrtf(mean + eps); + + for (int col = tid; col < ncols; col += block_size) { + dst[row*ncols + col] = scale * y[col] * x[row*ncols + col]; + } +} + static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { @@ -163,6 +197,18 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con } } +static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst, + const int ncols, const int nrows, const float eps, cudaStream_t stream) { + GGML_ASSERT(ncols % WARP_SIZE == 0); + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + fused_rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps); + } else { + const dim3 block_dims(1024, 1, 1); + fused_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps); + } +} + void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; @@ -222,3 +268,32 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream); } + +void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + if (!dst->src[1]) { + ggml_cuda_op_rms_norm(ctx, dst); + return; + } + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const float * src0_d = (const float *)src0->data; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->ne[0] == src1->ne[0]); + GGML_ASSERT(ggml_nrows(src1) == 1); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream); +} diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh index 431a8f74..e4f9ee82 100644 --- a/ggml/src/ggml-cuda/norm.cuh +++ b/ggml/src/ggml-cuda/norm.cuh @@ -5,3 +5,5 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |
