From 7b8dbcb78b2f65c4676e41da215800d65846edd0 Mon Sep 17 00:00:00 2001 From: anzz1 Date: Tue, 28 Mar 2023 17:09:55 +0300 Subject: main.cpp fixes, refactoring (#571) - main: entering empty line passes back control without new input in interactive/instruct modes - instruct mode: keep prompt fix - instruct mode: duplicate instruct prompt fix - refactor: move common console code from main->common --- examples/common.cpp | 67 +++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 60 insertions(+), 7 deletions(-) (limited to 'examples/common.cpp') diff --git a/examples/common.cpp b/examples/common.cpp index 2ab000f4..880ebe9a 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -9,11 +9,20 @@ #include #include - #if defined(_MSC_VER) || defined(__MINGW32__) - #include // using malloc.h with MSC/MINGW - #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) - #include - #endif +#if defined(_MSC_VER) || defined(__MINGW32__) +#include // using malloc.h with MSC/MINGW +#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) +#include +#endif + +#if defined (_WIN32) +#pragma comment(lib,"kernel32.lib") +extern "C" __declspec(dllimport) void* __stdcall GetStdHandle(unsigned long nStdHandle); +extern "C" __declspec(dllimport) int __stdcall GetConsoleMode(void* hConsoleHandle, unsigned long* lpMode); +extern "C" __declspec(dllimport) int __stdcall SetConsoleMode(void* hConsoleHandle, unsigned long dwMode); +extern "C" __declspec(dllimport) int __stdcall SetConsoleCP(unsigned int wCodePageID); +extern "C" __declspec(dllimport) int __stdcall SetConsoleOutputCP(unsigned int wCodePageID); +#endif bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { // determine sensible default number of threads. @@ -204,7 +213,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --in-prefix STRING string to prefix user inputs with (default: empty)\n"); fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stderr, " prompt file to start generation.\n"); - fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 - infinity)\n", params.n_predict); + fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict); fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k); fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p); fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n); @@ -216,7 +225,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n"); fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stderr, " --perplexity compute perplexity over the prompt\n"); - fprintf(stderr, " --keep number of tokens to keep from the initial prompt\n"); + fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); if (ggml_mlock_supported()) { fprintf(stderr, " --mlock force system to keep model in RAM rather than swapping or compressing\n"); } @@ -256,3 +265,47 @@ std::vector llama_tokenize(struct llama_context * ctx, const std::s return res; } + +/* Keep track of current color of output, and emit ANSI code if it changes. */ +void set_console_color(console_state & con_st, console_color_t color) { + if (con_st.use_color && con_st.color != color) { + switch(color) { + case CONSOLE_COLOR_DEFAULT: + printf(ANSI_COLOR_RESET); + break; + case CONSOLE_COLOR_PROMPT: + printf(ANSI_COLOR_YELLOW); + break; + case CONSOLE_COLOR_USER_INPUT: + printf(ANSI_BOLD ANSI_COLOR_GREEN); + break; + } + con_st.color = color; + } +} + +#if defined (_WIN32) +void win32_console_init(bool enable_color) { + unsigned long dwMode = 0; + void* hConOut = GetStdHandle((unsigned long)-11); // STD_OUTPUT_HANDLE (-11) + if (!hConOut || hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode)) { + hConOut = GetStdHandle((unsigned long)-12); // STD_ERROR_HANDLE (-12) + if (hConOut && (hConOut == (void*)-1 || !GetConsoleMode(hConOut, &dwMode))) { + hConOut = 0; + } + } + if (hConOut) { + // Enable ANSI colors on Windows 10+ + if (enable_color && !(dwMode & 0x4)) { + SetConsoleMode(hConOut, dwMode | 0x4); // ENABLE_VIRTUAL_TERMINAL_PROCESSING (0x4) + } + // Set console output codepage to UTF8 + SetConsoleOutputCP(65001); // CP_UTF8 + } + void* hConIn = GetStdHandle((unsigned long)-10); // STD_INPUT_HANDLE (-10) + if (hConIn && hConIn != (void*)-1 && GetConsoleMode(hConIn, &dwMode)) { + // Set console input codepage to UTF8 + SetConsoleCP(65001); // CP_UTF8 + } +} +#endif -- cgit v1.2.3