summaryrefslogtreecommitdiff
path: root/ggml/src/ggml-cuda
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml-cuda')
-rw-r--r--ggml/src/ggml-cuda/norm.cu75
-rw-r--r--ggml/src/ggml-cuda/norm.cuh2
2 files changed, 77 insertions, 0 deletions
diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu
index 133e219f..7e670912 100644
--- a/ggml/src/ggml-cuda/norm.cu
+++ b/ggml/src/ggml-cuda/norm.cu
@@ -131,6 +131,40 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
}
}
+template <int block_size>
+static __global__ void fused_rms_norm_f32(const float * x, const float * y, float * dst, const int ncols, const float eps) {
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ const int tid = threadIdx.x;
+
+ float tmp = 0.0f; // partial sum for thread in warp
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const float xi = x[row*ncols + col];
+ tmp += xi * xi;
+ }
+
+ // sum up partial sums
+ tmp = warp_reduce_sum(tmp);
+ if (block_size > WARP_SIZE) {
+ __shared__ float s_sum[32];
+ int warp_id = threadIdx.x / WARP_SIZE;
+ int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = tmp;
+ }
+ __syncthreads();
+ tmp = s_sum[lane_id];
+ tmp = warp_reduce_sum(tmp);
+ }
+
+ const float mean = tmp / ncols;
+ const float scale = rsqrtf(mean + eps);
+
+ for (int col = tid; col < ncols; col += block_size) {
+ dst[row*ncols + col] = scale * y[col] * x[row*ncols + col];
+ }
+}
+
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
@@ -163,6 +197,18 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
}
}
+static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst,
+ const int ncols, const int nrows, const float eps, cudaStream_t stream) {
+ GGML_ASSERT(ncols % WARP_SIZE == 0);
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ fused_rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ fused_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
+ }
+}
+
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
@@ -222,3 +268,32 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
}
+
+void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ if (!dst->src[1]) {
+ ggml_cuda_op_rms_norm(ctx, dst);
+ return;
+ }
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const float * src0_d = (const float *)src0->data;
+ const float * src1_d = (const float *)src1->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0->ne[0] == src1->ne[0]);
+ GGML_ASSERT(ggml_nrows(src1) == 1);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
+}
diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh
index 431a8f74..e4f9ee82 100644
--- a/ggml/src/ggml-cuda/norm.cuh
+++ b/ggml/src/ggml-cuda/norm.cuh
@@ -5,3 +5,5 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);