diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-10-18 16:21:57 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-18 16:21:57 +0300 |
commit | 0e89203b517c95ec6675eda75d200a60d1e8921d (patch) | |
tree | 3aba40ef0362d061f240bd43c52e86a8f728f89d /examples/batched/batched.cpp | |
parent | c67fe68e417f766970fb1feaf2e66458aa24116a (diff) |
speculative : add tree-based sampling example (#3624)
* sampling : one sequence per sampling context
ggml-ci
* speculative : add tree-based sampling support
ggml-ci
* speculative : reuse the n_parallel CLI param
* speculative : refactor sampling
* examples : fix build after sampling refactoring
ggml-ci
* batched : fix n_seq_id
* sampling : fix malloc
ggml-ci
* swift : fix build
ggml-ci
* swift : try to fix build
ggml-ci
* prompts : add assistant.txt
* common : add llama_batch_add() and llama_batch_clear() helpers
* speculative : minor refactor
ggml-ci
* minor : comments + rename
ggml-ci
* speculative : fix off-by-one for n_drafted
* speculative : fix the n_drafted fix + p constants
Diffstat (limited to 'examples/batched/batched.cpp')
-rw-r--r-- | examples/batched/batched.cpp | 26 |
1 files changed, 8 insertions, 18 deletions
diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index a88e022d..15521216 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -97,20 +97,15 @@ int main(int argc, char ** argv) { fflush(stderr); - // create a llama_batch with size 512 + // create a llama_batch // we use this object to submit token data for decoding - - llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0); + llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0, 1); // evaluate the initial prompt - batch.n_tokens = tokens_list.size(); - - for (int32_t i = 0; i < batch.n_tokens; i++) { - batch.token[i] = tokens_list[i]; - batch.pos[i] = i; - batch.seq_id[i] = 0; - batch.logits[i] = false; + for (size_t i = 0; i < tokens_list.size(); ++i) { + llama_batch_add(batch, tokens_list[i], i, { 0 }, false); } + GGML_ASSERT(batch.n_tokens == (int) tokens_list.size()); // llama_decode will output logits only for the last token of the prompt batch.logits[batch.n_tokens - 1] = true; @@ -146,7 +141,7 @@ int main(int argc, char ** argv) { while (n_cur <= n_len) { // prepare the next batch - batch.n_tokens = 0; + llama_batch_clear(batch); // sample the next token for each parallel sequence / stream for (int32_t i = 0; i < n_parallel; ++i) { @@ -198,15 +193,10 @@ int main(int argc, char ** argv) { streams[i] += llama_token_to_piece(ctx, new_token_id); - // push this new token for next evaluation - batch.token [batch.n_tokens] = new_token_id; - batch.pos [batch.n_tokens] = n_cur; - batch.seq_id[batch.n_tokens] = i; - batch.logits[batch.n_tokens] = true; - i_batch[i] = batch.n_tokens; - batch.n_tokens += 1; + // push this new token for next evaluation + llama_batch_add(batch, new_token_id, n_cur, { i }, true); n_decode += 1; } |