summaryrefslogtreecommitdiff
path: root/examples/batched-bench
diff options
context:
space:
mode:
Diffstat (limited to 'examples/batched-bench')
-rw-r--r--examples/batched-bench/batched-bench.cpp38
1 files changed, 15 insertions, 23 deletions
diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp
index 3e1e0716..c552eaa7 100644
--- a/examples/batched-bench/batched-bench.cpp
+++ b/examples/batched-bench/batched-bench.cpp
@@ -114,7 +114,7 @@ int main(int argc, char ** argv) {
return 1;
}
- llama_batch batch = llama_batch_init(n_kv_max, 0);
+ llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
// decode in batches of ctx_params.n_batch tokens
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
@@ -123,11 +123,12 @@ int main(int argc, char ** argv) {
llama_batch batch_view = {
n_tokens,
- batch.token + i,
+ batch.token + i,
nullptr,
- batch.pos + i,
- batch.seq_id + i,
- batch.logits + i,
+ batch.pos + i,
+ batch.n_seq_id + i,
+ batch.seq_id + i,
+ batch.logits + i,
0, 0, 0, // unused
};
@@ -143,13 +144,8 @@ int main(int argc, char ** argv) {
// warm up
{
- batch.n_tokens = 16;
-
- for (int i = 0; i < batch.n_tokens; ++i) {
- batch.token[i] = 0;
- batch.pos[i] = i;
- batch.seq_id[i] = 0;
- batch.logits[i] = false;
+ for (int i = 0; i < 16; ++i) {
+ llama_batch_add(batch, 0, i, { 0 }, false);
}
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
@@ -174,13 +170,12 @@ int main(int argc, char ** argv) {
continue;
}
- batch.n_tokens = is_pp_shared ? pp : pl*pp;
+ llama_batch_clear(batch);
+
+ const int n_tokens = is_pp_shared ? pp : pl*pp;
- for (int i = 0; i < batch.n_tokens; ++i) {
- batch.token[i] = 0;
- batch.pos[i] = i;
- batch.seq_id[i] = 0;
- batch.logits[i] = false;
+ for (int i = 0; i < n_tokens; ++i) {
+ llama_batch_add(batch, 0, i, { 0 }, false);
}
batch.logits[batch.n_tokens - 1] = true;
@@ -204,13 +199,10 @@ int main(int argc, char ** argv) {
const auto t_tg_start = ggml_time_us();
for (int i = 0; i < tg; ++i) {
- batch.n_tokens = pl;
+ llama_batch_clear(batch);
for (int j = 0; j < pl; ++j) {
- batch.token[j] = 0;
- batch.pos[j] = pp + i;
- batch.seq_id[j] = j;
- batch.logits[j] = true;
+ llama_batch_add(batch, 0, pp + i, { j }, true);
}
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {