summaryrefslogtreecommitdiff
path: root/ggml/src/ggml.c
diff options
context:
space:
mode:
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r--ggml/src/ggml.c138
1 files changed, 136 insertions, 2 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 45fddca5..d562002e 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -3144,6 +3144,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"RMS_NORM",
"RMS_NORM_BACK",
"GROUP_NORM",
+ "FUSED_RMS_NORM",
"MUL_MAT",
"MUL_MAT_ID",
@@ -3207,7 +3208,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
+static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -3234,6 +3235,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"rms_norm(x)",
"rms_norm_back(x)",
"group_norm(x)",
+ "fused_rms_norm(x)",
"X*Y",
"X[i]*Y",
@@ -3297,7 +3299,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
+static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -5737,6 +5739,57 @@ struct ggml_tensor * ggml_rms_norm_inplace(
return ggml_rms_norm_impl(ctx, a, eps, true);
}
+static struct ggml_tensor * ggml_fused_rms_norm_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ float eps,
+ bool inplace) {
+
+ if (!b) {
+ return ggml_rms_norm_impl(ctx, a, eps, inplace);
+ }
+
+ if (ggml_nrows(b) > 1 || a->ne[0] != b->ne[0]) {
+ struct ggml_tensor * result = ggml_rms_norm_impl(ctx, a, eps, inplace);
+ result = ggml_mul_impl(ctx, result, b, inplace);
+ return result;
+ }
+
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ ggml_set_op_params(result, &eps, sizeof(eps));
+
+ result->op = GGML_OP_FUSED_RMS_NORM;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+ result->src[1] = b;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_fused_rms_norm(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ float eps) {
+ return ggml_fused_rms_norm_impl(ctx, a, b, eps, false);
+}
+
+struct ggml_tensor * ggml_fused_rms_norm_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ float eps) {
+ return ggml_fused_rms_norm_impl(ctx, a, b, eps, true);
+}
+
// ggml_rms_norm_back
struct ggml_tensor * ggml_rms_norm_back(
@@ -12455,6 +12508,78 @@ static void ggml_compute_forward_rms_norm(
}
}
+static void ggml_compute_forward_fused_rms_norm_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+ const struct ggml_tensor * src1 = dst->src[1];
+
+ if (!src1) {
+ ggml_compute_forward_rms_norm_f32(params, dst);
+ return;
+ }
+
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
+ GGML_ASSERT(src1->ne[0] == src0->ne[0]);
+ GGML_ASSERT(ggml_nrows(src1) == 1);
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ float eps;
+ memcpy(&eps, dst->op_params, sizeof(float));
+
+ GGML_ASSERT(eps > 0.0f);
+
+ // TODO: optimize
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
+ for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
+ const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
+
+ ggml_float sum = 0.0;
+ for (int64_t i00 = 0; i00 < ne00; i00++) {
+ sum += (ggml_float)(x[i00] * x[i00]);
+ }
+
+ const float mean = sum/ne00;
+
+ float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
+
+ const float scale = 1.0f/sqrtf(mean + eps);
+
+ ggml_vec_mul_f32(ne00, y, x, (const float *)src1->data);
+ ggml_vec_scale_f32(ne00, y, scale);
+
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_fused_rms_norm(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_fused_rms_norm_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
static void ggml_compute_forward_rms_norm_back_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
@@ -17708,6 +17833,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_rms_norm(params, tensor);
} break;
+ case GGML_OP_FUSED_RMS_NORM:
+ {
+ ggml_compute_forward_fused_rms_norm(params, tensor);
+ } break;
case GGML_OP_RMS_NORM_BACK:
{
ggml_compute_forward_rms_norm_back(params, tensor);
@@ -18398,6 +18527,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
zero_table);
}
} break;
+ case GGML_OP_FUSED_RMS_NORM:
+ {
+ GGML_ABORT("fatal error"); // TODO: not implemented
+ }
case GGML_OP_RMS_NORM_BACK:
{
GGML_ABORT("fatal error"); // TODO: not implemented
@@ -19465,6 +19598,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_DIV:
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
+ case GGML_OP_FUSED_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_GROUP_NORM:
case GGML_OP_CONCAT: