summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
Diffstat (limited to 'common')
-rw-r--r--common/sampling.cpp5
-rw-r--r--common/sampling.h1
2 files changed, 6 insertions, 0 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp
index cc83600d..3715a798 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -35,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
result->prev.resize(params.n_prev);
+ result->n_considered = 0;
+
llama_sampling_set_rng_seed(result, params.seed);
return result;
@@ -64,6 +66,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear();
+ ctx->n_considered = 0;
}
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
@@ -253,6 +256,8 @@ static llama_token llama_sampling_sample_impl(
}
}
+ ctx_sampling->n_considered = cur_p.size;
+
return id;
}
diff --git a/common/sampling.h b/common/sampling.h
index cf7081e3..5b73ecdc 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -81,6 +81,7 @@ struct llama_sampling_context {
// TODO: replace with ring-buffer
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
+ size_t n_considered;
std::mt19937 rng;
};