From 6136a4b8034f57067e0202d23571c45c98a0bf0b Mon Sep 17 00:00:00 2001 From: Kawrakow <48489457+ikawrakow@users.noreply.github.com> Date: Sun, 8 Sep 2024 10:19:21 +0300 Subject: Adding fused rms_norm (#42) * Fused rms_norm: works on the CPU * Fused rms_norm WIP * Fused rms_norm WIP * Fused rms_norm WIP * Fused rms_norm WIP * Fused rms_norm WIP --------- Co-authored-by: Iwan Kawrakow --- src/llama.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'src') diff --git a/src/llama.cpp b/src/llama.cpp index 8a85144e..768aafa7 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7987,6 +7987,16 @@ static struct ggml_tensor * llm_build_norm( llm_norm_type type, const llm_build_cb & cb, int il, float scale_eps = 1) { + + if (type == LLM_NORM_RMS && mw) { + cur = ggml_fused_rms_norm(ctx, cur, mw, scale_eps * hparams.f_norm_rms_eps); + if (mb) { + cb(cur, "fused_norm", il); + cur = ggml_add(ctx, cur, mb); + } + return cur; + } + switch (type) { case LLM_NORM: cur = ggml_norm (ctx, cur, hparams.f_norm_eps); break; case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, scale_eps * hparams.f_norm_rms_eps); break; -- cgit v1.2.3