From 3f111ad7bbb2d4f721332f9b2b344e48b3bbf9aa Mon Sep 17 00:00:00 2001 From: firecoperana Date: Thu, 19 Jun 2025 02:24:53 -0500 Subject: add dry sampler (#513) * add dry sampler * use vocab instead of model in dry_init function * fix compile error for build test --------- Co-authored-by: firecoperana --- common/common.cpp | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) (limited to 'common/common.cpp') diff --git a/common/common.cpp b/common/common.cpp index 20e583fc..208d4511 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -666,6 +666,47 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa sparams.top_n_sigma = std::stof(argv[i]); return true; } + + if (arg == "--dry-multiplier") { + CHECK_ARG + sparams.dry_multiplier = std::stof(argv[i]); + return true; + } + if (arg == "--dry-base") { + CHECK_ARG + sparams.dry_base = std::stof(argv[i]); + return true; + } + if (arg == "--dry-allowed-length") { + CHECK_ARG + sparams.dry_allowed_length = std::stof(argv[i]); + return true; + } + if (arg == "--dry-penalty-last-n") { + CHECK_ARG + sparams.dry_penalty_last_n = std::stof(argv[i]); + return true; + } + if (arg == "--dry-sequence-breaker") { + CHECK_ARG + static bool defaults_cleared = false; + + if (!defaults_cleared) { + params.sparams.dry_sequence_breakers.clear(); + defaults_cleared = true; + } + std::string value= std::string(argv[i]); + if (value == "none") { + params.sparams.dry_sequence_breakers.clear(); + } + else { + for (size_t i; i < value.size(); i++) + { + params.sparams.dry_sequence_breakers.emplace_back(""+value[i]); + } + } + return true; + } if (arg == "--cfg-negative-prompt") { CHECK_ARG sparams.cfg_negative_prompt = argv[i]; @@ -2326,6 +2367,11 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; } + if (params.sparams.dry_penalty_last_n == -1) { + LOG("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); + params.sparams.dry_penalty_last_n = llama_n_ctx(lctx); + } + if (params.warmup) { LOG("warming up the model with an empty run\n"); @@ -3389,6 +3435,10 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); + fprintf(stream, "dry_allowed_length: %d # default: 2\n", sparams.dry_allowed_length); + fprintf(stream, "dry_base: %.2f # default: 1.75\n", sparams.dry_base); + fprintf(stream, "dry_multiplier: %.1f # default: 0.0\n", sparams.dry_multiplier); + fprintf(stream, "dry_penalty_last_n: %d # default: -1 (0 = disable, -1 = context size)\n", sparams.dry_penalty_last_n); fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq); -- cgit v1.2.3