diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-10-18 16:21:57 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-18 16:21:57 +0300 |
commit | 0e89203b517c95ec6675eda75d200a60d1e8921d (patch) | |
tree | 3aba40ef0362d061f240bd43c52e86a8f728f89d /common/common.h | |
parent | c67fe68e417f766970fb1feaf2e66458aa24116a (diff) |
speculative : add tree-based sampling example (#3624)
* sampling : one sequence per sampling context
ggml-ci
* speculative : add tree-based sampling support
ggml-ci
* speculative : reuse the n_parallel CLI param
* speculative : refactor sampling
* examples : fix build after sampling refactoring
ggml-ci
* batched : fix n_seq_id
* sampling : fix malloc
ggml-ci
* swift : fix build
ggml-ci
* swift : try to fix build
ggml-ci
* prompts : add assistant.txt
* common : add llama_batch_add() and llama_batch_clear() helpers
* speculative : minor refactor
ggml-ci
* minor : comments + rename
ggml-ci
* speculative : fix off-by-one for n_drafted
* speculative : fix the n_drafted fix + p constants
Diffstat (limited to 'common/common.h')
-rw-r--r-- | common/common.h | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/common/common.h b/common/common.h index 08c60323..65d3d20c 100644 --- a/common/common.h +++ b/common/common.h @@ -70,6 +70,7 @@ struct gpt_params { std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted std::string logdir = ""; // directory in which to save YAML log files + // TODO: avoid tuple, use struct std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale std::string lora_base = ""; // base model path for the lora adapter @@ -124,10 +125,23 @@ void process_escapes(std::string& input); // Model utils // +// TODO: avoid tuplue, use struct std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_params(gpt_params & params); -struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params); + +struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params); struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); +// Batch utils + +void llama_batch_clear(struct llama_batch & batch); + +void llama_batch_add( + struct llama_batch & batch, + llama_token id, + llama_pos pos, + const std::vector<llama_seq_id> & seq_ids, + bool logits); + // // Vocab utils // |