summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp7
1 files changed, 5 insertions, 2 deletions
diff --git a/llama.cpp b/llama.cpp
index e4ca34bd..3a4a03d8 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -13667,7 +13667,7 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da
return result;
}
-llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
+llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) {
GGML_ASSERT(ctx);
const int64_t t_start_sample_us = ggml_time_us();
@@ -13680,7 +13680,6 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
- auto & rng = ctx->rng;
int idx = dist(rng);
llama_token result = candidates->data[idx].id;
@@ -13690,6 +13689,10 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra
return result;
}
+llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) {
+ return llama_sample_token_with_rng(ctx, candidates, ctx->rng);
+}
+
void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) {
const int64_t t_start_sample_us = ggml_time_us();