summaryrefslogtreecommitdiff
path: root/examples/batched/batched.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/batched/batched.cpp')
-rw-r--r--examples/batched/batched.cpp26
1 files changed, 8 insertions, 18 deletions
diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp
index a88e022d..15521216 100644
--- a/examples/batched/batched.cpp
+++ b/examples/batched/batched.cpp
@@ -97,20 +97,15 @@ int main(int argc, char ** argv) {
fflush(stderr);
- // create a llama_batch with size 512
+ // create a llama_batch
// we use this object to submit token data for decoding
-
- llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0);
+ llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0, 1);
// evaluate the initial prompt
- batch.n_tokens = tokens_list.size();
-
- for (int32_t i = 0; i < batch.n_tokens; i++) {
- batch.token[i] = tokens_list[i];
- batch.pos[i] = i;
- batch.seq_id[i] = 0;
- batch.logits[i] = false;
+ for (size_t i = 0; i < tokens_list.size(); ++i) {
+ llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
}
+ GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
@@ -146,7 +141,7 @@ int main(int argc, char ** argv) {
while (n_cur <= n_len) {
// prepare the next batch
- batch.n_tokens = 0;
+ llama_batch_clear(batch);
// sample the next token for each parallel sequence / stream
for (int32_t i = 0; i < n_parallel; ++i) {
@@ -198,15 +193,10 @@ int main(int argc, char ** argv) {
streams[i] += llama_token_to_piece(ctx, new_token_id);
- // push this new token for next evaluation
- batch.token [batch.n_tokens] = new_token_id;
- batch.pos [batch.n_tokens] = n_cur;
- batch.seq_id[batch.n_tokens] = i;
- batch.logits[batch.n_tokens] = true;
-
i_batch[i] = batch.n_tokens;
- batch.n_tokens += 1;
+ // push this new token for next evaluation
+ llama_batch_add(batch, new_token_id, n_cur, { i }, true);
n_decode += 1;
}