summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSigbjørn Skjæret <sigbjorn.skjaeret@scala.com>2024-06-18 14:19:45 +0200
committerGitHub <noreply@github.com>2024-06-18 22:19:45 +1000
commit91c188d6c296bd3384f2a02a83b71187aa3d18b3 (patch)
tree93e5f4aa20756a15ac8d86394c811736b8b05fb1
parent84f6de17f6a8602e7ff7f7c7bda36a73f510a2dd (diff)
Only use FIM middle token if it exists (#7648)
* Only use FIM middle if it exists * Only use FIM middle if it exists
-rw-r--r--examples/infill/infill.cpp13
-rw-r--r--examples/server/server.cpp7
2 files changed, 17 insertions, 3 deletions
diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp
index 0e4ec79c..3e82e4a8 100644
--- a/examples/infill/infill.cpp
+++ b/examples/infill/infill.cpp
@@ -223,7 +223,11 @@ int main(int argc, char ** argv) {
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
embd_inp = inp_pfx;
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
- embd_inp.push_back(llama_token_middle(model));
+
+ const llama_token middle_token = llama_token_middle(model);
+ if (middle_token >= 0) {
+ embd_inp.push_back(middle_token);
+ }
LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix));
LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix));
@@ -528,7 +532,12 @@ int main(int argc, char ** argv) {
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
embd_inp = inp_pfx;
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
- embd_inp.push_back(llama_token_middle(model));
+
+ const llama_token middle_token = llama_token_middle(model);
+ if (middle_token >= 0) {
+ embd_inp.push_back(middle_token);
+ }
+
embd.clear();
n_remain = params.n_predict;
n_past = 0;
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 919078f2..ec59307b 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -2038,7 +2038,12 @@ struct server_context {
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
- prefix_tokens.push_back(llama_token_middle(model));
+
+ const llama_token middle_token = llama_token_middle(model);
+ if (middle_token >= 0) {
+ prefix_tokens.push_back(middle_token);
+ }
+
prompt_tokens = prefix_tokens;
} else {
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt