diff options
Diffstat (limited to 'common/sampling.h')
-rw-r--r-- | common/sampling.h | 87 |
1 files changed, 39 insertions, 48 deletions
diff --git a/common/sampling.h b/common/sampling.h index 0aab5d03..50afcbc1 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -2,6 +2,8 @@ #include "llama.h" +#include "grammar-parser.h" + #include <string> #include <vector> #include <unordered_map> @@ -34,75 +36,64 @@ typedef struct llama_sampling_params { } llama_sampling_params; -// per-sequence sampler context -typedef struct llama_sampler_sequence_context { - float mirostat_mu; // mirostat sampler state - llama_grammar * grammar; -} llama_sampler_sequence_context; - // general sampler context -typedef struct llama_sampling_context { - ~llama_sampling_context(); - - // parameters that will be used for sampling and when creating - // new llama_sampler_sequence_context instances +// TODO: move to llama.h +struct llama_sampling_context { + // parameters that will be used for sampling llama_sampling_params params; - // map of sequence ids to sampler contexts - std::unordered_map<llama_seq_id, llama_sampler_sequence_context> sequence_contexts; + // mirostat sampler state + float mirostat_mu; - // when non-NULL, new instances of llama_sampler_sequence_context - // will get a copy of the grammar here - // note: only the pointer is stored here, it is not a copy of - // the grammar and shouldn't be freed llama_grammar * grammar; -} llama_sampling_context; + + // internal + grammar_parser::parse_state parsed_grammar; + + // TODO: replace with ring-buffer + std::vector<llama_token> prev; + std::vector<llama_token_data> cur; +}; #include "common.h" // Create a new sampling context instance. -llama_sampling_context llama_sampling_context_init( - const struct gpt_params & params, - llama_grammar * grammar = NULL); - -// Fetches the sampler context for the specified sequence id (defaults to 0). -// If the context for that sequence id doesn't already exist, it will be created with -// default values based on the parameters in the ctx_sampling argument. -llama_sampler_sequence_context & llama_sampling_get_sequence_context( - llama_sampling_context & ctx_sampling, - const llama_seq_id seq = 0); - -// Reset the sampler context for the supplied sequence id (defaults to 0). -// This is necessary to reuse a sequence id or free memory used by sequences -// that are no longer required. -bool llama_sampling_context_reset( - llama_sampling_context & ctx_sampling, - const llama_seq_id seq = 0); +struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params); + +void llama_sampling_free(struct llama_sampling_context * ctx); + +// Reset the sampler context +// - clear prev tokens +// - reset grammar +void llama_sampling_reset(llama_sampling_context * ctx); + +// Copy the sampler context +void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function // Note: When using multiple sequences, it is the caller's responsibility to call -// llama_sampling_context_reset when a sequence ends +// llama_sampling_reset when a sequence ends // // required: -// - ctx: context to use for sampling +// - ctx_main: context to use for sampling // - ctx_sampling: sampling-specific context // // optional: -// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL -// - last_tokens: needed for repetition penalty, ignore if empty -// - idx: sample from llama_get_logits_ith(ctx, idx) -// - seq: sequence id to associate sampler state with +// - ctx_cfg: context to use for classifier-free guidance +// - idx: sample from llama_get_logits_ith(ctx, idx) // // returns: // - token: sampled token // - candidates: vector of candidate tokens // llama_token llama_sampling_sample( - struct llama_context * ctx, - struct llama_context * ctx_guidance, - struct llama_sampling_context & ctx_sampling, - const std::vector<llama_token> & last_tokens, - std::vector<llama_token_data> & candidates, - const int idx = 0, - llama_seq_id seq = 0); + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + int idx = 0); + +void llama_sampling_accept( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + llama_token id); |