summaryrefslogtreecommitdiff
path: root/examples/main/main.cpp
diff options
context:
space:
mode:
authoranzz1 <anzz1@live.com>2023-03-28 17:09:55 +0300
committerGitHub <noreply@github.com>2023-03-28 17:09:55 +0300
commit7b8dbcb78b2f65c4676e41da215800d65846edd0 (patch)
treeab17f652bb706aac95699ee323c6aaa36a1f4706 /examples/main/main.cpp
parent4b8efff0e3945090379aa2f897ff125c8f9cdbae (diff)
main.cpp fixes, refactoring (#571)
- main: entering empty line passes back control without new input in interactive/instruct modes - instruct mode: keep prompt fix - instruct mode: duplicate instruct prompt fix - refactor: move common console code from main->common
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r--examples/main/main.cpp164
1 files changed, 53 insertions, 111 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 66b7c2d5..d5ab2cf7 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -18,58 +18,13 @@
#include <signal.h>
#endif
-#if defined (_WIN32)
-#pragma comment(lib,"kernel32.lib")
-extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle);
-extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode);
-extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode);
-extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID);
-extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID);
-#endif
-
-#define ANSI_COLOR_RED "\x1b[31m"
-#define ANSI_COLOR_GREEN "\x1b[32m"
-#define ANSI_COLOR_YELLOW "\x1b[33m"
-#define ANSI_COLOR_BLUE "\x1b[34m"
-#define ANSI_COLOR_MAGENTA "\x1b[35m"
-#define ANSI_COLOR_CYAN "\x1b[36m"
-#define ANSI_COLOR_RESET "\x1b[0m"
-#define ANSI_BOLD "\x1b[1m"
-
-/* Keep track of current color of output, and emit ANSI code if it changes. */
-enum console_state {
- CONSOLE_STATE_DEFAULT=0,
- CONSOLE_STATE_PROMPT,
- CONSOLE_STATE_USER_INPUT
-};
-
-static console_state con_st = CONSOLE_STATE_DEFAULT;
-static bool con_use_color = false;
-
-void set_console_state(console_state new_st) {
- if (!con_use_color) return;
- // only emit color code if state changed
- if (new_st != con_st) {
- con_st = new_st;
- switch(con_st) {
- case CONSOLE_STATE_DEFAULT:
- printf(ANSI_COLOR_RESET);
- return;
- case CONSOLE_STATE_PROMPT:
- printf(ANSI_COLOR_YELLOW);
- return;
- case CONSOLE_STATE_USER_INPUT:
- printf(ANSI_BOLD ANSI_COLOR_GREEN);
- return;
- }
- }
-}
+static console_state con_st;
static bool is_interacting = false;
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
void sigint_handler(int signo) {
- set_console_state(CONSOLE_STATE_DEFAULT);
+ set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
printf("\n"); // this also force flush stdout.
if (signo == SIGINT) {
if (!is_interacting) {
@@ -81,32 +36,6 @@ void sigint_handler(int signo) {
}
#endif
-#if defined (_WIN32)
-void win32_console_init(void) {
- unsigned long dwMode = 0;
- void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11)
- if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) {
- hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12)
- if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) {
- hConOut = 0;
- }
- }
- if (hConOut) {
- // Enable ANSI colors on Windows 10+
- if (con_use_color && !(dwMode & 0x4)) {
- SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4)
- }
- // Set console output codepage to UTF8
- SetConsoleOutputCP(65001); // CP_UTF8
- }
- void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10)
- if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) {
- // Set console input codepage to UTF8
- SetConsoleCP(65001); // CP_UTF8
- }
-}
-#endif
-
int main(int argc, char ** argv) {
gpt_params params;
params.model = "models/llama-7B/ggml-model.bin";
@@ -115,13 +44,12 @@ int main(int argc, char ** argv) {
return 1;
}
-
// save choice to use color for later
// (note for later: this is a slightly awkward choice)
- con_use_color = params.use_color;
+ con_st.use_color = params.use_color;
#if defined (_WIN32)
- win32_console_init();
+ win32_console_init(params.use_color);
#endif
if (params.perplexity) {
@@ -218,7 +146,10 @@ int main(int argc, char ** argv) {
return 1;
}
- params.n_keep = std::min(params.n_keep, (int) embd_inp.size());
+ // number of tokens to keep when resetting context
+ if (params.n_keep < 0 || params.n_keep > (int)embd_inp.size() || params.instruct) {
+ params.n_keep = (int)embd_inp.size();
+ }
// prefix & suffix for instruct mode
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", true);
@@ -226,16 +157,12 @@ int main(int argc, char ** argv) {
// in instruct mode, we inject a prefix and a suffix to each input by the user
if (params.instruct) {
- params.interactive = true;
+ params.interactive_start = true;
params.antiprompt.push_back("### Instruction:\n\n");
}
- // enable interactive mode if reverse prompt is specified
- if (params.antiprompt.size() != 0) {
- params.interactive = true;
- }
-
- if (params.interactive_start) {
+ // enable interactive mode if reverse prompt or interactive start is specified
+ if (params.antiprompt.size() != 0 || params.interactive_start) {
params.interactive = true;
}
@@ -297,17 +224,18 @@ int main(int argc, char ** argv) {
#endif
" - Press Return to return control to LLaMa.\n"
" - If you want to submit another line, end your input in '\\'.\n\n");
- is_interacting = params.interactive_start || params.instruct;
+ is_interacting = params.interactive_start;
}
- bool input_noecho = false;
+ bool is_antiprompt = false;
+ bool input_noecho = false;
int n_past = 0;
int n_remain = params.n_predict;
int n_consumed = 0;
// the first thing we will do is to output the prompt, so set color accordingly
- set_console_state(CONSOLE_STATE_PROMPT);
+ set_console_color(con_st, CONSOLE_COLOR_PROMPT);
std::vector<llama_token> embd;
@@ -408,36 +336,38 @@ int main(int argc, char ** argv) {
}
// reset color to default if we there is no pending user input
if (!input_noecho && (int)embd_inp.size() == n_consumed) {
- set_console_state(CONSOLE_STATE_DEFAULT);
+ set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
}
// in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more
if (params.interactive && (int) embd_inp.size() <= n_consumed) {
+
// check for reverse prompt
- std::string last_output;
- for (auto id : last_n_tokens) {
- last_output += llama_token_to_str(ctx, id);
- }
+ if (params.antiprompt.size()) {
+ std::string last_output;
+ for (auto id : last_n_tokens) {
+ last_output += llama_token_to_str(ctx, id);
+ }
- // Check if each of the reverse prompts appears at the end of the output.
- for (std::string & antiprompt : params.antiprompt) {
- if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
- is_interacting = true;
- set_console_state(CONSOLE_STATE_USER_INPUT);
- fflush(stdout);
- break;
+ is_antiprompt = false;
+ // Check if each of the reverse prompts appears at the end of the output.
+ for (std::string & antiprompt : params.antiprompt) {
+ if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) {
+ is_interacting = true;
+ is_antiprompt = true;
+ set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
+ fflush(stdout);
+ break;
+ }
}
}
if (n_past > 0 && is_interacting) {
// potentially set color to indicate we are taking user input
- set_console_state(CONSOLE_STATE_USER_INPUT);
+ set_console_color(con_st, CONSOLE_COLOR_USER_INPUT);
if (params.instruct) {
- n_consumed = embd_inp.size();
- embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
-
printf("\n> ");
}
@@ -463,16 +393,28 @@ int main(int argc, char ** argv) {
} while (another_line);
// done taking input, reset color
- set_console_state(CONSOLE_STATE_DEFAULT);
+ set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
- auto line_inp = ::llama_tokenize(ctx, buffer, false);
- embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
+ // Add tokens to embd only if the input buffer is non-empty
+ // Entering a empty line lets the user pass control back
+ if (buffer.length() > 1) {
- if (params.instruct) {
- embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
- }
+ // instruct mode: insert instruction prefix
+ if (params.instruct && !is_antiprompt) {
+ n_consumed = embd_inp.size();
+ embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
+ }
- n_remain -= line_inp.size();
+ auto line_inp = ::llama_tokenize(ctx, buffer, false);
+ embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
+
+ // instruct mode: insert response suffix
+ if (params.instruct) {
+ embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
+ }
+
+ n_remain -= line_inp.size();
+ }
input_noecho = true; // do not echo this again
}
@@ -506,7 +448,7 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx);
llama_free(ctx);
- set_console_state(CONSOLE_STATE_DEFAULT);
+ set_console_color(con_st, CONSOLE_COLOR_DEFAULT);
return 0;
}