summaryrefslogtreecommitdiff
path: root/examples/main/main.cpp
diff options
context:
space:
mode:
authorSeb C <47074056+Sebby37@users.noreply.github.com>2023-11-21 00:26:59 +1030
committerGitHub <noreply@github.com>2023-11-20 14:56:59 +0100
commit881800d1f083c39431cef288347082be516d1c80 (patch)
tree3a87305b90d9532a5934e3c95f3ec1755932e2e0 /examples/main/main.cpp
parentf23c0359a32871947169a044eb1dc4dbffd0f405 (diff)
main : Add ChatML functionality to main example (#4046)
Co-authored-by: Sebastian Cramond <sebby37@users.noreply.github.com>
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r--examples/main/main.cpp36
1 files changed, 31 insertions, 5 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 99d219d6..31ec8cad 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -234,8 +234,11 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd_inp;
- if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
+ if (params.interactive_first || params.instruct || params.chatml || !params.prompt.empty() || session_tokens.empty()) {
LOG("tokenize the prompt\n");
+ if (params.chatml) {
+ params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>";
+ }
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
} else {
LOG("use session tokens\n");
@@ -313,7 +316,7 @@ int main(int argc, char ** argv) {
}
// number of tokens to keep when resetting context
- if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct) {
+ if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml) {
params.n_keep = (int)embd_inp.size();
}
@@ -324,11 +327,23 @@ int main(int argc, char ** argv) {
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str());
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str());
+ // chatml prefix & suffix
+ const auto cml_pfx = ::llama_tokenize(ctx, "\n<|im_start|>user\n", add_bos, true);
+ const auto cml_sfx = ::llama_tokenize(ctx, "<|im_end|>\n<|im_start|>assistant\n", false, true);
+
+ LOG("cml_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_pfx).c_str());
+ LOG("cml_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, cml_sfx).c_str());
+
// in instruct mode, we inject a prefix and a suffix to each input by the user
if (params.instruct) {
params.interactive_first = true;
params.antiprompt.push_back("### Instruction:\n\n");
}
+ // similar for chatml mode
+ else if (params.chatml) {
+ params.interactive_first = true;
+ params.antiprompt.push_back("<|im_start|>user\n");
+ }
// enable interactive mode if interactive start is specified
if (params.interactive_first) {
@@ -705,7 +720,7 @@ int main(int argc, char ** argv) {
is_interacting = true;
printf("\n");
- } else if (params.instruct) {
+ } else if (params.instruct || params.chatml) {
is_interacting = true;
}
}
@@ -713,7 +728,7 @@ int main(int argc, char ** argv) {
if (n_past > 0 && is_interacting) {
LOG("waiting for user input\n");
- if (params.instruct) {
+ if (params.instruct || params.chatml) {
printf("\n> ");
}
@@ -760,6 +775,12 @@ int main(int argc, char ** argv) {
n_consumed = embd_inp.size();
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
}
+ // chatml mode: insert user chat prefix
+ if (params.chatml && !is_antiprompt) {
+ LOG("inserting chatml prefix\n");
+ n_consumed = embd_inp.size();
+ embd_inp.insert(embd_inp.end(), cml_pfx.begin(), cml_pfx.end());
+ }
if (params.escape) {
process_escapes(buffer);
}
@@ -778,6 +799,11 @@ int main(int argc, char ** argv) {
LOG("inserting instruction suffix\n");
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
}
+ // chatml mode: insert assistant chat suffix
+ if (params.chatml) {
+ LOG("inserting chatml suffix\n");
+ embd_inp.insert(embd_inp.end(), cml_sfx.begin(), cml_sfx.end());
+ }
for (size_t i = original_size; i < embd_inp.size(); ++i) {
const llama_token token = embd_inp[i];
@@ -803,7 +829,7 @@ int main(int argc, char ** argv) {
}
// end of text token
- if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive)) {
+ if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive || params.chatml)) {
LOG_TEE(" [end of text]\n");
break;
}