summaryrefslogtreecommitdiff
path: root/main.cpp
diff options
context:
space:
mode:
authorbeiller <beiller@gmail.com>2023-03-12 16:23:15 -0400
committerGitHub <noreply@github.com>2023-03-12 22:23:15 +0200
commit02f0c6fe7f9b7be24c7d339aed016e54a92388ea (patch)
tree1da03a06b631ee711abeafa9ff410d87bd579f5d /main.cpp
parenteb062bb012c4e131818dd757a6d3a757fdee3961 (diff)
Add back top_k (#56)
* Add back top_k * Update utils.cpp * Update utils.h --------- Co-authored-by: Bill Hamilton <bill.hamilton@shopify.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'main.cpp')
-rw-r--r--main.cpp3
1 files changed, 2 insertions, 1 deletions
diff --git a/main.cpp b/main.cpp
index a11d755a..01556143 100644
--- a/main.cpp
+++ b/main.cpp
@@ -825,6 +825,7 @@ int main(int argc, char ** argv) {
if (i >= embd_inp.size()) {
// sample next token
+ const float top_k = params.top_k;
const float top_p = params.top_p;
const float temp = params.temp;
const float repeat_penalty = params.repeat_penalty;
@@ -836,7 +837,7 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_sample_us = ggml_time_us();
- id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_p, temp, rng);
+ id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng);
last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id);