diff options
Diffstat (limited to 'examples/batched-bench')
-rw-r--r-- | examples/batched-bench/batched-bench.cpp | 38 |
1 files changed, 15 insertions, 23 deletions
diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 3e1e0716..c552eaa7 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -114,7 +114,7 @@ int main(int argc, char ** argv) { return 1; } - llama_batch batch = llama_batch_init(n_kv_max, 0); + llama_batch batch = llama_batch_init(n_kv_max, 0, 1); // decode in batches of ctx_params.n_batch tokens auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) { @@ -123,11 +123,12 @@ int main(int argc, char ** argv) { llama_batch batch_view = { n_tokens, - batch.token + i, + batch.token + i, nullptr, - batch.pos + i, - batch.seq_id + i, - batch.logits + i, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, 0, 0, 0, // unused }; @@ -143,13 +144,8 @@ int main(int argc, char ** argv) { // warm up { - batch.n_tokens = 16; - - for (int i = 0; i < batch.n_tokens; ++i) { - batch.token[i] = 0; - batch.pos[i] = i; - batch.seq_id[i] = 0; - batch.logits[i] = false; + for (int i = 0; i < 16; ++i) { + llama_batch_add(batch, 0, i, { 0 }, false); } if (!decode_helper(ctx, batch, ctx_params.n_batch)) { @@ -174,13 +170,12 @@ int main(int argc, char ** argv) { continue; } - batch.n_tokens = is_pp_shared ? pp : pl*pp; + llama_batch_clear(batch); + + const int n_tokens = is_pp_shared ? pp : pl*pp; - for (int i = 0; i < batch.n_tokens; ++i) { - batch.token[i] = 0; - batch.pos[i] = i; - batch.seq_id[i] = 0; - batch.logits[i] = false; + for (int i = 0; i < n_tokens; ++i) { + llama_batch_add(batch, 0, i, { 0 }, false); } batch.logits[batch.n_tokens - 1] = true; @@ -204,13 +199,10 @@ int main(int argc, char ** argv) { const auto t_tg_start = ggml_time_us(); for (int i = 0; i < tg; ++i) { - batch.n_tokens = pl; + llama_batch_clear(batch); for (int j = 0; j < pl; ++j) { - batch.token[j] = 0; - batch.pos[j] = pp + i; - batch.seq_id[j] = j; - batch.logits[j] = true; + llama_batch_add(batch, 0, pp + i, { j }, true); } if (!decode_helper(ctx, batch, ctx_params.n_batch)) { |