summaryrefslogtreecommitdiff
path: root/common/sampling.h
diff options
context:
space:
mode:
Diffstat (limited to 'common/sampling.h')
-rw-r--r--common/sampling.h87
1 files changed, 39 insertions, 48 deletions
diff --git a/common/sampling.h b/common/sampling.h
index 0aab5d03..50afcbc1 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -2,6 +2,8 @@
#include "llama.h"
+#include "grammar-parser.h"
+
#include <string>
#include <vector>
#include <unordered_map>
@@ -34,75 +36,64 @@ typedef struct llama_sampling_params {
} llama_sampling_params;
-// per-sequence sampler context
-typedef struct llama_sampler_sequence_context {
- float mirostat_mu; // mirostat sampler state
- llama_grammar * grammar;
-} llama_sampler_sequence_context;
-
// general sampler context
-typedef struct llama_sampling_context {
- ~llama_sampling_context();
-
- // parameters that will be used for sampling and when creating
- // new llama_sampler_sequence_context instances
+// TODO: move to llama.h
+struct llama_sampling_context {
+ // parameters that will be used for sampling
llama_sampling_params params;
- // map of sequence ids to sampler contexts
- std::unordered_map<llama_seq_id, llama_sampler_sequence_context> sequence_contexts;
+ // mirostat sampler state
+ float mirostat_mu;
- // when non-NULL, new instances of llama_sampler_sequence_context
- // will get a copy of the grammar here
- // note: only the pointer is stored here, it is not a copy of
- // the grammar and shouldn't be freed
llama_grammar * grammar;
-} llama_sampling_context;
+
+ // internal
+ grammar_parser::parse_state parsed_grammar;
+
+ // TODO: replace with ring-buffer
+ std::vector<llama_token> prev;
+ std::vector<llama_token_data> cur;
+};
#include "common.h"
// Create a new sampling context instance.
-llama_sampling_context llama_sampling_context_init(
- const struct gpt_params & params,
- llama_grammar * grammar = NULL);
-
-// Fetches the sampler context for the specified sequence id (defaults to 0).
-// If the context for that sequence id doesn't already exist, it will be created with
-// default values based on the parameters in the ctx_sampling argument.
-llama_sampler_sequence_context & llama_sampling_get_sequence_context(
- llama_sampling_context & ctx_sampling,
- const llama_seq_id seq = 0);
-
-// Reset the sampler context for the supplied sequence id (defaults to 0).
-// This is necessary to reuse a sequence id or free memory used by sequences
-// that are no longer required.
-bool llama_sampling_context_reset(
- llama_sampling_context & ctx_sampling,
- const llama_seq_id seq = 0);
+struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params);
+
+void llama_sampling_free(struct llama_sampling_context * ctx);
+
+// Reset the sampler context
+// - clear prev tokens
+// - reset grammar
+void llama_sampling_reset(llama_sampling_context * ctx);
+
+// Copy the sampler context
+void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);
// this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function
// Note: When using multiple sequences, it is the caller's responsibility to call
-// llama_sampling_context_reset when a sequence ends
+// llama_sampling_reset when a sequence ends
//
// required:
-// - ctx: context to use for sampling
+// - ctx_main: context to use for sampling
// - ctx_sampling: sampling-specific context
//
// optional:
-// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
-// - last_tokens: needed for repetition penalty, ignore if empty
-// - idx: sample from llama_get_logits_ith(ctx, idx)
-// - seq: sequence id to associate sampler state with
+// - ctx_cfg: context to use for classifier-free guidance
+// - idx: sample from llama_get_logits_ith(ctx, idx)
//
// returns:
// - token: sampled token
// - candidates: vector of candidate tokens
//
llama_token llama_sampling_sample(
- struct llama_context * ctx,
- struct llama_context * ctx_guidance,
- struct llama_sampling_context & ctx_sampling,
- const std::vector<llama_token> & last_tokens,
- std::vector<llama_token_data> & candidates,
- const int idx = 0,
- llama_seq_id seq = 0);
+ struct llama_sampling_context * ctx_sampling,
+ struct llama_context * ctx_main,
+ struct llama_context * ctx_cfg,
+ int idx = 0);
+
+void llama_sampling_accept(
+ struct llama_sampling_context * ctx_sampling,
+ struct llama_context * ctx_main,
+ llama_token id);