summaryrefslogtreecommitdiff
path: root/examples/batched-bench/batched-bench.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/batched-bench/batched-bench.cpp')
-rw-r--r--examples/batched-bench/batched-bench.cpp13
1 files changed, 8 insertions, 5 deletions
diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp
index 19aff18a..dff6c68e 100644
--- a/examples/batched-bench/batched-bench.cpp
+++ b/examples/batched-bench/batched-bench.cpp
@@ -105,6 +105,9 @@ int main(int argc, char ** argv) {
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
+ // ensure enough sequences are available
+ ctx_params.n_parallel = *std::max_element(n_pl.begin(), n_pl.end());
+
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
if (ctx == NULL) {
@@ -174,10 +177,10 @@ int main(int argc, char ** argv) {
llama_batch_clear(batch);
- const int n_tokens = is_pp_shared ? pp : pl*pp;
-
- for (int i = 0; i < n_tokens; ++i) {
- llama_batch_add(batch, 0, i, { 0 }, false);
+ for (int i = 0; i < pp; ++i) {
+ for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
+ llama_batch_add(batch, 0, i, { j }, false);
+ }
}
batch.logits[batch.n_tokens - 1] = true;
@@ -192,7 +195,7 @@ int main(int argc, char ** argv) {
if (is_pp_shared) {
for (int32_t i = 1; i < pl; ++i) {
- llama_kv_cache_seq_cp(ctx, 0, i, 0, pp);
+ llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}
}