summaryrefslogtreecommitdiff
path: root/common/common.cpp
diff options
context:
space:
mode:
authorfirecoperana <xuqiaowei1124@gmail.com>2025-06-19 02:24:53 -0500
committerGitHub <noreply@github.com>2025-06-19 10:24:53 +0300
commit3f111ad7bbb2d4f721332f9b2b344e48b3bbf9aa (patch)
treea3a17ee74e0436253e17f0d322320ed554d34b0a /common/common.cpp
parentc5368148cf3af7a3694e0eb03d24a08326c01d12 (diff)
add dry sampler (#513)
* add dry sampler * use vocab instead of model in dry_init function * fix compile error for build test --------- Co-authored-by: firecoperana <firecoperana>
Diffstat (limited to 'common/common.cpp')
-rw-r--r--common/common.cpp50
1 files changed, 50 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 20e583fc..208d4511 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -666,6 +666,47 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
sparams.top_n_sigma = std::stof(argv[i]);
return true;
}
+
+ if (arg == "--dry-multiplier") {
+ CHECK_ARG
+ sparams.dry_multiplier = std::stof(argv[i]);
+ return true;
+ }
+ if (arg == "--dry-base") {
+ CHECK_ARG
+ sparams.dry_base = std::stof(argv[i]);
+ return true;
+ }
+ if (arg == "--dry-allowed-length") {
+ CHECK_ARG
+ sparams.dry_allowed_length = std::stof(argv[i]);
+ return true;
+ }
+ if (arg == "--dry-penalty-last-n") {
+ CHECK_ARG
+ sparams.dry_penalty_last_n = std::stof(argv[i]);
+ return true;
+ }
+ if (arg == "--dry-sequence-breaker") {
+ CHECK_ARG
+ static bool defaults_cleared = false;
+
+ if (!defaults_cleared) {
+ params.sparams.dry_sequence_breakers.clear();
+ defaults_cleared = true;
+ }
+ std::string value= std::string(argv[i]);
+ if (value == "none") {
+ params.sparams.dry_sequence_breakers.clear();
+ }
+ else {
+ for (size_t i; i < value.size(); i++)
+ {
+ params.sparams.dry_sequence_breakers.emplace_back(""+value[i]);
+ }
+ }
+ return true;
+ }
if (arg == "--cfg-negative-prompt") {
CHECK_ARG
sparams.cfg_negative_prompt = argv[i];
@@ -2326,6 +2367,11 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
}
+ if (params.sparams.dry_penalty_last_n == -1) {
+ LOG("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx));
+ params.sparams.dry_penalty_last_n = llama_n_ctx(lctx);
+ }
+
if (params.warmup) {
LOG("warming up the model with an empty run\n");
@@ -3389,6 +3435,10 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
+ fprintf(stream, "dry_allowed_length: %d # default: 2\n", sparams.dry_allowed_length);
+ fprintf(stream, "dry_base: %.2f # default: 1.75\n", sparams.dry_base);
+ fprintf(stream, "dry_multiplier: %.1f # default: 0.0\n", sparams.dry_multiplier);
+ fprintf(stream, "dry_penalty_last_n: %d # default: -1 (0 = disable, -1 = context size)\n", sparams.dry_penalty_last_n);
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);