From 0c06204fb39aa5560e883e0ae74be9518c57d88e Mon Sep 17 00:00:00 2001 From: Xiao-Yong Jin Date: Tue, 25 Jul 2023 07:19:11 -0500 Subject: main : add `--in-prefix-bos` to prefix BOS to user inputs; keep EOS (#2304) * add `--in-prefix-bos` to prefix BOS to user inputs; keep EOS The BOS precedes the string specified by `--in-prefix`. Model generated EOS is now kept in the context. It provides a way to strictly following the prompt format used in Llama-2-chat. The EOS handling also benefits some existing finetunes that uses EOS to mark the end of turn. * examples/common: move input_prefix_bos to other bools --- examples/main/main.cpp | 47 ++++++++++++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 17 deletions(-) (limited to 'examples/main/main.cpp') diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 16ddc227..3796a923 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -325,6 +325,10 @@ int main(int argc, char ** argv) { } } + if (params.input_prefix_bos) { + fprintf(stderr, "Input prefix with BOS\n"); + } + if (!params.input_prefix.empty()) { fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); } @@ -633,16 +637,6 @@ int main(int argc, char ** argv) { last_n_tokens.push_back(id); } - // replace end of text token with newline token when in interactive mode - if (id == llama_token_eos() && params.interactive && !params.instruct) { - id = llama_token_newline.front(); - if (params.antiprompt.size() != 0) { - // tokenize and inject first reverse prompt - const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false); - embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); - } - } - // add it to the context embd.push_back(id); @@ -708,11 +702,34 @@ int main(int argc, char ** argv) { } } + // deal with end of text token in interactive mode + if (last_n_tokens.back() == llama_token_eos()) { + if (params.interactive) { + if (params.antiprompt.size() != 0) { + // tokenize and inject first reverse prompt + const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false); + embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); + is_antiprompt = true; + } + + is_interacting = true; + printf("\n"); + console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + fflush(stdout); + } else if (params.instruct) { + is_interacting = true; + } + } + if (n_past > 0 && is_interacting) { if (params.instruct) { printf("\n> "); } + if (params.input_prefix_bos) { + embd_inp.push_back(llama_token_bos()); + } + std::string buffer; if (!params.input_prefix.empty()) { buffer += params.input_prefix; @@ -776,13 +793,9 @@ int main(int argc, char ** argv) { } // end of text token - if (!embd.empty() && embd.back() == llama_token_eos()) { - if (params.instruct) { - is_interacting = true; - } else { - fprintf(stderr, " [end of text]\n"); - break; - } + if (!embd.empty() && embd.back() == llama_token_eos() && !(params.instruct || params.interactive)) { + fprintf(stderr, " [end of text]\n"); + break; } // In interactive mode, respect the maximum number of tokens and drop back to user input when reached. -- cgit v1.2.3