summaryrefslogtreecommitdiff
path: root/examples/main/main.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r--examples/main/main.cpp110
1 files changed, 93 insertions, 17 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index b97b7b79..61e960ea 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -37,14 +37,15 @@ static gpt_params * g_params;
static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss;
static std::vector<llama_token> * g_output_tokens;
-static bool is_interacting = false;
+static bool is_interacting = false;
+static bool need_insert_eot = false;
-static bool file_exists(const std::string &path) {
+static bool file_exists(const std::string & path) {
std::ifstream f(path.c_str());
return f.good();
}
-static bool file_is_empty(const std::string &path) {
+static bool file_is_empty(const std::string & path) {
std::ifstream f;
f.exceptions(std::ifstream::failbit | std::ifstream::badbit);
f.open(path.c_str(), std::ios::in | std::ios::binary | std::ios::ate);
@@ -99,7 +100,8 @@ static void write_logfile(
static void sigint_handler(int signo) {
if (signo == SIGINT) {
if (!is_interacting && g_params->interactive) {
- is_interacting = true;
+ is_interacting = true;
+ need_insert_eot = true;
} else {
console::cleanup();
printf("\n");
@@ -117,6 +119,15 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v
LOG_TEE("%s", text);
}
+static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
+ llama_chat_msg new_msg{role, content};
+ auto formatted = llama_chat_format_single(
+ model, g_params->chat_template, chat_msgs, new_msg, role == "user");
+ chat_msgs.push_back({role, content});
+ LOG("formatted: %s\n", formatted.c_str());
+ return formatted;
+}
+
int main(int argc, char ** argv) {
gpt_params params;
g_params = &params;
@@ -190,6 +201,7 @@ int main(int argc, char ** argv) {
llama_model * model;
llama_context * ctx;
llama_context * ctx_guidance = NULL;
+ std::vector<llama_chat_msg> chat_msgs;
g_model = &model;
g_ctx = &ctx;
@@ -215,6 +227,15 @@ int main(int argc, char ** argv) {
__func__, n_ctx_train, n_ctx);
}
+ // print chat template example in conversation mode
+ if (params.conversation) {
+ if (params.enable_chat_template) {
+ LOG_TEE("%s: chat template example: %s\n", __func__, llama_chat_format_example(model, params.chat_template).c_str());
+ } else {
+ LOG_TEE("%s: in-suffix/prefix is specified, chat template will be disabled\n", __func__);
+ }
+ }
+
// print system information
{
LOG_TEE("\n");
@@ -244,26 +265,38 @@ int main(int argc, char ** argv) {
}
const bool add_bos = llama_should_add_bos_token(model);
- GGML_ASSERT(llama_add_eos_token(model) != 1);
+ if (!llama_model_has_encoder(model)) {
+ GGML_ASSERT(llama_add_eos_token(model) != 1);
+ }
LOG("add_bos: %d\n", add_bos);
std::vector<llama_token> embd_inp;
- if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
- LOG("tokenize the prompt\n");
- embd_inp = ::llama_tokenize(ctx, params.prompt, true, true);
- } else {
- LOG("use session tokens\n");
- embd_inp = session_tokens;
- }
+ {
+ auto prompt = (params.conversation && params.enable_chat_template && !params.prompt.empty())
+ ? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
+ : params.prompt;
+ if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
+ LOG("tokenize the prompt\n");
+ embd_inp = ::llama_tokenize(ctx, prompt, true, true);
+ } else {
+ LOG("use session tokens\n");
+ embd_inp = session_tokens;
+ }
- LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
- LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
+ LOG("prompt: \"%s\"\n", log_tostr(prompt));
+ LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
+ }
// Should not run without any tokens
if (embd_inp.empty()) {
- embd_inp.push_back(llama_token_bos(model));
- LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
+ if (add_bos) {
+ embd_inp.push_back(llama_token_bos(model));
+ LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
+ } else {
+ LOG_TEE("error: input is empty\n");
+ return -1;
+ }
}
// Tokenize negative prompt
@@ -478,6 +511,7 @@ int main(int argc, char ** argv) {
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
std::ostringstream output_ss; g_output_ss = &output_ss;
+ std::ostringstream assistant_ss; // for storing current assistant message, used in conversation mode
// the first thing we will do is to output the prompt, so set color accordingly
console::set_display(console::prompt);
@@ -500,6 +534,24 @@ int main(int argc, char ** argv) {
exit(1);
}
+ if (llama_model_has_encoder(model)) {
+ int enc_input_size = embd_inp.size();
+ llama_token * enc_input_buf = embd_inp.data();
+
+ if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) {
+ LOG_TEE("%s : failed to eval\n", __func__);
+ return 1;
+ }
+
+ llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
+ if (decoder_start_token_id == -1) {
+ decoder_start_token_id = llama_token_bos(model);
+ }
+
+ embd_inp.clear();
+ embd_inp.push_back(decoder_start_token_id);
+ }
+
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
if (!embd.empty()) {
@@ -793,11 +845,20 @@ int main(int argc, char ** argv) {
is_antiprompt = true;
}
+ if (params.enable_chat_template) {
+ chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
+ }
is_interacting = true;
printf("\n");
}
}
+ // if current token is not EOG, we add it to current assistant message
+ if (params.conversation) {
+ auto id = llama_sampling_last(ctx_sampling);
+ assistant_ss << llama_token_to_piece(ctx, id, false);
+ }
+
if (n_past > 0 && is_interacting) {
LOG("waiting for user input\n");
@@ -848,12 +909,24 @@ int main(int argc, char ** argv) {
string_process_escapes(buffer);
}
+ bool format_chat = params.conversation && params.enable_chat_template;
+ std::string user_inp = format_chat
+ ? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
+ : std::move(buffer);
+ // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
- const auto line_inp = ::llama_tokenize(ctx, buffer, false, false);
+ const auto line_inp = ::llama_tokenize(ctx, user_inp, false, format_chat);
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
+ // if user stop generation mid-way, we must add EOT to finish model's last response
+ if (need_insert_eot && format_chat) {
+ llama_token eot = llama_token_eot(model);
+ embd_inp.push_back(eot == -1 ? llama_token_eos(model) : eot);
+ need_insert_eot = false;
+ }
+
embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end());
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
@@ -864,6 +937,9 @@ int main(int argc, char ** argv) {
output_ss << llama_token_to_piece(ctx, token);
}
+ // reset assistant message
+ assistant_ss.str("");
+
n_remain -= line_inp.size();
LOG("n_remain: %d\n", n_remain);
} else {