summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
authorAlex Renda <alexrenda@users.noreply.github.com>2023-06-24 03:15:01 -0700
committerGitHub <noreply@github.com>2023-06-24 13:15:01 +0300
commitb061ba9e2a7a2c335a200df8c11aed5e31e4ccbb (patch)
tree0bf01d16d556af5ab9a3990d8859a5ff1f3ea4c0 /llama.cpp
parent527b6fba1d237befb324fd846bda7418c0fa394d (diff)
llama : fix top-p sampling to match the canonical definition (#1953)
* Fix top-p sampling to match the standard definition (smallest set that has probability mass at least p, not largest set with probability mass less than p) * top-p: correct gt to gte * add test for correct top-p behavior
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp7
1 files changed, 4 insertions, 3 deletions
diff --git a/llama.cpp b/llama.cpp
index a528eef4..ac22a48f 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -2015,9 +2015,10 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can
for (size_t i = 0; i < candidates->size; ++i) {
cum_sum += candidates->data[i].p;
- // Check if the running sum is greater than p or if we have kept at least min_keep tokens
- if (cum_sum > p && i >= min_keep) {
- last_idx = i;
+ // Check if the running sum is at least p or if we have kept at least min_keep tokens
+ // we set the last index to i+1 to indicate that the current iterate should be included in the set
+ if (cum_sum >= p && i + 1 >= min_keep) {
+ last_idx = i + 1;
break;
}
}