diff options
author | Johannes Gäßler <johannesg@5d6.de> | 2023-08-28 17:59:39 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-28 17:59:39 +0200 |
commit | 6b73ef120114beb5664ea94aab48d07ed248ee52 (patch) | |
tree | 6d9c777a34a43f7b3ad6185df9639bab9be5c5cd /examples/perplexity | |
parent | 75fafcbcccc280a5b3883bc76d0a2dabf474d094 (diff) |
YAML result logging + preset script (#2657)
Diffstat (limited to 'examples/perplexity')
-rw-r--r-- | examples/perplexity/perplexity.cpp | 141 |
1 files changed, 115 insertions, 26 deletions
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index ebafa0c2..aeb774c5 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -3,16 +3,79 @@ #include "build-info.h" #include <cmath> +#include <cstdio> +#include <cstring> #include <ctime> #include <sstream> -#include <cstring> #include <thread> #include <mutex> +#include <vector> #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif +struct results_perplexity { + std::vector<llama_token> tokens; + double ppl_value; + std::vector<float> logits; + std::vector<float> probs; +}; + +struct results_log_softmax { + double log_softmax; + float logit; + float prob; +}; + +void write_logfile(const llama_context * ctx, const gpt_params & params, + const llama_model * model, const struct results_perplexity & results) { + + if (params.logdir.empty()) { + return; + } + + if (params.hellaswag) { + fprintf(stderr, "%s: warning: logging results is not implemented for HellaSwag. No files will be written.\n", __func__); + return; + } + + const std::string timestamp = get_sortable_timestamp(); + + const bool success = create_directory_with_parents(params.logdir); + if (!success) { + fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n", + __func__, params.logdir.c_str()); + return; + } + + const std::string logfile_path = params.logdir + timestamp + ".yml"; + FILE * logfile = fopen(logfile_path.c_str(), "w"); + + if (logfile == NULL) { + fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str()); + return; + } + + fprintf(logfile, "binary: main\n"); + char model_desc[128]; + llama_model_desc(model, model_desc, sizeof(model_desc)); + dump_non_result_info_yaml(logfile, params, ctx, timestamp, results.tokens, model_desc); + + fprintf(logfile, "\n"); + fprintf(logfile, "######################\n"); + fprintf(logfile, "# Perplexity Results #\n"); + fprintf(logfile, "######################\n"); + fprintf(logfile, "\n"); + + dump_vector_float_yaml(logfile, "logits", results.logits); + fprintf(logfile, "ppl_value: %f\n", results.ppl_value); + dump_vector_float_yaml(logfile, "probs", results.probs); + + llama_dump_timing_info_yaml(logfile, ctx); + fclose(logfile); +} + std::vector<float> softmax(const std::vector<float>& logits) { std::vector<float> probs(logits.size()); float max_logit = logits[0]; @@ -29,20 +92,20 @@ std::vector<float> softmax(const std::vector<float>& logits) { return probs; } -float log_softmax(int n_vocab, const float * logits, int tok) { +results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) { float max_logit = logits[0]; for (int i = 1; i < n_vocab; ++i) max_logit = std::max(max_logit, logits[i]); double sum_exp = 0.0; for (int i = 0; i < n_vocab; ++i) sum_exp += expf(logits[i] - max_logit); - return logits[tok] - max_logit - log(sum_exp); + return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp}; } -void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread>& workers, - double& nll, double& nll2) { +void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers, + double & nll, double & nll2, float * logit_history, float * prob_history) { std::mutex mutex; int counter = 0; - auto compute = [&mutex, &counter, &nll, &nll2, n_vocab, logits, tokens, n_token] () { + auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () { double local_nll = 0, local_nll2 = 0; while (true) { std::unique_lock<std::mutex> lock(mutex); @@ -52,34 +115,43 @@ void process_logits(int n_vocab, const float * logits, const int * tokens, int n break; } lock.unlock(); - double v = -log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]); + const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]); + const double v = -results.log_softmax; local_nll += v; local_nll2 += v*v; + + logit_history[i] = results.logit; + prob_history[i] = results.prob; } }; - for (auto& w : workers) w = std::thread(compute); + for (auto & w : workers) w = std::thread(compute); compute(); - for (auto& w : workers) w.join(); + for (auto & w : workers) w.join(); } -void perplexity_v2(llama_context * ctx, const gpt_params & params) { +results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) { // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` // Output: `perplexity: 13.5106 [114/114]` // BOS tokens will be added for each chunk before eval - if (params.ppl_stride <= 0) { - fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride); - return; - } - const bool is_spm = llama_vocab_type(ctx) == LLAMA_VOCAB_TYPE_SPM; const bool add_bos = is_spm; fprintf(stderr, "%s: tokenizing the input ..\n", __func__); - auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos); + std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos); + std::vector<float> logit_history; + std::vector<float> prob_history; + + logit_history.resize(tokens.size()); + prob_history.resize(tokens.size()); + + if (params.ppl_stride <= 0) { + fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride); + return {tokens, -1, logit_history, prob_history}; + } const int calc_chunk = params.n_ctx; @@ -88,7 +160,7 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) { if (int(tokens.size()) <= calc_chunk) { fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__, tokens.size(), params.n_ctx, params.ppl_stride); - return; + return {tokens, -1, logit_history, prob_history}; } const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride; @@ -120,7 +192,7 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) { //fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { //fprintf(stderr, "%s : failed to eval\n", __func__); - return; + return {tokens, -1, logit_history, prob_history}; } // save original token and restore it after eval @@ -161,6 +233,8 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) { logits.begin() + (j + 1) * n_vocab); const float prob = softmax(tok_logits)[tokens[start + j + 1]]; + logit_history[start + j + 1] = tok_logits[tokens[start + j + 1]]; + prob_history[start + j + 1] = prob; nll += -std::log(prob); ++count; @@ -174,12 +248,14 @@ void perplexity_v2(llama_context * ctx, const gpt_params & params) { fflush(stdout); } printf("\n"); + + return {tokens, std::exp(nll / count), logit_history, prob_history}; } -void perplexity(llama_context * ctx, const gpt_params & params) { +results_perplexity perplexity(llama_context * ctx, const gpt_params & params) { + if (params.ppl_stride > 0) { - perplexity_v2(ctx, params); - return; + return perplexity_v2(ctx, params); } // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research @@ -193,11 +269,17 @@ void perplexity(llama_context * ctx, const gpt_params & params) { auto tim1 = std::chrono::high_resolution_clock::now(); fprintf(stderr, "%s: tokenizing the input ..\n", __func__); - auto tokens = ::llama_tokenize(ctx, params.prompt, add_bos); + std::vector<llama_token> tokens = ::llama_tokenize(ctx, params.prompt, add_bos); auto tim2 = std::chrono::high_resolution_clock::now(); fprintf(stderr, "%s: tokenization took %g ms\n",__func__,1e-3*std::chrono::duration_cast<std::chrono::microseconds>(tim2-tim1).count()); + std::vector<float> logit_history; + logit_history.resize(tokens.size()); + + std::vector<float> prob_history; + prob_history.resize(tokens.size()); + const int n_chunk_max = tokens.size() / params.n_ctx; const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); @@ -236,7 +318,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) { if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); - return; + return {tokens, -1, logit_history, prob_history}; } // restore the original token in case it was set to BOS @@ -272,7 +354,8 @@ void perplexity(llama_context * ctx, const gpt_params & params) { // last 256 tokens. Then, we split the input up into context window size chunks to // process the entire prompt. const int first = std::min(512, params.n_ctx/2); - process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, params.n_ctx - 1 - first, workers, nll, nll2); + process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, params.n_ctx - 1 - first, + workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); count += params.n_ctx - first - 1; // perplexity is e^(average negative log-likelihood) @@ -287,16 +370,19 @@ void perplexity(llama_context * ctx, const gpt_params & params) { fflush(stdout); } printf("\n"); + nll2 /= count; nll /= count; + const double ppl = exp(nll); nll2 -= nll * nll; if (nll2 > 0) { nll2 = sqrt(nll2/(count-1)); - double ppl = exp(nll); printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl); } else { printf("Unexpected negative standard deviation of log(prob)\n"); } + + return {tokens, ppl, logit_history, prob_history}; } std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch, @@ -604,13 +690,16 @@ int main(int argc, char ** argv) { params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); } + struct results_perplexity results; if (params.hellaswag) { hellaswag_score(ctx, params); } else { - perplexity(ctx, params); + results = perplexity(ctx, params); } llama_print_timings(ctx); + write_logfile(ctx, params, model, results); + llama_free(ctx); llama_free_model(model); |