summaryrefslogtreecommitdiff
path: root/common/common.cpp
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-10-18 16:21:57 +0300
committerGitHub <noreply@github.com>2023-10-18 16:21:57 +0300
commit0e89203b517c95ec6675eda75d200a60d1e8921d (patch)
tree3aba40ef0362d061f240bd43c52e86a8f728f89d /common/common.cpp
parentc67fe68e417f766970fb1feaf2e66458aa24116a (diff)
speculative : add tree-based sampling example (#3624)
* sampling : one sequence per sampling context ggml-ci * speculative : add tree-based sampling support ggml-ci * speculative : reuse the n_parallel CLI param * speculative : refactor sampling * examples : fix build after sampling refactoring ggml-ci * batched : fix n_seq_id * sampling : fix malloc ggml-ci * swift : fix build ggml-ci * swift : try to fix build ggml-ci * prompts : add assistant.txt * common : add llama_batch_add() and llama_batch_clear() helpers * speculative : minor refactor ggml-ci * minor : comments + rename ggml-ci * speculative : fix off-by-one for n_drafted * speculative : fix the n_drafted fix + p constants
Diffstat (limited to 'common/common.cpp')
-rw-r--r--common/common.cpp21
1 files changed, 21 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 3e4b8a8c..ce14d66b 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -820,6 +820,27 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
return cparams;
}
+void llama_batch_clear(struct llama_batch & batch) {
+ batch.n_tokens = 0;
+}
+
+void llama_batch_add(
+ struct llama_batch & batch,
+ llama_token id,
+ llama_pos pos,
+ const std::vector<llama_seq_id> & seq_ids,
+ bool logits) {
+ batch.token [batch.n_tokens] = id;
+ batch.pos [batch.n_tokens] = pos,
+ batch.n_seq_id[batch.n_tokens] = seq_ids.size();
+ for (size_t i = 0; i < seq_ids.size(); ++i) {
+ batch.seq_id[batch.n_tokens][i] = seq_ids[i];
+ }
+ batch.logits [batch.n_tokens] = logits;
+
+ batch.n_tokens++;
+}
+
std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params) {
auto mparams = llama_model_params_from_gpt_params(params);