diff options
author | slaren <2141330+slaren@users.noreply.github.com> | 2023-03-19 19:22:48 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-19 20:22:48 +0200 |
commit | 50fae10d0339f2bd639f69dd679c0201d939a265 (patch) | |
tree | 907e49aae8cf271ed1080a709256ebf14d53405a /main.cpp | |
parent | 084e2f0ec081c929343d44b09df07ae87cd1ed32 (diff) |
Add --ignore-eos parameter (#181)
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'main.cpp')
-rw-r--r-- | main.cpp | 10 |
1 files changed, 9 insertions, 1 deletions
@@ -27,6 +27,8 @@ #define ANSI_COLOR_RESET "\x1b[0m" #define ANSI_BOLD "\x1b[1m" +static const int EOS_TOKEN_ID = 2; + // determine number of model parts based on the dimension static const std::map<int, int> LLAMA_N_PARTS = { { 4096, 1 }, @@ -956,6 +958,11 @@ int main(int argc, char ** argv) { { const int64_t t_start_sample_us = ggml_time_us(); + if (params.ignore_eos) { + // set the logit of the eos token to zero to avoid sampling it + logits[logits.size() - n_vocab + EOS_TOKEN_ID] = 0; + } + id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng); last_n_tokens.erase(last_n_tokens.begin()); @@ -1055,7 +1062,8 @@ int main(int argc, char ** argv) { } // end of text token - if (embd.back() == 2) { + + if (embd.back() == EOS_TOKEN_ID) { if (params.interactive) { is_interacting = true; } else { |