diff options
Diffstat (limited to 'common/common.cpp')
-rw-r--r-- | common/common.cpp | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp index 3e4b8a8c..ce14d66b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -820,6 +820,27 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param return cparams; } +void llama_batch_clear(struct llama_batch & batch) { + batch.n_tokens = 0; +} + +void llama_batch_add( + struct llama_batch & batch, + llama_token id, + llama_pos pos, + const std::vector<llama_seq_id> & seq_ids, + bool logits) { + batch.token [batch.n_tokens] = id; + batch.pos [batch.n_tokens] = pos, + batch.n_seq_id[batch.n_tokens] = seq_ids.size(); + for (size_t i = 0; i < seq_ids.size(); ++i) { + batch.seq_id[batch.n_tokens][i] = seq_ids[i]; + } + batch.logits [batch.n_tokens] = logits; + + batch.n_tokens++; +} + std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) { auto mparams = llama_model_params_from_gpt_params(params); |