summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
authorAlexey Parfenov <zxed@alkatrazstudio.net>2023-12-23 09:31:49 +0000
committerGitHub <noreply@github.com>2023-12-23 11:31:49 +0200
commit6123979952385847d8348e295d77d6e01da8aa84 (patch)
tree2d536d31ef7e1b6f07468ff6b13710da1fe4f732 /common
parentb9ec82d262cb20d7f0a8a1157bfa9aace40e2625 (diff)
server : allow to specify custom prompt for penalty calculation (#3727)
Diffstat (limited to 'common')
-rw-r--r--common/sampling.cpp8
-rw-r--r--common/sampling.h3
2 files changed, 8 insertions, 3 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 5b15204b..8e45909f 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -203,12 +203,14 @@ static llama_token llama_sampling_sample_impl(
}
// apply penalties
- if (!prev.empty()) {
+ const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
+ const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
+ if (penalty_tokens_used_size) {
const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
llama_sample_repetition_penalties(ctx_main, &cur_p,
- prev.data() + prev.size() - penalty_last_n,
- penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
+ penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
+ penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {
diff --git a/common/sampling.h b/common/sampling.h
index fdfa9eed..f16ef97e 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -36,6 +36,9 @@ typedef struct llama_sampling_params {
float cfg_scale = 1.f; // how strong is guidance
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
+
+ std::vector<llama_token> penalty_prompt_tokens;
+ bool use_penalty_prompt_tokens = false;
} llama_sampling_params;
// general sampler context