From 0e89203b517c95ec6675eda75d200a60d1e8921d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 18 Oct 2023 16:21:57 +0300 Subject: speculative : add tree-based sampling example (#3624) * sampling : one sequence per sampling context ggml-ci * speculative : add tree-based sampling support ggml-ci * speculative : reuse the n_parallel CLI param * speculative : refactor sampling * examples : fix build after sampling refactoring ggml-ci * batched : fix n_seq_id * sampling : fix malloc ggml-ci * swift : fix build ggml-ci * swift : try to fix build ggml-ci * prompts : add assistant.txt * common : add llama_batch_add() and llama_batch_clear() helpers * speculative : minor refactor ggml-ci * minor : comments + rename ggml-ci * speculative : fix off-by-one for n_drafted * speculative : fix the n_drafted fix + p constants --- common/common.cpp | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) (limited to 'common/common.cpp') diff --git a/common/common.cpp b/common/common.cpp index 3e4b8a8c..ce14d66b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -820,6 +820,27 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param return cparams; } +void llama_batch_clear(struct llama_batch & batch) { + batch.n_tokens = 0; +} + +void llama_batch_add( + struct llama_batch & batch, + llama_token id, + llama_pos pos, + const std::vector & seq_ids, + bool logits) { + batch.token [batch.n_tokens] = id; + batch.pos [batch.n_tokens] = pos, + batch.n_seq_id[batch.n_tokens] = seq_ids.size(); + for (size_t i = 0; i < seq_ids.size(); ++i) { + batch.seq_id[batch.n_tokens][i] = seq_ids[i]; + } + batch.logits [batch.n_tokens] = logits; + + batch.n_tokens++; +} + std::tuple llama_init_from_gpt_params(gpt_params & params) { auto mparams = llama_model_params_from_gpt_params(params); -- cgit v1.2.3