summaryrefslogtreecommitdiff
path: root/utils.h
diff options
context:
space:
mode:
authorbeiller <beiller@gmail.com>2023-03-12 16:23:15 -0400
committerGitHub <noreply@github.com>2023-03-12 22:23:15 +0200
commit02f0c6fe7f9b7be24c7d339aed016e54a92388ea (patch)
tree1da03a06b631ee711abeafa9ff410d87bd579f5d /utils.h
parenteb062bb012c4e131818dd757a6d3a757fdee3961 (diff)
Add back top_k (#56)
* Add back top_k * Update utils.cpp * Update utils.h --------- Co-authored-by: Bill Hamilton <bill.hamilton@shopify.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'utils.h')
-rw-r--r--utils.h19
1 files changed, 6 insertions, 13 deletions
diff --git a/utils.h b/utils.h
index e331904b..5b3d7364 100644
--- a/utils.h
+++ b/utils.h
@@ -19,7 +19,7 @@ struct gpt_params {
int32_t repeat_last_n = 64; // last n tokens to penalize
// sampling parameters
- int32_t top_k = 40; // unused
+ int32_t top_k = 40;
float top_p = 0.95f;
float temp = 0.80f;
float repeat_penalty = 1.30f;
@@ -77,26 +77,19 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
// - consider only the top K tokens
// - from them, consider only the top tokens with cumulative probability > P
//
-// TODO: not sure if this implementation is correct
-// TODO: temperature is not implemented
-//
-gpt_vocab::id gpt_sample_top_k_top_p(
- const gpt_vocab & vocab,
- const float * logits,
- int top_k,
- double top_p,
- double temp,
- std::mt19937 & rng);
-
-gpt_vocab::id llama_sample_top_p(
+gpt_vocab::id llama_sample_top_p_top_k(
const gpt_vocab & vocab,
const float * logits,
std::vector<gpt_vocab::id> & last_n_tokens,
double repeat_penalty,
+ int top_k,
double top_p,
double temp,
std::mt19937 & rng);
+// filer to top K tokens from list of logits
+void sample_top_k(std::vector<std::pair<double, gpt_vocab::id>> & logits_id, int top_k);
+
//
// Quantization
//