diff options
author | klosax <131523366+klosax@users.noreply.github.com> | 2023-08-07 19:07:19 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-07 19:07:19 +0200 |
commit | f3c3b4b1672d860800639c87d3b5d17564692469 (patch) | |
tree | ebdc8a40a9e374eb4713da9c6233c8c499cd768a /examples/common.cpp | |
parent | 93356bdb7a324a8f6570f99d02af392cd4c45796 (diff) |
Add --rope-scale parameter (#2544)
* common.cpp : Add --rope-scale parameter
* README.md : Add info about using linear rope scaling
Diffstat (limited to 'examples/common.cpp')
-rw-r--r-- | examples/common.cpp | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/examples/common.cpp b/examples/common.cpp index 21f4a035..4d3ba9bb 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -194,6 +194,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.rope_freq_scale = std::stof(argv[i]); + } else if (arg == "--rope-scale") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.rope_freq_scale = 1.0f/std::stof(argv[i]); } else if (arg == "--memory-f32") { params.memory_f16 = false; } else if (arg == "--top-p") { @@ -564,8 +570,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " --cfg-negative-prompt PROMPT \n"); fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n"); fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); - fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base); - fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale); + fprintf(stdout, " --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale); + fprintf(stdout, " --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base); + fprintf(stdout, " --rope-freq-scale N RoPE frequency linear scaling factor, inverse of --rope-scale (default: %g)\n", params.rope_freq_scale); fprintf(stdout, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); fprintf(stdout, " --no-penalize-nl do not penalize newline token\n"); fprintf(stdout, " --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); |