summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp220
1 files changed, 167 insertions, 53 deletions
diff --git a/llama.cpp b/llama.cpp
index 1c6d482f..685882c2 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -54,6 +54,7 @@
#include <cassert>
#include <cinttypes>
#include <climits>
+#include <cmath>
#include <cstdarg>
#include <cstddef>
#include <cstdint>
@@ -235,6 +236,10 @@ enum llm_kv {
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_FREQ_BASE,
LLM_KV_ROPE_SCALE_LINEAR,
+ LLM_KV_ROPE_SCALING_TYPE,
+ LLM_KV_ROPE_SCALING_FACTOR,
+ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
+ LLM_KV_ROPE_SCALING_FINETUNED,
LLM_KV_TOKENIZER_MODEL,
LLM_KV_TOKENIZER_LIST,
@@ -276,9 +281,13 @@ static std::map<llm_kv, std::string> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
- { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
- { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
- { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
+ { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
+ { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
+ { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
+ { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
+ { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
+ { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
+ { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
@@ -552,6 +561,22 @@ do { \
} \
} while (0)
+static std::map<int8_t, std::string> LLAMA_ROPE_SCALING_TYPES = {
+ { LLAMA_ROPE_SCALING_NONE, "none" },
+ { LLAMA_ROPE_SCALING_LINEAR, "linear" },
+ { LLAMA_ROPE_SCALING_YARN, "yarn" },
+};
+
+static int8_t llama_rope_scaling_type_from_string(const std::string & name) {
+ for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
+ if (kv.second == name) {
+ return kv.first;
+ }
+ }
+
+ return LLAMA_ROPE_SCALING_UNSPECIFIED;
+}
+
//
// ggml helpers
//
@@ -1035,8 +1060,11 @@ struct llama_hparams {
float f_norm_eps;
float f_norm_rms_eps;
- float rope_freq_base_train;
- float rope_freq_scale_train;
+ float rope_freq_base_train;
+ float rope_freq_scale_train;
+ uint32_t n_yarn_orig_ctx;
+ int8_t rope_scaling_type_train : 3;
+ bool rope_finetuned : 1;
float f_clamp_kqv;
float f_max_alibi_bias;
@@ -1051,6 +1079,8 @@ struct llama_hparams {
if (this->n_layer != other.n_layer) return true;
if (this->n_rot != other.n_rot) return true;
if (this->n_ff != other.n_ff) return true;
+ if (this->rope_finetuned != other.rope_finetuned) return true;
+ if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
const float EPSILON = 1e-9;
@@ -1081,8 +1111,16 @@ struct llama_cparams {
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing
- float rope_freq_base;
- float rope_freq_scale;
+ float rope_freq_base;
+ float rope_freq_scale;
+
+ uint32_t n_yarn_orig_ctx;
+ // These hyperparameters are not exposed in GGUF, because all
+ // existing YaRN models use the same values for them.
+ float yarn_ext_factor;
+ float yarn_attn_factor;
+ float yarn_beta_fast;
+ float yarn_beta_slow;
bool mul_mat_q;
};
@@ -2014,14 +2052,30 @@ static void llm_load_hparams(
hparams.n_head_kv = hparams.n_head;
GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
+ hparams.rope_finetuned = false;
+ GGUF_GET_KEY(ctx, hparams.rope_finetuned, gguf_get_val_bool, GGUF_TYPE_BOOL, false,
+ kv(LLM_KV_ROPE_SCALING_FINETUNED));
+
+ hparams.n_yarn_orig_ctx = hparams.n_ctx_train;
+ GGUF_GET_KEY(ctx, hparams.n_yarn_orig_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false,
+ kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN));
+
// rope_freq_base (optional)
hparams.rope_freq_base_train = 10000.0f;
GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
+ std::string rope_scaling("linear");
+ GGUF_GET_KEY(ctx, rope_scaling, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_ROPE_SCALING_TYPE));
+ hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling);
+ GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_UNSPECIFIED);
+
// rope_freq_scale (inverse of the kv) is optional
- float ropescale = 1.0f;
- GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
- hparams.rope_freq_scale_train = 1.0f/ropescale;
+ float ropescale = 0.0f;
+ GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALING_FACTOR));
+ if (ropescale == 0.0f) { // try the old key name
+ GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
+ }
+ hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
// sanity check for n_rot (optional)
{
@@ -2371,6 +2425,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
const auto & hparams = model.hparams;
const auto & vocab = model.vocab;
+ const auto rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
+
// hparams
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver));
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
@@ -2389,8 +2445,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
+ LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
+ LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx);
+ LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype).c_str());
LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, ml.n_elements*1e-9);
@@ -3047,21 +3106,11 @@ static void llm_load_tensors(
model.t_load_us = ggml_time_us() - model.t_start_us;
}
-static bool llama_model_load(
- const std::string & fname,
- llama_model & model,
- int n_gpu_layers,
- int main_gpu,
- const float * tensor_split,
- bool use_mmap,
- bool use_mlock,
- bool vocab_only,
- llama_progress_callback progress_callback,
- void *progress_callback_user_data) {
+static bool llama_model_load(const std::string & fname, llama_model & model, const llama_model_params & params) {
try {
- llama_model_loader ml(fname, use_mmap);
+ llama_model_loader ml(fname, params.use_mmap);
- model.hparams.vocab_only = vocab_only;
+ model.hparams.vocab_only = params.vocab_only;
llm_load_arch (ml, model);
llm_load_hparams(ml, model);
@@ -3073,15 +3122,15 @@ static bool llama_model_load(
throw std::runtime_error("vocab size mismatch");
}
- if (vocab_only) {
+ if (params.vocab_only) {
LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
return true;
}
llm_load_tensors(
- ml, model, n_gpu_layers,
- main_gpu, tensor_split,
- use_mlock, progress_callback, progress_callback_user_data);
+ ml, model, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.use_mlock,
+ params.progress_callback, params.progress_callback_user_data
+ );
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("error loading model: %s\n", err.what());
return false;
@@ -3150,6 +3199,7 @@ static struct ggml_tensor * llm_build_inp_embd(
static void llm_build_k_shift(
struct ggml_context * ctx,
const llama_hparams & hparams,
+ const llama_cparams & cparams,
const llama_kv_cache & kv,
struct ggml_cgraph * graph,
llm_rope_type type,
@@ -3162,6 +3212,11 @@ static void llm_build_k_shift(
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_gqa = hparams.n_embd_gqa();
const int64_t n_embd_head = hparams.n_embd_head();
+ const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx;
+ const float ext_factor = cparams.yarn_ext_factor;
+ const float attn_factor = cparams.yarn_attn_factor;
+ const float beta_fast = cparams.yarn_beta_fast;
+ const float beta_slow = cparams.yarn_beta_slow;
GGML_ASSERT(n_embd_head % n_rot == 0);
@@ -3185,7 +3240,8 @@ static void llm_build_k_shift(
ggml_element_size(kv.k)*n_embd_head,
ggml_element_size(kv.k)*n_embd_gqa,
ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il),
- K_shift, n_rot, rope_type, 0, freq_base, freq_scale);
+ K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow);
cb(tmp, "K_shifted", il);
ggml_build_forward_expand(graph, tmp);
}
@@ -3442,12 +3498,17 @@ struct llm_build_context {
const float freq_base;
const float freq_scale;
+ const float ext_factor;
+ const float attn_factor;
+ const float beta_fast;
+ const float beta_slow;
const float norm_eps;
const float norm_rms_eps;
const int32_t n_tokens;
const int32_t n_kv; // size of KV cache to consider (n_kv <= n_ctx)
const int32_t kv_head; // index of where we store new KV data in the cache
+ const int32_t n_orig_ctx;
const bool do_rope_shift;
@@ -3477,11 +3538,16 @@ struct llm_build_context {
n_embd_gqa (hparams.n_embd_gqa()),
freq_base (cparams.rope_freq_base),
freq_scale (cparams.rope_freq_scale),
+ ext_factor (cparams.yarn_ext_factor),
+ attn_factor (cparams.yarn_attn_factor),
+ beta_fast (cparams.yarn_beta_fast),
+ beta_slow (cparams.yarn_beta_slow),
norm_eps (hparams.f_norm_eps),
norm_rms_eps (hparams.f_norm_rms_eps),
n_tokens (batch.n_tokens),
n_kv (worst_case ? n_ctx : kv_self.n),
kv_head (worst_case ? n_ctx - n_tokens : kv_self.head),
+ n_orig_ctx (cparams.n_yarn_orig_ctx),
do_rope_shift (worst_case || kv_self.has_shift),
cb (cb),
buf_compute (lctx.buf_compute) {
@@ -3532,7 +3598,7 @@ struct llm_build_context {
// shift the entire K-cache if needed
if (do_rope_shift) {
- llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+ llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
@@ -3556,10 +3622,18 @@ struct llm_build_context {
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
- Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
+ Qcur = ggml_rope_custom(
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+ n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
cb(Qcur, "Qcur", il);
- Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
+ Kcur = ggml_rope_custom(
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+ n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
cb(Kcur, "Kcur", il);
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
@@ -3634,7 +3708,7 @@ struct llm_build_context {
// shift the entire K-cache if needed
if (do_rope_shift) {
- llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+ llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
@@ -3658,8 +3732,16 @@ struct llm_build_context {
switch (model.type) {
case MODEL_7B:
- Qcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
- Kcur = ggml_rope_custom(ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, n_embd_head, 0, 0, freq_base, freq_scale);
+ Qcur = ggml_rope_custom(
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
+ n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ Kcur = ggml_rope_custom(
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
+ n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
break;
case MODEL_13B:
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, n_tokens);
@@ -3746,7 +3828,7 @@ struct llm_build_context {
// shift the entire K-cache if needed
if (do_rope_shift) {
- llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+ llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
@@ -3786,10 +3868,16 @@ struct llm_build_context {
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
// using mode = 2 for neox mode
- Qcur = ggml_rope_custom(ctx0, Qcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
+ Qcur = ggml_rope_custom(
+ ctx0, Qcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx,
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+ );
cb(Qcur, "Qcur", il);
- Kcur = ggml_rope_custom(ctx0, Kcur, inp_pos, n_embd_head, 2, 0, freq_base, freq_scale);
+ Kcur = ggml_rope_custom(
+ ctx0, Kcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx,
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+ );
cb(Kcur, "Kcur", il);
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
@@ -3960,7 +4048,7 @@ struct llm_build_context {
cb(KQ_mask, "KQ_mask", -1);
if (do_rope_shift) {
- llm_build_k_shift(ctx0, hparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+ llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
@@ -4053,13 +4141,15 @@ struct llm_build_context {
cb(kpass, "kpass", il);
struct ggml_tensor * qrotated = ggml_rope_custom(
- ctx0, qrot, inp_pos, n_rot, 2, 0, freq_base, freq_scale
- );
+ ctx0, qrot, inp_pos, n_rot, 2, 0, n_orig_ctx,
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+ );
cb(qrotated, "qrotated", il);
struct ggml_tensor * krotated = ggml_rope_custom(
- ctx0, krot, inp_pos, n_rot, 2, 0, freq_base, freq_scale
- );
+ ctx0, krot, inp_pos, n_rot, 2, 0, n_orig_ctx,
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+ );
cb(krotated, "krotated", il);
// ggml currently only supports concatenation on dim=2
@@ -7883,8 +7973,13 @@ struct llama_context_params llama_context_default_params() {
/*.n_batch =*/ 512,
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
+ /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_UNSPECIFIED,
/*.rope_freq_base =*/ 0.0f,
/*.rope_freq_scale =*/ 0.0f,
+ /*.yarn_ext_factor =*/ NAN,
+ /*.yarn_attn_factor =*/ 1.0f,
+ /*.yarn_beta_fast =*/ 32.0f,
+ /*.yarn_beta_slow =*/ 1.0f,
/*.mul_mat_q =*/ true,
/*.f16_kv =*/ true,
/*.logits_all =*/ false,
@@ -7971,10 +8066,7 @@ struct llama_model * llama_load_model_from_file(
};
}
- if (!llama_model_load(path_model, *model, params.n_gpu_layers,
- params.main_gpu, params.tensor_split,
- params.use_mmap, params.use_mlock, params.vocab_only,
- params.progress_callback, params.progress_callback_user_data)) {
+ if (!llama_model_load(path_model, *model, params)) {
LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
delete model;
return nullptr;
@@ -8000,13 +8092,35 @@ struct llama_context * llama_new_context_with_model(
const auto & hparams = model->hparams;
auto & cparams = ctx->cparams;
- cparams.n_batch = params.n_batch;
- cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
- cparams.rope_freq_base = params.rope_freq_base == 0 ? hparams.rope_freq_base_train : params.rope_freq_base;
- cparams.rope_freq_scale = params.rope_freq_scale == 0 ? hparams.rope_freq_scale_train : params.rope_freq_scale;
- cparams.n_threads = params.n_threads;
- cparams.n_threads_batch = params.n_threads_batch;
- cparams.mul_mat_q = params.mul_mat_q;
+ cparams.n_batch = params.n_batch;
+ cparams.n_threads = params.n_threads;
+ cparams.n_threads_batch = params.n_threads_batch;
+ cparams.yarn_ext_factor = params.yarn_ext_factor;
+ cparams.yarn_attn_factor = params.yarn_attn_factor;
+ cparams.yarn_beta_fast = params.yarn_beta_fast;
+ cparams.yarn_beta_slow = params.yarn_beta_slow;
+ cparams.mul_mat_q = params.mul_mat_q;
+
+ cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
+ cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
+ cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
+
+ cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
+ hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
+ hparams.n_ctx_train;
+
+ auto rope_scaling_type = params.rope_scaling_type;
+ if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) {
+ rope_scaling_type = hparams.rope_scaling_type_train;
+ }
+
+ if (rope_scaling_type == LLAMA_ROPE_SCALING_NONE) {
+ cparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
+ }
+
+ if (std::isnan(cparams.yarn_ext_factor)) { // NaN indicates 'not set'
+ cparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_YARN ? 1.0f : 0.0f;
+ }
if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);