summaryrefslogtreecommitdiff
path: root/ggml
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-09-08 10:19:21 +0300
committerGitHub <noreply@github.com>2024-09-08 10:19:21 +0300
commit6136a4b8034f57067e0202d23571c45c98a0bf0b (patch)
tree1d7954eb8cf97f1c26b03fe220b5fb2c9d06ddef /ggml
parent0087008d2999eea83f20fd17c775fdc5f8b4b6b5 (diff)
Adding fused rms_norm (#42)
* Fused rms_norm: works on the CPU * Fused rms_norm WIP * Fused rms_norm WIP * Fused rms_norm WIP * Fused rms_norm WIP * Fused rms_norm WIP --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml')
-rw-r--r--ggml/include/ggml.h13
-rw-r--r--ggml/src/ggml-cuda.cu4
-rw-r--r--ggml/src/ggml-cuda/norm.cu75
-rw-r--r--ggml/src/ggml-cuda/norm.cuh2
-rw-r--r--ggml/src/ggml-metal.m35
-rw-r--r--ggml/src/ggml-metal.metal51
-rw-r--r--ggml/src/ggml.c138
7 files changed, 316 insertions, 2 deletions
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 1a4a516c..ab6d172d 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -480,6 +480,7 @@ extern "C" {
GGML_OP_RMS_NORM,
GGML_OP_RMS_NORM_BACK,
GGML_OP_GROUP_NORM,
+ GGML_OP_FUSED_RMS_NORM,
GGML_OP_MUL_MAT,
GGML_OP_MUL_MAT_ID,
@@ -1159,6 +1160,18 @@ extern "C" {
struct ggml_tensor * a,
float eps);
+ GGML_API struct ggml_tensor * ggml_fused_rms_norm(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ float eps);
+
+ GGML_API struct ggml_tensor * ggml_fused_rms_norm_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ float eps);
+
// group normalize along ne0*ne1*n_groups
// used in stable-diffusion
GGML_API struct ggml_tensor * ggml_group_norm(
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 056ca4a4..cf053559 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -2248,6 +2248,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_RMS_NORM:
ggml_cuda_op_rms_norm(ctx, dst);
break;
+ case GGML_OP_FUSED_RMS_NORM:
+ ggml_cuda_op_fused_rms_norm(ctx, dst);
+ break;
case GGML_OP_MUL_MAT:
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
GGML_CUDA_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
@@ -2871,6 +2874,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_RMS_NORM:
+ case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_SCALE:
case GGML_OP_SOFTCAP:
case GGML_OP_SQR:
diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu
index 133e219f..7e670912 100644
--- a/ggml/src/ggml-cuda/norm.cu
+++ b/ggml/src/ggml-cuda/norm.cu
@@ -131,6 +131,40 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
}
}
+template <int block_size>
+static __global__ void fused_rms_norm_f32(const float * x, const float * y, float * dst, const int ncols, const float eps) {
+ const int row = blockIdx.x*blockDim.y + threadIdx.y;
+ const int tid = threadIdx.x;
+
+ float tmp = 0.0f; // partial sum for thread in warp
+
+ for (int col = tid; col < ncols; col += block_size) {
+ const float xi = x[row*ncols + col];
+ tmp += xi * xi;
+ }
+
+ // sum up partial sums
+ tmp = warp_reduce_sum(tmp);
+ if (block_size > WARP_SIZE) {
+ __shared__ float s_sum[32];
+ int warp_id = threadIdx.x / WARP_SIZE;
+ int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = tmp;
+ }
+ __syncthreads();
+ tmp = s_sum[lane_id];
+ tmp = warp_reduce_sum(tmp);
+ }
+
+ const float mean = tmp / ncols;
+ const float scale = rsqrtf(mean + eps);
+
+ for (int col = tid; col < ncols; col += block_size) {
+ dst[row*ncols + col] = scale * y[col] * x[row*ncols + col];
+ }
+}
+
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
@@ -163,6 +197,18 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
}
}
+static void fused_rms_norm_f32_cuda(const float * x, const float * y, float * dst,
+ const int ncols, const int nrows, const float eps, cudaStream_t stream) {
+ GGML_ASSERT(ncols % WARP_SIZE == 0);
+ if (ncols < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ fused_rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ fused_rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, y, dst, ncols, eps);
+ }
+}
+
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
@@ -222,3 +268,32 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
}
+
+void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ if (!dst->src[1]) {
+ ggml_cuda_op_rms_norm(ctx, dst);
+ return;
+ }
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const float * src0_d = (const float *)src0->data;
+ const float * src1_d = (const float *)src1->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0->ne[0] == src1->ne[0]);
+ GGML_ASSERT(ggml_nrows(src1) == 1);
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t nrows = ggml_nrows(src0);
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ fused_rms_norm_f32_cuda(src0_d, src1_d, dst_d, ne00, nrows, eps, stream);
+}
diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh
index 431a8f74..e4f9ee82 100644
--- a/ggml/src/ggml-cuda/norm.cuh
+++ b/ggml/src/ggml-cuda/norm.cuh
@@ -5,3 +5,5 @@ void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_fused_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index 83bd76f9..b3f6e60c 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -104,6 +104,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K,
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
GGML_METAL_KERNEL_TYPE_RMS_NORM,
+ GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM,
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
GGML_METAL_KERNEL_TYPE_NORM,
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
@@ -613,6 +614,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ6_K, get_rows_iq6_k, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM, fused_rms_norm, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
@@ -884,6 +886,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
return true; //ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op);
case GGML_OP_SOFT_MAX:
case GGML_OP_RMS_NORM:
+ case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_GROUP_NORM:
return ctx->support_simdgroup_reduction;
case GGML_OP_NORM:
@@ -2608,6 +2611,38 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
+ case GGML_OP_FUSED_RMS_NORM:
+ {
+ GGML_ASSERT(ne00 % 4 == 0);
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
+ GGML_ASSERT(src1->ne[0] == src0->ne[0]);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(ggml_nrows(src1) == 1);
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ int nth = 32; // SIMD width
+
+ while (nth < ne00/4 && nth < 1024) {
+ nth *= 2;
+ }
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FUSED_RMS_NORM].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
+ [encoder setBytes:&eps length:sizeof( float) atIndex:5];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+ const int64_t nrows = ggml_nrows(src0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
case GGML_OP_GROUP_NORM:
{
GGML_ASSERT(ne00 % 4 == 0);
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index f9c88a37..d7af1800 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -1038,6 +1038,57 @@ kernel void kernel_rms_norm(
}
}
+kernel void kernel_fused_rms_norm(
+ device const void * src0,
+ device const void * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant float & eps,
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
+
+ float4 sumf = 0;
+ float all_sum = 0;
+
+ // parallel sum
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ sumf += x[i00] * x[i00];
+ }
+ all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
+ all_sum = simd_sum(all_sum);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = all_sum;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ all_sum = buf[tiisg];
+ all_sum = simd_sum(all_sum);
+ }
+
+ const float mean = all_sum/ne00;
+ const float scale = 1.0f/sqrt(mean + eps);
+
+ device float4 * y = (device float4 *) (dst + tgpig*ne00);
+ device float4 * z = (device float4 *)src1;
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ y[i00] = x[i00] * z[i00] * scale;
+ }
+}
+
kernel void kernel_group_norm(
device const float * src0,
device float * dst,
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 45fddca5..d562002e 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -3144,6 +3144,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"RMS_NORM",
"RMS_NORM_BACK",
"GROUP_NORM",
+ "FUSED_RMS_NORM",
"MUL_MAT",
"MUL_MAT_ID",
@@ -3207,7 +3208,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
+static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -3234,6 +3235,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"rms_norm(x)",
"rms_norm_back(x)",
"group_norm(x)",
+ "fused_rms_norm(x)",
"X*Y",
"X[i]*Y",
@@ -3297,7 +3299,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
+static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -5737,6 +5739,57 @@ struct ggml_tensor * ggml_rms_norm_inplace(
return ggml_rms_norm_impl(ctx, a, eps, true);
}
+static struct ggml_tensor * ggml_fused_rms_norm_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ float eps,
+ bool inplace) {
+
+ if (!b) {
+ return ggml_rms_norm_impl(ctx, a, eps, inplace);
+ }
+
+ if (ggml_nrows(b) > 1 || a->ne[0] != b->ne[0]) {
+ struct ggml_tensor * result = ggml_rms_norm_impl(ctx, a, eps, inplace);
+ result = ggml_mul_impl(ctx, result, b, inplace);
+ return result;
+ }
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ ggml_set_op_params(result, &eps, sizeof(eps));
+
+ result->op = GGML_OP_FUSED_RMS_NORM;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_fused_rms_norm(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ float eps) {
+ return ggml_fused_rms_norm_impl(ctx, a, b, eps, false);
+}
+
+struct ggml_tensor * ggml_fused_rms_norm_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ float eps) {
+ return ggml_fused_rms_norm_impl(ctx, a, b, eps, true);
+}
+
// ggml_rms_norm_back
struct ggml_tensor * ggml_rms_norm_back(
@@ -12455,6 +12508,78 @@ static void ggml_compute_forward_rms_norm(
}
}
+static void ggml_compute_forward_fused_rms_norm_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ if (!src1) {
+ ggml_compute_forward_rms_norm_f32(params, dst);
+ return;
+ }
+
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
+ GGML_ASSERT(src1->ne[0] == src0->ne[0]);
+ GGML_ASSERT(ggml_nrows(src1) == 1);
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ GGML_ASSERT(eps > 0.0f);
+
+ // TODO: optimize
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+ const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+ ggml_float sum = 0.0;
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ sum += (ggml_float)(x[i00] * x[i00]);
+ }
+
+ const float mean = sum/ne00;
+
+ float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+ const float scale = 1.0f/sqrtf(mean + eps);
+
+ ggml_vec_mul_f32(ne00, y, x, (const float *)src1->data);
+ ggml_vec_scale_f32(ne00, y, scale);
+
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_fused_rms_norm(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_fused_rms_norm_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
static void ggml_compute_forward_rms_norm_back_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
@@ -17708,6 +17833,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_rms_norm(params, tensor);
} break;
+ case GGML_OP_FUSED_RMS_NORM:
+ {
+ ggml_compute_forward_fused_rms_norm(params, tensor);
+ } break;
case GGML_OP_RMS_NORM_BACK:
{
ggml_compute_forward_rms_norm_back(params, tensor);
@@ -18398,6 +18527,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
zero_table);
}
} break;
+ case GGML_OP_FUSED_RMS_NORM:
+ {
+ GGML_ABORT("fatal error"); // TODO: not implemented
+ }
case GGML_OP_RMS_NORM_BACK:
{
GGML_ABORT("fatal error"); // TODO: not implemented
@@ -19465,6 +19598,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_DIV:
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
+ case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_GROUP_NORM:
case GGML_OP_CONCAT: