summaryrefslogtreecommitdiff
path: root/common/sampling.h
diff options
context:
space:
mode:
authorMinsoo Cheong <54794500+mscheong01@users.noreply.github.com>2024-03-24 17:54:07 +0900
committerGitHub <noreply@github.com>2024-03-24 10:54:07 +0200
commit586e7bc561be88e929a9afca7e67d8ead00c53bd (patch)
tree7cf1ce22a8b8ea7d711662947fc698e1718905d3 /common/sampling.h
parentddf65685105a39a57b1e7f80c3aa502a6313af24 (diff)
sampling : deduplicated code for probability distribution access (#6240)
* sampling: remove duplicated code for probability distribution access * free original_logits * fix original_logits allocation * fixes based on review @cebtenzzre * change function name to `llama_sampling_prepare`
Diffstat (limited to 'common/sampling.h')
-rw-r--r--common/sampling.h8
1 files changed, 5 insertions, 3 deletions
diff --git a/common/sampling.h b/common/sampling.h
index 79a998be..56ed991b 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -131,12 +131,14 @@ llama_token llama_sampling_sample(
struct llama_context * ctx_cfg,
int idx = 0);
-// returns the probability that token of given id will be sampled
-llama_token_data_array llama_sampling_probability_distribution(
+// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
+llama_token_data_array llama_sampling_prepare(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
- int idx = 0);
+ int idx = 0,
+ bool apply_grammar = true,
+ std::vector<float> * original_logits = nullptr);
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,