diff options
author | cebtenzzre <cebtenzzre@gmail.com> | 2023-11-01 18:04:33 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-01 18:04:33 -0400 |
commit | 898aeca90a9bb992f506234cf3b8b7f7fa28a1df (patch) | |
tree | 125f8a9b466efd4534ecd3e64419ece001c86a7d /common/common.h | |
parent | c43c2da8afacaddfe51c09b21dbd9922cd0ea46b (diff) |
llama : implement YaRN RoPE scaling (#2268)
Co-authored-by: cebtenzzre <cebtenzzre@gmail.com>
Co-authored-by: Jeffrey Quesnelle <jquesnelle@gmail.com>
Diffstat (limited to 'common/common.h')
-rw-r--r-- | common/common.h | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/common/common.h b/common/common.h index 343b2721..7be69f92 100644 --- a/common/common.h +++ b/common/common.h @@ -9,6 +9,7 @@ #define LOG_NO_FILE_LINE_FUNCTION #include "log.h" +#include <cmath> #include <string> #include <vector> #include <random> @@ -54,6 +55,12 @@ struct gpt_params { int32_t n_beams = 0; // if non-zero then use beam search of given width. float rope_freq_base = 0.0f; // RoPE base frequency float rope_freq_scale = 0.0f; // RoPE frequency scaling factor + float yarn_ext_factor = NAN; // YaRN extrapolation mix factor + float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor + float yarn_beta_fast = 32.0f;// YaRN low correction dim + float yarn_beta_slow = 1.0f; // YaRN high correction dim + int32_t yarn_orig_ctx = 0; // YaRN original context length + int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // // sampling parameters struct llama_sampling_params sparams; |