From 6b73ef120114beb5664ea94aab48d07ed248ee52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 28 Aug 2023 17:59:39 +0200 Subject: YAML result logging + preset script (#2657) --- examples/main/main.cpp | 78 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 76 insertions(+), 2 deletions(-) (limited to 'examples/main/main.cpp') diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3ce57f43..89cc4f60 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -36,9 +37,57 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -static llama_context ** g_ctx; +static llama_context ** g_ctx; +static llama_model ** g_model; +static gpt_params * g_params; +static std::vector * g_input_tokens; +static std::ostringstream * g_output_ss; +static std::vector * g_output_tokens; static bool is_interacting = false; +void write_logfile( + const llama_context * ctx, const gpt_params & params, const llama_model * model, + const std::vector input_tokens, const std::string output, const std::vector output_tokens) { + + if (params.logdir.empty()) { + return; + } + + const std::string timestamp = get_sortable_timestamp(); + + const bool success = create_directory_with_parents(params.logdir); + if (!success) { + fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n", + __func__, params.logdir.c_str()); + return; + } + + const std::string logfile_path = params.logdir + timestamp + ".yml"; + FILE * logfile = fopen(logfile_path.c_str(), "w"); + + if (logfile == NULL) { + fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str()); + return; + } + + fprintf(logfile, "binary: main\n"); + char model_desc[128]; + llama_model_desc(model, model_desc, sizeof(model_desc)); + dump_non_result_info_yaml(logfile, params, ctx, timestamp, input_tokens, model_desc); + + fprintf(logfile, "\n"); + fprintf(logfile, "######################\n"); + fprintf(logfile, "# Generation Results #\n"); + fprintf(logfile, "######################\n"); + fprintf(logfile, "\n"); + + dump_string_yaml_multiline(logfile, "output", output.c_str()); + dump_vector_int_yaml(logfile, "output_tokens", output_tokens); + + llama_dump_timing_info_yaml(logfile, ctx); + fclose(logfile); +} + #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) void sigint_handler(int signo) { if (signo == SIGINT) { @@ -48,6 +97,7 @@ void sigint_handler(int signo) { console::cleanup(); printf("\n"); llama_print_timings(*g_ctx); + write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); _exit(130); } } @@ -56,6 +106,7 @@ void sigint_handler(int signo) { int main(int argc, char ** argv) { gpt_params params; + g_params = ¶ms; if (gpt_params_parse(argc, argv, params) == false) { return 1; @@ -116,6 +167,7 @@ int main(int argc, char ** argv) { llama_model * model; llama_context * ctx; llama_context * ctx_guidance = NULL; + g_model = &model; g_ctx = &ctx; // load the model and apply lora adapter, if any @@ -397,6 +449,10 @@ int main(int argc, char ** argv) { int n_session_consumed = 0; int n_past_guidance = 0; + std::vector input_tokens; g_input_tokens = &input_tokens; + std::vector output_tokens; g_output_tokens = &output_tokens; + std::ostringstream output_ss; g_output_ss = &output_ss; + // the first thing we will do is to output the prompt, so set color accordingly console::set_display(console::prompt); @@ -667,7 +723,15 @@ int main(int argc, char ** argv) { // display text if (input_echo) { for (auto id : embd) { - printf("%s", llama_token_to_piece(ctx, id).c_str()); + const std::string token_str = llama_token_to_piece(ctx, id); + printf("%s", token_str.c_str()); + + if (embd.size() > 1) { + input_tokens.push_back(id); + } else { + output_tokens.push_back(id); + output_ss << token_str; + } } fflush(stdout); } @@ -761,6 +825,8 @@ int main(int argc, char ** argv) { printf("%s", params.input_suffix.c_str()); } + const size_t original_size = embd_inp.size(); + // instruct mode: insert instruction prefix if (params.instruct && !is_antiprompt) { n_consumed = embd_inp.size(); @@ -775,6 +841,12 @@ int main(int argc, char ** argv) { embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); } + for (size_t i = original_size; i < embd_inp.size(); ++i) { + const llama_token token = embd_inp[i]; + output_tokens.push_back(token); + output_ss << llama_token_to_piece(ctx, token); + } + n_remain -= line_inp.size(); } @@ -817,6 +889,8 @@ int main(int argc, char ** argv) { } llama_print_timings(ctx); + write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); + if (ctx_guidance) { llama_free(ctx_guidance); } llama_free(ctx); llama_free_model(model); -- cgit v1.2.3