summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/perplexity/perplexity.cpp6
1 files changed, 3 insertions, 3 deletions
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp
index 372684f0..95aedce6 100644
--- a/examples/perplexity/perplexity.cpp
+++ b/examples/perplexity/perplexity.cpp
@@ -166,7 +166,7 @@ static void process_logits(
break;
}
lock.unlock();
- const results_log_softmax results = log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
+ const results_log_softmax results = log_softmax(n_vocab, logits + int64_t(i)*n_vocab, tokens[i+1]);
const double v = -results.log_softmax;
local_nll += v;
local_nll2 += v*v;
@@ -200,7 +200,7 @@ static void process_logits(std::ostream& out, int n_vocab, const float * logits,
break;
}
lock.unlock();
- const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + i*nv, tokens[i+1]);
+ const double v = log_softmax(n_vocab, logits + i*n_vocab, log_probs.data() + int64_t(i)*nv, tokens[i+1]);
local_nll += v;
local_nll2 += v*v;
}
@@ -618,7 +618,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
if (num_batches > 1 && n_outputs > 0) {
const auto * batch_logits = llama_get_logits(ctx);
- logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab);
+ logits.insert(logits.end(), batch_logits, batch_logits + int64_t(n_outputs) * n_vocab);
}
}