From 02f0c6fe7f9b7be24c7d339aed016e54a92388ea Mon Sep 17 00:00:00 2001 From: beiller Date: Sun, 12 Mar 2023 16:23:15 -0400 Subject: Add back top_k (#56) * Add back top_k * Update utils.cpp * Update utils.h --------- Co-authored-by: Bill Hamilton Co-authored-by: Georgi Gerganov --- utils.h | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) (limited to 'utils.h') diff --git a/utils.h b/utils.h index e331904b..5b3d7364 100644 --- a/utils.h +++ b/utils.h @@ -19,7 +19,7 @@ struct gpt_params { int32_t repeat_last_n = 64; // last n tokens to penalize // sampling parameters - int32_t top_k = 40; // unused + int32_t top_k = 40; float top_p = 0.95f; float temp = 0.80f; float repeat_penalty = 1.30f; @@ -77,26 +77,19 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab); // - consider only the top K tokens // - from them, consider only the top tokens with cumulative probability > P // -// TODO: not sure if this implementation is correct -// TODO: temperature is not implemented -// -gpt_vocab::id gpt_sample_top_k_top_p( - const gpt_vocab & vocab, - const float * logits, - int top_k, - double top_p, - double temp, - std::mt19937 & rng); - -gpt_vocab::id llama_sample_top_p( +gpt_vocab::id llama_sample_top_p_top_k( const gpt_vocab & vocab, const float * logits, std::vector & last_n_tokens, double repeat_penalty, + int top_k, double top_p, double temp, std::mt19937 & rng); +// filer to top K tokens from list of logits +void sample_top_k(std::vector> & logits_id, int top_k); + // // Quantization // -- cgit v1.2.3