diff options
Diffstat (limited to 'examples/perplexity/perplexity.cpp')
-rw-r--r-- | examples/perplexity/perplexity.cpp | 30 |
1 files changed, 16 insertions, 14 deletions
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 3a1c8c28..4620c43a 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -28,9 +28,10 @@ struct results_log_softmax { float prob; }; -void write_logfile(const llama_context * ctx, const gpt_params & params, - const llama_model * model, const struct results_perplexity & results) { - +static void write_logfile( + const llama_context * ctx, const gpt_params & params, const llama_model * model, + const struct results_perplexity & results +) { if (params.logdir.empty()) { return; } @@ -76,7 +77,7 @@ void write_logfile(const llama_context * ctx, const gpt_params & params, fclose(logfile); } -std::vector<float> softmax(const std::vector<float>& logits) { +static std::vector<float> softmax(const std::vector<float>& logits) { std::vector<float> probs(logits.size()); float max_logit = logits[0]; for (float v : logits) max_logit = std::max(max_logit, v); @@ -92,7 +93,7 @@ std::vector<float> softmax(const std::vector<float>& logits) { return probs; } -results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) { +static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) { float max_logit = logits[0]; for (int i = 1; i < n_vocab; ++i) max_logit = std::max(max_logit, logits[i]); double sum_exp = 0.0; @@ -100,9 +101,10 @@ results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) { return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp}; } -void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers, - double & nll, double & nll2, float * logit_history, float * prob_history) { - +static void process_logits( + int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers, + double & nll, double & nll2, float * logit_history, float * prob_history +) { std::mutex mutex; int counter = 0; auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () { @@ -130,7 +132,7 @@ void process_logits(int n_vocab, const float * logits, const int * tokens, int n } -results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) { +static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) { // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` // Output: `perplexity: 13.5106 [114/114]` @@ -260,8 +262,7 @@ results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) return {tokens, std::exp(nll / count), logit_history, prob_history}; } -results_perplexity perplexity(llama_context * ctx, const gpt_params & params) { - +static results_perplexity perplexity(llama_context * ctx, const gpt_params & params) { if (params.ppl_stride > 0) { return perplexity_v2(ctx, params); } @@ -400,8 +401,9 @@ results_perplexity perplexity(llama_context * ctx, const gpt_params & params) { return {tokens, ppl, logit_history, prob_history}; } -std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch, - int n_vocab, int n_thread) { +static std::vector<float> hellaswag_evaluate_tokens( + llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch, int n_vocab, int n_thread +) { std::vector<float> result; result.reserve(tokens.size() * n_vocab); size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch; @@ -421,7 +423,7 @@ std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vec return result; } -void hellaswag_score(llama_context * ctx, const gpt_params & params) { +static void hellaswag_score(llama_context * ctx, const gpt_params & params) { // Calculates hellaswag score (acc_norm) from prompt // // Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl |