summaryrefslogtreecommitdiff
path: root/common/common.h
diff options
context:
space:
mode:
Diffstat (limited to 'common/common.h')
-rw-r--r--common/common.h36
1 files changed, 36 insertions, 0 deletions
diff --git a/common/common.h b/common/common.h
index 5a379688..105fb09e 100644
--- a/common/common.h
+++ b/common/common.h
@@ -32,6 +32,7 @@ struct gpt_params {
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
+ int32_t n_draft = 16; // number of tokens to draft during speculative decoding
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
@@ -63,6 +64,7 @@ struct gpt_params {
float cfg_scale = 1.f; // How strong is guidance
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
+ std::string model_draft = ""; // draft model for speculative decoding
std::string model_alias = "unknown"; // model alias
std::string prompt = "";
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
@@ -156,6 +158,40 @@ std::string llama_detokenize_bpe(
llama_context * ctx,
const std::vector<llama_token> & tokens);
+//
+// Sampling utils
+//
+
+// this is a common sampling function used across the examples for convenience
+// it can serve as a starting point for implementing your own sampling function
+//
+// required:
+// - ctx: context to use for sampling
+// - params: sampling parameters
+//
+// optional:
+// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
+// - grammar: grammar to use for sampling, ignore if NULL
+// - last_tokens: needed for repetition penalty, ignore if empty
+// - idx: sample from llama_get_logits(ctx) + idx * n_vocab
+//
+// returns:
+// - token: sampled token
+// - candidates: vector of candidate tokens
+//
+llama_token llama_sample_token(
+ struct llama_context * ctx,
+ struct llama_context * ctx_guidance,
+ struct llama_grammar * grammar,
+ const struct gpt_params & params,
+ const std::vector<llama_token> & last_tokens,
+ std::vector<llama_token_data> & candidates,
+ int idx = 0);
+
+//
+// YAML utils
+//
+
bool create_directory_with_parents(const std::string & path);
void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data);
void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data);