From 586e7bc561be88e929a9afca7e67d8ead00c53bd Mon Sep 17 00:00:00 2001 From: Minsoo Cheong <54794500+mscheong01@users.noreply.github.com> Date: Sun, 24 Mar 2024 17:54:07 +0900 Subject: sampling : deduplicated code for probability distribution access (#6240) * sampling: remove duplicated code for probability distribution access * free original_logits * fix original_logits allocation * fixes based on review @cebtenzzre * change function name to `llama_sampling_prepare` --- examples/speculative/speculative.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'examples/speculative/speculative.cpp') diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index e991b884..8b31b678 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -219,7 +219,8 @@ int main(int argc, char ** argv) { if (params.sparams.temp > 0) { // stochastic verification - llama_token_data_array dist_tgt = llama_sampling_probability_distribution(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); + llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL); + llama_sample_softmax(ctx_tgt, &dist_tgt); float p_tgt = 0, p_dft = 0; // GGML_ASSERT(dist_tgt.size() == dist_dft.size()); -- cgit v1.2.3