summaryrefslogtreecommitdiff
path: root/ggml/include
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-09-08 10:19:21 +0300
committerGitHub <noreply@github.com>2024-09-08 10:19:21 +0300
commit6136a4b8034f57067e0202d23571c45c98a0bf0b (patch)
tree1d7954eb8cf97f1c26b03fe220b5fb2c9d06ddef /ggml/include
parent0087008d2999eea83f20fd17c775fdc5f8b4b6b5 (diff)
Adding fused rms_norm (#42)
* Fused rms_norm: works on the CPU * Fused rms_norm WIP * Fused rms_norm WIP * Fused rms_norm WIP * Fused rms_norm WIP * Fused rms_norm WIP --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'ggml/include')
-rw-r--r--ggml/include/ggml.h13
1 files changed, 13 insertions, 0 deletions
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 1a4a516c..ab6d172d 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -480,6 +480,7 @@ extern "C" {
GGML_OP_RMS_NORM,
GGML_OP_RMS_NORM_BACK,
GGML_OP_GROUP_NORM,
+ GGML_OP_FUSED_RMS_NORM,
GGML_OP_MUL_MAT,
GGML_OP_MUL_MAT_ID,
@@ -1159,6 +1160,18 @@ extern "C" {
struct ggml_tensor * a,
float eps);
+ GGML_API struct ggml_tensor * ggml_fused_rms_norm(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ float eps);
+
+ GGML_API struct ggml_tensor * ggml_fused_rms_norm_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ float eps);
+
// group normalize along ne0*ne1*n_groups
// used in stable-diffusion
GGML_API struct ggml_tensor * ggml_group_norm(