diff options
author | anzz1 <anzz1@live.com> | 2023-03-28 17:09:55 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-28 17:09:55 +0300 |
commit | 7b8dbcb78b2f65c4676e41da215800d65846edd0 (patch) | |
tree | ab17f652bb706aac95699ee323c6aaa36a1f4706 /examples/main/main.cpp | |
parent | 4b8efff0e3945090379aa2f897ff125c8f9cdbae (diff) |
main.cpp fixes, refactoring (#571)
- main: entering empty line passes back control without new input in interactive/instruct modes
- instruct mode: keep prompt fix
- instruct mode: duplicate instruct prompt fix
- refactor: move common console code from main->common
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r-- | examples/main/main.cpp | 164 |
1 files changed, 53 insertions, 111 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 66b7c2d5..d5ab2cf7 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -18,58 +18,13 @@ #include <signal.h> #endif -#if defined (_WIN32) -#pragma comment(lib,"kernel32.lib") -extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle); -extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode); -extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode); -extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID); -extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID); -#endif - -#define ANSI_COLOR_RED "\x1b[31m" -#define ANSI_COLOR_GREEN "\x1b[32m" -#define ANSI_COLOR_YELLOW "\x1b[33m" -#define ANSI_COLOR_BLUE "\x1b[34m" -#define ANSI_COLOR_MAGENTA "\x1b[35m" -#define ANSI_COLOR_CYAN "\x1b[36m" -#define ANSI_COLOR_RESET "\x1b[0m" -#define ANSI_BOLD "\x1b[1m" - -/* Keep track of current color of output, and emit ANSI code if it changes. */ -enum console_state { - CONSOLE_STATE_DEFAULT=0, - CONSOLE_STATE_PROMPT, - CONSOLE_STATE_USER_INPUT -}; - -static console_state con_st = CONSOLE_STATE_DEFAULT; -static bool con_use_color = false; - -void set_console_state(console_state new_st) { - if (!con_use_color) return; - // only emit color code if state changed - if (new_st != con_st) { - con_st = new_st; - switch(con_st) { - case CONSOLE_STATE_DEFAULT: - printf(ANSI_COLOR_RESET); - return; - case CONSOLE_STATE_PROMPT: - printf(ANSI_COLOR_YELLOW); - return; - case CONSOLE_STATE_USER_INPUT: - printf(ANSI_BOLD ANSI_COLOR_GREEN); - return; - } - } -} +static console_state con_st; static bool is_interacting = false; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) void sigint_handler(int signo) { - set_console_state(CONSOLE_STATE_DEFAULT); + set_console_color(con_st, CONSOLE_COLOR_DEFAULT); printf("\n"); // this also force flush stdout. if (signo == SIGINT) { if (!is_interacting) { @@ -81,32 +36,6 @@ void sigint_handler(int signo) { } #endif -#if defined (_WIN32) -void win32_console_init(void) { - unsigned long dwMode = 0; - void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11) - if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) { - hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12) - if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) { - hConOut = 0; - } - } - if (hConOut) { - // Enable ANSI colors on Windows 10+ - if (con_use_color && !(dwMode & 0x4)) { - SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4) - } - // Set console output codepage to UTF8 - SetConsoleOutputCP(65001); // CP_UTF8 - } - void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10) - if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) { - // Set console input codepage to UTF8 - SetConsoleCP(65001); // CP_UTF8 - } -} -#endif - int main(int argc, char ** argv) { gpt_params params; params.model = "models/llama-7B/ggml-model.bin"; @@ -115,13 +44,12 @@ int main(int argc, char ** argv) { return 1; } - // save choice to use color for later // (note for later: this is a slightly awkward choice) - con_use_color = params.use_color; + con_st.use_color = params.use_color; #if defined (_WIN32) - win32_console_init(); + win32_console_init(params.use_color); #endif if (params.perplexity) { @@ -218,7 +146,10 @@ int main(int argc, char ** argv) { return 1; } - params.n_keep = std::min(params.n_keep, (int) embd_inp.size()); + // number of tokens to keep when resetting context + if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) { + params.n_keep = (int)embd_inp.size(); + } // prefix & suffix for instruct mode const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true); @@ -226,16 +157,12 @@ int main(int argc, char ** argv) { // in instruct mode, we inject a prefix and a suffix to each input by the user if (params.instruct) { - params.interactive = true; + params.interactive_start = true; params.antiprompt.push_back("### Instruction:\n\n"); } - // enable interactive mode if reverse prompt is specified - if (params.antiprompt.size() != 0) { - params.interactive = true; - } - - if (params.interactive_start) { + // enable interactive mode if reverse prompt or interactive start is specified + if (params.antiprompt.size() != 0 || params.interactive_start) { params.interactive = true; } @@ -297,17 +224,18 @@ int main(int argc, char ** argv) { #endif " - Press Return to return control to LLaMa.\n" " - If you want to submit another line, end your input in '\\'.\n\n"); - is_interacting = params.interactive_start || params.instruct; + is_interacting = params.interactive_start; } - bool input_noecho = false; + bool is_antiprompt = false; + bool input_noecho = false; int n_past = 0; int n_remain = params.n_predict; int n_consumed = 0; // the first thing we will do is to output the prompt, so set color accordingly - set_console_state(CONSOLE_STATE_PROMPT); + set_console_color(con_st, CONSOLE_COLOR_PROMPT); std::vector<llama_token> embd; @@ -408,36 +336,38 @@ int main(int argc, char ** argv) { } // reset color to default if we there is no pending user input if (!input_noecho && (int)embd_inp.size() == n_consumed) { - set_console_state(CONSOLE_STATE_DEFAULT); + set_console_color(con_st, CONSOLE_COLOR_DEFAULT); } // in interactive mode, and not currently processing queued inputs; // check if we should prompt the user for more if (params.interactive && (int) embd_inp.size() <= n_consumed) { + // check for reverse prompt - std::string last_output; - for (auto id : last_n_tokens) { - last_output += llama_token_to_str(ctx, id); - } + if (params.antiprompt.size()) { + std::string last_output; + for (auto id : last_n_tokens) { + last_output += llama_token_to_str(ctx, id); + } - // Check if each of the reverse prompts appears at the end of the output. - for (std::string & antiprompt : params.antiprompt) { - if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { - is_interacting = true; - set_console_state(CONSOLE_STATE_USER_INPUT); - fflush(stdout); - break; + is_antiprompt = false; + // Check if each of the reverse prompts appears at the end of the output. + for (std::string & antiprompt : params.antiprompt) { + if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { + is_interacting = true; + is_antiprompt = true; + set_console_color(con_st, CONSOLE_COLOR_USER_INPUT); + fflush(stdout); + break; + } } } if (n_past > 0 && is_interacting) { // potentially set color to indicate we are taking user input - set_console_state(CONSOLE_STATE_USER_INPUT); + set_console_color(con_st, CONSOLE_COLOR_USER_INPUT); if (params.instruct) { - n_consumed = embd_inp.size(); - embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); - printf("\n> "); } @@ -463,16 +393,28 @@ int main(int argc, char ** argv) { } while (another_line); // done taking input, reset color - set_console_state(CONSOLE_STATE_DEFAULT); + set_console_color(con_st, CONSOLE_COLOR_DEFAULT); - auto line_inp = ::llama_tokenize(ctx, buffer, false); - embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); + // Add tokens to embd only if the input buffer is non-empty + // Entering a empty line lets the user pass control back + if (buffer.length() > 1) { - if (params.instruct) { - embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); - } + // instruct mode: insert instruction prefix + if (params.instruct && !is_antiprompt) { + n_consumed = embd_inp.size(); + embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); + } - n_remain -= line_inp.size(); + auto line_inp = ::llama_tokenize(ctx, buffer, false); + embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); + + // instruct mode: insert response suffix + if (params.instruct) { + embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); + } + + n_remain -= line_inp.size(); + } input_noecho = true; // do not echo this again } @@ -506,7 +448,7 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); llama_free(ctx); - set_console_state(CONSOLE_STATE_DEFAULT); + set_console_color(con_st, CONSOLE_COLOR_DEFAULT); return 0; } |