summaryrefslogtreecommitdiff
path: root/examples/perplexity
diff options
context:
space:
mode:
Diffstat (limited to 'examples/perplexity')
-rw-r--r--examples/perplexity/perplexity.cpp51
1 files changed, 36 insertions, 15 deletions
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp
index 2b375e34..de08bd4a 100644
--- a/examples/perplexity/perplexity.cpp
+++ b/examples/perplexity/perplexity.cpp
@@ -80,7 +80,9 @@ static void write_logfile(
static std::vector<float> softmax(const std::vector<float>& logits) {
std::vector<float> probs(logits.size());
float max_logit = logits[0];
- for (float v : logits) max_logit = std::max(max_logit, v);
+ for (float v : logits) {
+ max_logit = std::max(max_logit, v);
+ }
double sum_exp = 0.0;
for (size_t i = 0; i < logits.size(); i++) {
// Subtract the maximum logit value from the current logit value for numerical stability
@@ -89,15 +91,21 @@ static std::vector<float> softmax(const std::vector<float>& logits) {
sum_exp += exp_logit;
probs[i] = exp_logit;
}
- for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp;
+ for (size_t i = 0; i < probs.size(); i++) {
+ probs[i] /= sum_exp;
+ }
return probs;
}
static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) {
float max_logit = logits[0];
- for (int i = 1; i < n_vocab; ++i) max_logit = std::max(max_logit, logits[i]);
+ for (int i = 1; i < n_vocab; ++i) {
+ max_logit = std::max(max_logit, logits[i]);
+ }
double sum_exp = 0.0;
- for (int i = 0; i < n_vocab; ++i) sum_exp += expf(logits[i] - max_logit);
+ for (int i = 0; i < n_vocab; ++i) {
+ sum_exp += expf(logits[i] - max_logit);
+ }
return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
}
@@ -108,7 +116,8 @@ static void process_logits(
std::mutex mutex;
int counter = 0;
auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
- double local_nll = 0, local_nll2 = 0;
+ double local_nll = 0;
+ double local_nll2 = 0;
while (true) {
std::unique_lock<std::mutex> lock(mutex);
int i = counter++;
@@ -126,10 +135,13 @@ static void process_logits(
prob_history[i] = results.prob;
}
};
- for (auto & w : workers) w = std::thread(compute);
+ for (auto & w : workers) {
+ w = std::thread(compute);
+ }
compute();
- for (auto & w : workers) w.join();
-
+ for (auto & w : workers) {
+ w.join();
+ }
}
static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
@@ -152,8 +164,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
return {std::move(tokens), 0., {}, {}};
}
- std::vector<float> logit_history;
- std::vector<float> prob_history;
+ std::vector<float> logit_history;
+ std::vector<float> prob_history;
logit_history.resize(tokens.size());
prob_history.resize(tokens.size());
@@ -195,12 +207,15 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
const auto t_start = std::chrono::high_resolution_clock::now();
+ // clear the KV cache
+ llama_kv_cache_tokens_rm(ctx, -1, -1);
+
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
- if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
+ if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
//fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}
@@ -320,6 +335,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const auto t_start = std::chrono::high_resolution_clock::now();
+ // clear the KV cache
+ llama_kv_cache_tokens_rm(ctx, -1, -1);
+
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
@@ -332,7 +350,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
tokens[batch_start] = llama_token_bos(ctx);
}
- if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
+ if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}
@@ -402,7 +420,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
}
static std::vector<float> hellaswag_evaluate_tokens(
- llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch, int n_vocab, int n_thread
+ llama_context * ctx, std::vector<int> & tokens, int n_past, int n_batch, int n_vocab, int n_thread
) {
std::vector<float> result;
result.reserve(tokens.size() * n_vocab);
@@ -410,7 +428,7 @@ static std::vector<float> hellaswag_evaluate_tokens(
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
size_t n_tokens = tokens.size() - i_chunk * n_batch;
n_tokens = std::min(n_tokens, size_t(n_batch));
- if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) {
+ if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0), n_thread)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {};
}
@@ -550,6 +568,9 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
query_embd.resize(32);
}
+ // clear the KV cache
+ llama_kv_cache_tokens_rm(ctx, -1, -1);
+
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
if (logits.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__);
@@ -661,7 +682,7 @@ int main(int argc, char ** argv) {
return 1;
}
- params.perplexity = true;
+ params.logits_all = true;
params.n_batch = std::min(params.n_batch, params.n_ctx);
if (params.ppl_stride > 0) {