diff options
author | vvhg1 <94630311+vvhg1@users.noreply.github.com> | 2023-10-10 09:31:21 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-10 10:31:21 +0300 |
commit | 11ea5c7d96f2c28e1c99659e08ec0a44574056e2 (patch) | |
tree | 2ab7b5099157439bce85929ce544ec7a971e3b31 | |
parent | 95bd60a0a69f57e9a2ff1269667ea484a1a9bb40 (diff) |
infill. : fix tokenization (#3508)
* infill tokens correction
* serverinfill tokens correction
* removing any leading whitespace from infill suffix and removing leeading space token from suffix when params.escape
* removing any leading whitespace from infill suffix and removing leeading space token from suffix when params.escape
* only rm when params.escape, rm space if possible which is added back or rm added space token
* only rm when params.escape, rm space if possible which is added back or rm added space token
* Revert "only rm when params.escape, rm space if possible which is added back or rm added space token"
This reverts commit 63ba0b621f21077c0e3bc6ba6a327534123cb738.
* fix interactive prompt escaping and fix server infill leading space handling
* rm unnecessary bool check
-rw-r--r-- | examples/infill/infill.cpp | 37 | ||||
-rw-r--r-- | examples/server/server.cpp | 15 |
2 files changed, 46 insertions, 6 deletions
diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 9ec75ce4..d994de5e 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -233,10 +233,22 @@ int main(int argc, char ** argv) { const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM; LOG("add_bos: %d\n", add_bos); + bool suff_rm_leading_spc = params.escape; + if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { + params.input_suffix.erase(0, 1); + suff_rm_leading_spc = false; + } std::vector<llama_token> embd_inp; - std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos); - std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos); + std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false); + std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false); + const int space_token = 29871; + if (suff_rm_leading_spc && inp_sfx[0] == space_token) { + inp_sfx.erase(inp_sfx.begin()); + } inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx)); + if (add_bos) { + inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx)); + } inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx)); embd_inp = inp_pfx; embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); @@ -627,10 +639,27 @@ int main(int argc, char ** argv) { buffer.clear(); // done taking input, reset color console::set_display(console::reset); + + if (params.escape) { + //process escape sequences, for the initial prompt this is done in common.cpp when we load the params, but for the interactive mode we need to do it here + process_escapes(params.input_prefix); + process_escapes(params.input_suffix); + } + suff_rm_leading_spc = params.escape; + if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { + params.input_suffix.erase(0, 1); + suff_rm_leading_spc = false; + } // tokenize new prefix and suffix - std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos); - std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos); + std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false); + std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false); + if (suff_rm_leading_spc && inp_sfx[0] == space_token) { + inp_sfx.erase(inp_sfx.begin()); + } inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx)); + if (add_bos) { + inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx)); + } inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx)); embd_inp = inp_pfx; embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c53a6486..8c5318c6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -344,9 +344,20 @@ struct llama_server_context void loadInfill() { - auto prefix_tokens = tokenize(params.input_prefix, true); // always add BOS - auto suffix_tokens = tokenize(params.input_suffix, true); // always add BOS + bool suff_rm_leading_spc = true; + if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { + params.input_suffix.erase(0, 1); + suff_rm_leading_spc = false; + } + + auto prefix_tokens = tokenize(params.input_prefix, false); + auto suffix_tokens = tokenize(params.input_suffix, false); + const int space_token = 29871; + if (suff_rm_leading_spc && suffix_tokens[0] == space_token) { + suffix_tokens.erase(suffix_tokens.begin()); + } prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx)); + prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx)); prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); prefix_tokens.push_back(llama_token_middle(ctx)); |