From 898aeca90a9bb992f506234cf3b8b7f7fa28a1df Mon Sep 17 00:00:00 2001 From: cebtenzzre Date: Wed, 1 Nov 2023 18:04:33 -0400 Subject: llama : implement YaRN RoPE scaling (#2268) Co-authored-by: cebtenzzre Co-authored-by: Jeffrey Quesnelle --- ggml.h | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) (limited to 'ggml.h') diff --git a/ggml.h b/ggml.h index 9d16c5a7..70eb25a6 100644 --- a/ggml.h +++ b/ggml.h @@ -219,7 +219,7 @@ #define GGML_MAX_CONTEXTS 64 #define GGML_MAX_SRC 6 #define GGML_MAX_NAME 64 -#define GGML_MAX_OP_PARAMS 32 +#define GGML_MAX_OP_PARAMS 64 #define GGML_DEFAULT_N_THREADS 4 #if UINTPTR_MAX == 0xFFFFFFFF @@ -1326,8 +1326,13 @@ extern "C" { int n_dims, int mode, int n_ctx, + int n_orig_ctx, float freq_base, - float freq_scale); + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); // in-place, returns view(a) GGML_API struct ggml_tensor * ggml_rope_custom_inplace( @@ -1337,8 +1342,17 @@ extern "C" { int n_dims, int mode, int n_ctx, + int n_orig_ctx, float freq_base, - float freq_scale); + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); + + // compute correction dims for YaRN RoPE scaling + void ggml_rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]); // xPos RoPE, in-place, returns view(a) GGML_API struct ggml_tensor * ggml_rope_xpos_inplace( -- cgit v1.2.3