summaryrefslogtreecommitdiff
path: root/ggml
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-02-07 08:33:42 +0200
committerGitHub <noreply@github.com>2025-02-07 08:33:42 +0200
commit4601a8c3735d8e47c46e0927712d77c4f422be6c (patch)
treebe32170ffa62b520911efc48ecac60313dc07aef /ggml
parentb08a2e9dfc0e721f7f190c25f37794390966e326 (diff)
cuda: non-contiguous rms norm (#190)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml')
-rw-r--r--ggml/src/ggml-cuda/norm.cu153
1 files changed, 141 insertions, 12 deletions
diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu
index 7e670912..9e4931a3 100644
--- a/ggml/src/ggml-cuda/norm.cu
+++ b/ggml/src/ggml-cuda/norm.cu
@@ -132,6 +132,51 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
}
template <int block_size>
+static __global__ void rms_norm_f32_nc(
+ const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
+ const int64_t stride_sample, const float eps) {
+ const int nrows = gridDim.x;
+ const int nchannels = gridDim.y;
+
+ const int row = blockIdx.x;
+ const int channel = blockIdx.y;
+ const int sample = blockIdx.z;
+ const int tid = threadIdx.x;
+
+ x += sample*stride_sample + channel*stride_channel + row*stride_row;
+ dst += ((sample*nchannels + channel)*nrows + row)*ncols;
+
+ float tmp = 0.0f; // partial sum for thread in warp
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const float xi = x[col];
+ tmp += xi * xi;
+ }
+
+ // sum up partial sums
+ tmp = warp_reduce_sum(tmp);
+ if constexpr (block_size > WARP_SIZE) {
+ static_assert(block_size == 1024, "unexpected block_size");
+ __shared__ float s_sum[32];
+ const int warp_id = threadIdx.x / WARP_SIZE;
+ const int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = tmp;
+ }
+ __syncthreads();
+ tmp = s_sum[lane_id];
+ tmp = warp_reduce_sum(tmp);
+ }
+
+ const float mean = tmp / ncols;
+ const float scale = rsqrtf(mean + eps);
+
+ for (int col = tid; col < ncols; col += block_size) {
+ dst[col] = scale * x[col];
+ }
+}
+
+template <int block_size>
static __global__ void fused_rms_norm_f32(const float * x, const float * y, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
@@ -165,6 +210,51 @@ static __global__ void fused_rms_norm_f32(const float * x, const float * y, floa
}
}
+template <int block_size>
+static __global__ void fused_rms_norm_f32_nc(
+ const float * x, const float * y, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
+ const int64_t stride_sample, const float eps) {
+ const int nrows = gridDim.x;
+ const int nchannels = gridDim.y;
+
+ const int row = blockIdx.x;
+ const int channel = blockIdx.y;
+ const int sample = blockIdx.z;
+ const int tid = threadIdx.x;
+
+ x += sample*stride_sample + channel*stride_channel + row*stride_row;
+ dst += ((sample*nchannels + channel)*nrows + row)*ncols;
+
+ float tmp = 0.0f; // partial sum for thread in warp
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const float xi = x[col];
+ tmp += xi * xi;
+ }
+
+ // sum up partial sums
+ tmp = warp_reduce_sum(tmp);
+ if constexpr (block_size > WARP_SIZE) {
+ static_assert(block_size == 1024, "unexpected block_size");
+ __shared__ float s_sum[32];
+ const int warp_id = threadIdx.x / WARP_SIZE;
+ const int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = tmp;
+ }
+ __syncthreads();
+ tmp = s_sum[lane_id];
+ tmp = warp_reduce_sum(tmp);
+ }
+
+ const float mean = tmp / ncols;
+ const float scale = rsqrtf(mean + eps);
+
+ for (int col = tid; col < ncols; col += block_size) {
+ dst[col] = scale * y[col] * x[col];
+ }
+}
+
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
@@ -197,6 +287,19 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
}
}
+static void rms_norm_f32_nc_cuda(
+ const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
+ const dim3 blocks_num(nrows, nchannels, nsamples);
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ rms_norm_f32_nc<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ rms_norm_f32_nc<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ }
+}
+
static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst,
const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
@@ -209,6 +312,19 @@ static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * ds
}
}
+static void fused_rms_norm_f32_nc_cuda(
+ const float * x, const float * y, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
+ const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
+ const dim3 blocks_num(nrows, nchannels, nsamples);
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ fused_rms_norm_f32_nc<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ fused_rms_norm_f32_nc<1024><<<blocks_num, block_dims, 0, stream>>>(x, y, dst, ncols, stride_row, stride_channel, stride_sample, eps);
+ }
+}
+
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
@@ -255,18 +371,24 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
- GGML_ASSERT(ggml_is_contiguous(src0));
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
- const int64_t ne00 = src0->ne[0];
- const int64_t nrows = ggml_nrows(src0);
-
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
- rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
+ const int64_t ne00 = src0->ne[0];
+ if (ggml_is_contiguous(src0)) {
+ const int64_t nrows = ggml_nrows(src0);
+ rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
+ } else {
+ auto ts0 = ggml_type_size(src0->type);
+ GGML_ASSERT(src0->nb[0] == ts0);
+ auto s01 = src0->nb[1] / ts0;
+ auto s02 = src0->nb[2] / ts0;
+ auto s03 = src0->nb[3] / ts0;
+ rms_norm_f32_nc_cuda(src0_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
+ }
}
void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -281,19 +403,26 @@ void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor *
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
- GGML_ASSERT(ggml_is_contiguous(src0));
-
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
GGML_ASSERT(ggml_nrows(src1) == 1);
- const int64_t ne00 = src0->ne[0];
- const int64_t nrows = ggml_nrows(src0);
-
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
- fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
+ const int64_t ne00 = src0->ne[0];
+
+ if (ggml_is_contiguous(src0)) {
+ const int64_t nrows = ggml_nrows(src0);
+ fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
+ } else {
+ auto ts0 = ggml_type_size(src0->type);
+ GGML_ASSERT(src0->nb[0] == ts0);
+ auto s01 = src0->nb[1] / ts0;
+ auto s02 = src0->nb[2] / ts0;
+ auto s03 = src0->nb[3] / ts0;
+ fused_rms_norm_f32_nc_cuda(src0_d, src1_d, dst_d, ne00, src0->ne[1], src0->ne[2], src0->ne[3], s01, s02, s03, eps, stream);
+ }
}