diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-09-03 15:12:08 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-03 15:12:08 +0300 |
commit | 47068e517004d90f13c16352bb3b4cafd53a00cd (patch) | |
tree | 259f1fb1184775dc250452d319c8006c0704ea22 /common/common.h | |
parent | 8f429fa5111901f9646cf998643ac5310846d487 (diff) |
speculative : PoC for speeding-up inference via speculative sampling (#2926)
* speculative : initial example
* speculative : print encoding speed
* speculative : add --draft CLI arg
Diffstat (limited to 'common/common.h')
-rw-r--r-- | common/common.h | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/common/common.h b/common/common.h index 5a379688..105fb09e 100644 --- a/common/common.h +++ b/common/common.h @@ -32,6 +32,7 @@ struct gpt_params { int32_t n_ctx = 512; // context size int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_draft = 16; // number of tokens to draft during speculative decoding int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_gpu_layers = 0; // number of layers to store in VRAM int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors @@ -63,6 +64,7 @@ struct gpt_params { float cfg_scale = 1.f; // How strong is guidance std::string model = "models/7B/ggml-model-f16.gguf"; // model path + std::string model_draft = ""; // draft model for speculative decoding std::string model_alias = "unknown"; // model alias std::string prompt = ""; std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state @@ -156,6 +158,40 @@ std::string llama_detokenize_bpe( llama_context * ctx, const std::vector<llama_token> & tokens); +// +// Sampling utils +// + +// this is a common sampling function used across the examples for convenience +// it can serve as a starting point for implementing your own sampling function +// +// required: +// - ctx: context to use for sampling +// - params: sampling parameters +// +// optional: +// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL +// - grammar: grammar to use for sampling, ignore if NULL +// - last_tokens: needed for repetition penalty, ignore if empty +// - idx: sample from llama_get_logits(ctx) + idx * n_vocab +// +// returns: +// - token: sampled token +// - candidates: vector of candidate tokens +// +llama_token llama_sample_token( + struct llama_context * ctx, + struct llama_context * ctx_guidance, + struct llama_grammar * grammar, + const struct gpt_params & params, + const std::vector<llama_token> & last_tokens, + std::vector<llama_token_data> & candidates, + int idx = 0); + +// +// YAML utils +// + bool create_directory_with_parents(const std::string & path); void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data); void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data); |