diff options
author | Alexey Parfenov <zxed@alkatrazstudio.net> | 2023-12-23 09:31:49 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-12-23 11:31:49 +0200 |
commit | 6123979952385847d8348e295d77d6e01da8aa84 (patch) | |
tree | 2d536d31ef7e1b6f07468ff6b13710da1fe4f732 /examples/server/server.cpp | |
parent | b9ec82d262cb20d7f0a8a1157bfa9aace40e2625 (diff) |
server : allow to specify custom prompt for penalty calculation (#3727)
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r-- | examples/server/server.cpp | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 04038530..72dfe452 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -761,6 +761,42 @@ struct llama_server_context slot->prompt = ""; } + slot->sparams.penalty_prompt_tokens.clear(); + slot->sparams.use_penalty_prompt_tokens = false; + const auto &penalty_prompt = data.find("penalty_prompt"); + if (penalty_prompt != data.end()) + { + if (penalty_prompt->is_string()) + { + const auto penalty_prompt_string = penalty_prompt->get<std::string>(); + auto penalty_tokens = llama_tokenize(model, penalty_prompt_string, false); + slot->sparams.penalty_prompt_tokens.swap(penalty_tokens); + if (slot->params.n_predict > 0) + { + slot->sparams.penalty_prompt_tokens.reserve(slot->sparams.penalty_prompt_tokens.size() + slot->params.n_predict); + } + slot->sparams.use_penalty_prompt_tokens = true; + } + else if (penalty_prompt->is_array()) + { + const auto n_tokens = penalty_prompt->size(); + slot->sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot->params.n_predict)); + const int n_vocab = llama_n_vocab(model); + for (const auto &penalty_token : *penalty_prompt) + { + if (penalty_token.is_number_integer()) + { + const auto tok = penalty_token.get<llama_token>(); + if (tok >= 0 && tok < n_vocab) + { + slot->sparams.penalty_prompt_tokens.push_back(tok); + } + } + } + slot->sparams.use_penalty_prompt_tokens = true; + } + } + slot->sparams.logit_bias.clear(); if (json_value(data, "ignore_eos", false)) @@ -992,6 +1028,12 @@ struct llama_server_context slot.generated_text += token_str; slot.has_next_token = true; + if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) + { + // we can change penalty_prompt_tokens because it is always created from scratch each request + slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok); + } + // check if there is incomplete UTF-8 character at the end bool incomplete = false; for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) @@ -1183,6 +1225,8 @@ struct llama_server_context {"repeat_penalty", slot.sparams.penalty_repeat}, {"presence_penalty", slot.sparams.penalty_present}, {"frequency_penalty", slot.sparams.penalty_freq}, + {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens}, + {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens}, {"mirostat", slot.sparams.mirostat}, {"mirostat_tau", slot.sparams.mirostat_tau}, {"mirostat_eta", slot.sparams.mirostat_eta}, |