From 586e7bc561be88e929a9afca7e67d8ead00c53bd Mon Sep 17 00:00:00 2001 From: Minsoo Cheong <54794500+mscheong01@users.noreply.github.com> Date: Sun, 24 Mar 2024 17:54:07 +0900 Subject: sampling : deduplicated code for probability distribution access (#6240) * sampling: remove duplicated code for probability distribution access * free original_logits * fix original_logits allocation * fixes based on review @cebtenzzre * change function name to `llama_sampling_prepare` --- common/sampling.h | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'common/sampling.h') diff --git a/common/sampling.h b/common/sampling.h index 79a998be..56ed991b 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -131,12 +131,14 @@ llama_token llama_sampling_sample( struct llama_context * ctx_cfg, int idx = 0); -// returns the probability that token of given id will be sampled -llama_token_data_array llama_sampling_probability_distribution( +// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters. +llama_token_data_array llama_sampling_prepare( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, struct llama_context * ctx_cfg, - int idx = 0); + int idx = 0, + bool apply_grammar = true, + std::vector * original_logits = nullptr); void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, -- cgit v1.2.3