From 6123979952385847d8348e295d77d6e01da8aa84 Mon Sep 17 00:00:00 2001 From: Alexey Parfenov Date: Sat, 23 Dec 2023 09:31:49 +0000 Subject: server : allow to specify custom prompt for penalty calculation (#3727) --- common/sampling.cpp | 8 +++++--- common/sampling.h | 3 +++ 2 files changed, 8 insertions(+), 3 deletions(-) (limited to 'common') diff --git a/common/sampling.cpp b/common/sampling.cpp index 5b15204b..8e45909f 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -203,12 +203,14 @@ static llama_token llama_sampling_sample_impl( } // apply penalties - if (!prev.empty()) { + const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; + const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); + if (penalty_tokens_used_size) { const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; llama_sample_repetition_penalties(ctx_main, &cur_p, - prev.data() + prev.size() - penalty_last_n, - penalty_last_n, penalty_repeat, penalty_freq, penalty_present); + penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, + penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); if (!penalize_nl) { for (size_t idx = 0; idx < cur_p.size; idx++) { diff --git a/common/sampling.h b/common/sampling.h index fdfa9eed..f16ef97e 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -36,6 +36,9 @@ typedef struct llama_sampling_params { float cfg_scale = 1.f; // how strong is guidance std::unordered_map logit_bias; // logit bias for specific tokens + + std::vector penalty_prompt_tokens; + bool use_penalty_prompt_tokens = false; } llama_sampling_params; // general sampler context -- cgit v1.2.3