summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-09-08 10:19:21 +0300
committerGitHub <noreply@github.com>2024-09-08 10:19:21 +0300
commit6136a4b8034f57067e0202d23571c45c98a0bf0b (patch)
tree1d7954eb8cf97f1c26b03fe220b5fb2c9d06ddef /src
parent0087008d2999eea83f20fd17c775fdc5f8b4b6b5 (diff)
Adding fused rms_norm (#42)
* Fused rms_norm: works on the CPU * Fused rms_norm WIP * Fused rms_norm WIP * Fused rms_norm WIP * Fused rms_norm WIP * Fused rms_norm WIP --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'src')
-rw-r--r--src/llama.cpp10
1 files changed, 10 insertions, 0 deletions
diff --git a/src/llama.cpp b/src/llama.cpp
index 8a85144e..768aafa7 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -7987,6 +7987,16 @@ static struct ggml_tensor * llm_build_norm(
llm_norm_type type,
const llm_build_cb & cb,
int il, float scale_eps = 1) {
+
+ if (type == LLM_NORM_RMS && mw) {
+ cur = ggml_fused_rms_norm(ctx, cur, mw, scale_eps * hparams.f_norm_rms_eps);
+ if (mb) {
+ cb(cur, "fused_norm", il);
+ cur = ggml_add(ctx, cur, mb);
+ }
+ return cur;
+ }
+
switch (type) {
case LLM_NORM: cur = ggml_norm (ctx, cur, hparams.f_norm_eps); break;
case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, scale_eps * hparams.f_norm_rms_eps); break;