diff options
Diffstat (limited to 'ggml/src/ggml.c')
-rw-r--r-- | ggml/src/ggml.c | 138 |
1 files changed, 136 insertions, 2 deletions
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 45fddca5..d562002e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3144,6 +3144,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "RMS_NORM", "RMS_NORM_BACK", "GROUP_NORM", + "FUSED_RMS_NORM", "MUL_MAT", "MUL_MAT_ID", @@ -3207,7 +3208,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); +static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3234,6 +3235,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rms_norm(x)", "rms_norm_back(x)", "group_norm(x)", + "fused_rms_norm(x)", "X*Y", "X[i]*Y", @@ -3297,7 +3299,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); +static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5737,6 +5739,57 @@ struct ggml_tensor * ggml_rms_norm_inplace( return ggml_rms_norm_impl(ctx, a, eps, true); } +static struct ggml_tensor * ggml_fused_rms_norm_impl( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps, + bool inplace) { + + if (!b) { + return ggml_rms_norm_impl(ctx, a, eps, inplace); + } + + if (ggml_nrows(b) > 1 || a->ne[0] != b->ne[0]) { + struct ggml_tensor * result = ggml_rms_norm_impl(ctx, a, eps, inplace); + result = ggml_mul_impl(ctx, result, b, inplace); + return result; + } + + bool is_node = false; + + if (!inplace && (a->grad)) { + is_node = true; + } + + struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); + + ggml_set_op_params(result, &eps, sizeof(eps)); + + result->op = GGML_OP_FUSED_RMS_NORM; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +struct ggml_tensor * ggml_fused_rms_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps) { + return ggml_fused_rms_norm_impl(ctx, a, b, eps, false); +} + +struct ggml_tensor * ggml_fused_rms_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps) { + return ggml_fused_rms_norm_impl(ctx, a, b, eps, true); +} + // ggml_rms_norm_back struct ggml_tensor * ggml_rms_norm_back( @@ -12455,6 +12508,78 @@ static void ggml_compute_forward_rms_norm( } } +static void ggml_compute_forward_fused_rms_norm_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + if (!src1) { + ggml_compute_forward_rms_norm_f32(params, dst); + return; + } + + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + GGML_ASSERT(src0->nb[0] == sizeof(float)); + GGML_ASSERT(src1->nb[0] == sizeof(float)); + GGML_ASSERT(src1->ne[0] == src0->ne[0]); + GGML_ASSERT(ggml_nrows(src1) == 1); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + GGML_ASSERT(eps > 0.0f); + + // TODO: optimize + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { + const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + + ggml_float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += (ggml_float)(x[i00] * x[i00]); + } + + const float mean = sum/ne00; + + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + const float scale = 1.0f/sqrtf(mean + eps); + + ggml_vec_mul_f32(ne00, y, x, (const float *)src1->data); + ggml_vec_scale_f32(ne00, y, scale); + + } + } + } +} + +static void ggml_compute_forward_fused_rms_norm( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_fused_rms_norm_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + static void ggml_compute_forward_rms_norm_back_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { @@ -17708,6 +17833,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_rms_norm(params, tensor); } break; + case GGML_OP_FUSED_RMS_NORM: + { + ggml_compute_forward_fused_rms_norm(params, tensor); + } break; case GGML_OP_RMS_NORM_BACK: { ggml_compute_forward_rms_norm_back(params, tensor); @@ -18398,6 +18527,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor zero_table); } } break; + case GGML_OP_FUSED_RMS_NORM: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_RMS_NORM_BACK: { GGML_ABORT("fatal error"); // TODO: not implemented @@ -19465,6 +19598,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_DIV: case GGML_OP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_FUSED_RMS_NORM: case GGML_OP_RMS_NORM_BACK: case GGML_OP_GROUP_NORM: case GGML_OP_CONCAT: |