summaryrefslogtreecommitdiff
path: root/examples/llama-bench/llama-bench.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/llama-bench/llama-bench.cpp')
-rw-r--r--examples/llama-bench/llama-bench.cpp8
1 files changed, 6 insertions, 2 deletions
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index 2f1a1d9f..058e34d5 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -891,7 +891,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
int n_processed = 0;
while (n_processed < n_prompt) {
int n_tokens = std::min(n_prompt - n_processed, n_batch);
- llama_eval(ctx, tokens.data(), n_tokens, n_past + n_processed, n_threads);
+ llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0), n_threads);
n_processed += n_tokens;
}
}
@@ -899,7 +899,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
llama_token token = llama_token_bos(ctx);
for (int i = 0; i < n_gen; i++) {
- llama_eval(ctx, &token, 1, n_past + i, n_threads);
+ llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0), n_threads);
}
}
@@ -977,6 +977,8 @@ int main(int argc, char ** argv) {
test t(inst, lmodel, ctx);
+ llama_kv_cache_tokens_rm(ctx, -1, -1);
+
// warmup run
if (t.n_prompt > 0) {
test_prompt(ctx, std::min(2, t.n_batch), 0, t.n_batch, t.n_threads);
@@ -986,6 +988,8 @@ int main(int argc, char ** argv) {
}
for (int i = 0; i < params.reps; i++) {
+ llama_kv_cache_tokens_rm(ctx, -1, -1);
+
uint64_t t_start = get_time_ns();
if (t.n_prompt > 0) {
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);