summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/sampling.cpp2
-rw-r--r--llama.cpp4
-rw-r--r--tests/test-sampling.cpp2
3 files changed, 7 insertions, 1 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp
index e8675a8c..844ad7c5 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -132,7 +132,7 @@ static void sampler_queue(
const float temp = params.temp;
const float dynatemp_range = params.dynatemp_range;
const float dynatemp_exponent = params.dynatemp_exponent;
- const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
+ const int32_t top_k = params.top_k;
const float top_p = params.top_p;
const float min_p = params.min_p;
const float tfs_z = params.tfs_z;
diff --git a/llama.cpp b/llama.cpp
index c45ae1d5..f8f5796a 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -8585,6 +8585,10 @@ void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * can
// }
const int64_t t_start_sample_us = ggml_time_us();
+
+ if (k <= 0) {
+ k = candidates->size;
+ }
k = std::max(k, (int) min_keep);
k = std::min(k, (int) candidates->size);
diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp
index c3b3d662..6374958f 100644
--- a/tests/test-sampling.cpp
+++ b/tests/test-sampling.cpp
@@ -235,6 +235,8 @@ int main(void) {
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1);
test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3);
+ test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);
+ test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 0);
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);