summaryrefslogtreecommitdiff
path: root/common/common.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'common/common.cpp')
-rw-r--r--common/common.cpp21
1 files changed, 21 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 3e4b8a8c..ce14d66b 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -820,6 +820,27 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
return cparams;
}
+void llama_batch_clear(struct llama_batch & batch) {
+ batch.n_tokens = 0;
+}
+
+void llama_batch_add(
+ struct llama_batch & batch,
+ llama_token id,
+ llama_pos pos,
+ const std::vector<llama_seq_id> & seq_ids,
+ bool logits) {
+ batch.token [batch.n_tokens] = id;
+ batch.pos [batch.n_tokens] = pos,
+ batch.n_seq_id[batch.n_tokens] = seq_ids.size();
+ for (size_t i = 0; i < seq_ids.size(); ++i) {
+ batch.seq_id[batch.n_tokens][i] = seq_ids[i];
+ }
+ batch.logits [batch.n_tokens] = logits;
+
+ batch.n_tokens++;
+}
+
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
auto mparams = llama_model_params_from_gpt_params(params);