diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-10-18 16:21:57 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-18 16:21:57 +0300 |
commit | 0e89203b517c95ec6675eda75d200a60d1e8921d (patch) | |
tree | 3aba40ef0362d061f240bd43c52e86a8f728f89d /common/sampling.h | |
parent | c67fe68e417f766970fb1feaf2e66458aa24116a (diff) |
speculative : add tree-based sampling example (#3624)
* sampling : one sequence per sampling context
ggml-ci
* speculative : add tree-based sampling support
ggml-ci
* speculative : reuse the n_parallel CLI param
* speculative : refactor sampling
* examples : fix build after sampling refactoring
ggml-ci
* batched : fix n_seq_id
* sampling : fix malloc
ggml-ci
* swift : fix build
ggml-ci
* swift : try to fix build
ggml-ci
* prompts : add assistant.txt
* common : add llama_batch_add() and llama_batch_clear() helpers
* speculative : minor refactor
ggml-ci
* minor : comments + rename
ggml-ci
* speculative : fix off-by-one for n_drafted
* speculative : fix the n_drafted fix + p constants
Diffstat (limited to 'common/sampling.h')
-rw-r--r-- | common/sampling.h | 87 |
1 files changed, 39 insertions, 48 deletions
diff --git a/common/sampling.h b/common/sampling.h index 0aab5d03..50afcbc1 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -2,6 +2,8 @@ #include "llama.h" +#include "grammar-parser.h" + #include <string> #include <vector> #include <unordered_map> @@ -34,75 +36,64 @@ typedef struct llama_sampling_params { } llama_sampling_params; -// per-sequence sampler context -typedef struct llama_sampler_sequence_context { - float mirostat_mu; // mirostat sampler state - llama_grammar * grammar; -} llama_sampler_sequence_context; - // general sampler context -typedef struct llama_sampling_context { - ~llama_sampling_context(); - - // parameters that will be used for sampling and when creating - // new llama_sampler_sequence_context instances +// TODO: move to llama.h +struct llama_sampling_context { + // parameters that will be used for sampling llama_sampling_params params; - // map of sequence ids to sampler contexts - std::unordered_map<llama_seq_id, llama_sampler_sequence_context> sequence_contexts; + // mirostat sampler state + float mirostat_mu; - // when non-NULL, new instances of llama_sampler_sequence_context - // will get a copy of the grammar here - // note: only the pointer is stored here, it is not a copy of - // the grammar and shouldn't be freed llama_grammar * grammar; -} llama_sampling_context; + + // internal + grammar_parser::parse_state parsed_grammar; + + // TODO: replace with ring-buffer + std::vector<llama_token> prev; + std::vector<llama_token_data> cur; +}; #include "common.h" // Create a new sampling context instance. -llama_sampling_context llama_sampling_context_init( - const struct gpt_params & params, - llama_grammar * grammar = NULL); - -// Fetches the sampler context for the specified sequence id (defaults to 0). -// If the context for that sequence id doesn't already exist, it will be created with -// default values based on the parameters in the ctx_sampling argument. -llama_sampler_sequence_context & llama_sampling_get_sequence_context( - llama_sampling_context & ctx_sampling, - const llama_seq_id seq = 0); - -// Reset the sampler context for the supplied sequence id (defaults to 0). -// This is necessary to reuse a sequence id or free memory used by sequences -// that are no longer required. -bool llama_sampling_context_reset( - llama_sampling_context & ctx_sampling, - const llama_seq_id seq = 0); +struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params); + +void llama_sampling_free(struct llama_sampling_context * ctx); + +// Reset the sampler context +// - clear prev tokens +// - reset grammar +void llama_sampling_reset(llama_sampling_context * ctx); + +// Copy the sampler context +void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function // Note: When using multiple sequences, it is the caller's responsibility to call -// llama_sampling_context_reset when a sequence ends +// llama_sampling_reset when a sequence ends // // required: -// - ctx: context to use for sampling +// - ctx_main: context to use for sampling // - ctx_sampling: sampling-specific context // // optional: -// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL -// - last_tokens: needed for repetition penalty, ignore if empty -// - idx: sample from llama_get_logits_ith(ctx, idx) -// - seq: sequence id to associate sampler state with +// - ctx_cfg: context to use for classifier-free guidance +// - idx: sample from llama_get_logits_ith(ctx, idx) // // returns: // - token: sampled token // - candidates: vector of candidate tokens // llama_token llama_sampling_sample( - struct llama_context * ctx, - struct llama_context * ctx_guidance, - struct llama_sampling_context & ctx_sampling, - const std::vector<llama_token> & last_tokens, - std::vector<llama_token_data> & candidates, - const int idx = 0, - llama_seq_id seq = 0); + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + int idx = 0); + +void llama_sampling_accept( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + llama_token id); |