summaryrefslogtreecommitdiff
path: root/common/common.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'common/common.cpp')
-rw-r--r--common/common.cpp79
1 files changed, 66 insertions, 13 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 7a48e9d1..b182ffaa 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -219,12 +219,52 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.rope_freq_scale = std::stof(argv[i]);
+ } else if (arg == "--rope-scaling") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ std::string value(argv[i]);
+ /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
+ else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
+ else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
+ else { invalid_param = true; break; }
} else if (arg == "--rope-scale") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.rope_freq_scale = 1.0f/std::stof(argv[i]);
+ } else if (arg == "--yarn-orig-ctx") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.yarn_orig_ctx = std::stoi(argv[i]);
+ } else if (arg == "--yarn-ext-factor") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.yarn_ext_factor = std::stof(argv[i]);
+ } else if (arg == "--yarn-attn-factor") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.yarn_attn_factor = std::stof(argv[i]);
+ } else if (arg == "--yarn-beta-fast") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.yarn_beta_fast = std::stof(argv[i]);
+ } else if (arg == "--yarn-beta-slow") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.yarn_beta_slow = std::stof(argv[i]);
} else if (arg == "--memory-f32") {
params.memory_f16 = false;
} else if (arg == "--top-p") {
@@ -716,9 +756,16 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --cfg-negative-prompt-file FNAME\n");
printf(" negative prompt file to use for guidance. (default: empty)\n");
printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale);
- printf(" --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale\n");
+ printf(" --rope-scaling {none,linear,yarn}\n");
+ printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
+ printf(" --rope-scale N RoPE context scaling factor, expands context by a factor of N\n");
printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n");
- printf(" --rope-freq-scale N RoPE frequency linear scaling factor (default: loaded from model)\n");
+ printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n");
+ printf(" --yarn-orig-ctx N YaRN: original context size of model (default: 0 = model training context size)\n");
+ printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
+ printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
+ printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
+ printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
printf(" --no-penalize-nl do not penalize newline token\n");
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
@@ -826,17 +873,23 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
auto cparams = llama_context_default_params();
- cparams.n_ctx = params.n_ctx;
- cparams.n_batch = params.n_batch;
- cparams.n_threads = params.n_threads;
- cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
- cparams.mul_mat_q = params.mul_mat_q;
- cparams.seed = params.seed;
- cparams.f16_kv = params.memory_f16;
- cparams.logits_all = params.logits_all;
- cparams.embedding = params.embedding;
- cparams.rope_freq_base = params.rope_freq_base;
- cparams.rope_freq_scale = params.rope_freq_scale;
+ cparams.n_ctx = params.n_ctx;
+ cparams.n_batch = params.n_batch;
+ cparams.n_threads = params.n_threads;
+ cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
+ cparams.mul_mat_q = params.mul_mat_q;
+ cparams.seed = params.seed;
+ cparams.f16_kv = params.memory_f16;
+ cparams.logits_all = params.logits_all;
+ cparams.embedding = params.embedding;
+ cparams.rope_scaling_type = params.rope_scaling_type;
+ cparams.rope_freq_base = params.rope_freq_base;
+ cparams.rope_freq_scale = params.rope_freq_scale;
+ cparams.yarn_ext_factor = params.yarn_ext_factor;
+ cparams.yarn_attn_factor = params.yarn_attn_factor;
+ cparams.yarn_beta_fast = params.yarn_beta_fast;
+ cparams.yarn_beta_slow = params.yarn_beta_slow;
+ cparams.yarn_orig_ctx = params.yarn_orig_ctx;
return cparams;
}