diff options
author | Thibault Terrasson <thibault.terrasson@gmail.com> | 2023-10-27 16:37:41 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-27 08:37:41 -0600 |
commit | c8d6a1f34ab6f1b6bd468d256e535a61f98f114c (patch) | |
tree | fa6d27fd74eba86e4b27d71e426c2866cf321dc0 /examples | |
parent | 2f9ec7e271220a78fe27c9e6ccbcc0dda31cda0f (diff) |
simple : fix batch handling (#3803)
Diffstat (limited to 'examples')
-rw-r--r-- | examples/simple/simple.cpp | 18 |
1 files changed, 4 insertions, 14 deletions
diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index f376c050..374aef6f 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -95,13 +95,8 @@ int main(int argc, char ** argv) { llama_batch batch = llama_batch_init(512, 0, 1); // evaluate the initial prompt - batch.n_tokens = tokens_list.size(); - - for (int32_t i = 0; i < batch.n_tokens; i++) { - batch.token[i] = tokens_list[i]; - batch.pos[i] = i; - batch.seq_id[i] = 0; - batch.logits[i] = false; + for (size_t i = 0; i < tokens_list.size(); i++) { + llama_batch_add(batch, tokens_list[i], i, { 0 }, false); } // llama_decode will output logits only for the last token of the prompt @@ -148,15 +143,10 @@ int main(int argc, char ** argv) { fflush(stdout); // prepare the next batch - batch.n_tokens = 0; + llama_batch_clear(batch); // push this new token for next evaluation - batch.token [batch.n_tokens] = new_token_id; - batch.pos [batch.n_tokens] = n_cur; - batch.seq_id[batch.n_tokens] = 0; - batch.logits[batch.n_tokens] = true; - - batch.n_tokens += 1; + llama_batch_add(batch, new_token_id, n_cur, { 0 }, true); n_decode += 1; } |