From 6d341ab6c53cd51f2921d986d0090cc8b049b39a Mon Sep 17 00:00:00 2001 From: Minsoo Cheong <54794500+mscheong01@users.noreply.github.com> Date: Tue, 5 Mar 2024 03:24:00 +0900 Subject: speculative : implement stochastic speculative sampling (#5625) * (WIP) Implement stochastic speculative decoding * sample from residual distribution on draft accept failure * fix #5657: force greedy sampling with probs when temp is 0 * remove p_accept parameter * fix style * remove unused variables * add srand() in speculative.cpp * replace use of rand() with mt19937 sampling * fixes based on review (@JohannesGaessler) * fix r random generation * randomly select next sequence to verify + fix bug in memory freeing * fix bug in active_seqs sync * fix uniform int distribution initialization * remove warnings from comparison between int and size_t * check grammar in `llama_sample_probability_distribution_impl` * remove malloc code by utilizing vectors * add PR link to README --- common/sampling.h | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'common/sampling.h') diff --git a/common/sampling.h b/common/sampling.h index 95d87539..48b2459d 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -131,6 +131,13 @@ llama_token llama_sampling_sample( struct llama_context * ctx_cfg, int idx = 0); +// returns the probability that token of given id will be sampled +llama_token_data_array llama_sampling_probability_distribution( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_cfg, + int idx = 0); + void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, -- cgit v1.2.3