summaryrefslogtreecommitdiff
path: root/examples/server/server.cpp
diff options
context:
space:
mode:
authorJohannes Gäßler <johannesg@5d6.de>2024-05-07 23:07:58 +0200
committerGitHub <noreply@github.com>2024-05-07 23:07:58 +0200
commitaf0a5b616359809ce886ea433acedebb39b12969 (patch)
treec9da4c284790e7ddd053c1c5f5e3b636581ee07a /examples/server/server.cpp
parentb6aa6702030320a3d5fbc2508307af0d7c947e40 (diff)
server: fix incorrectly reported token probabilities (#7125)
* server: normalize token probabilities * fix temperature == 0.0f
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r--examples/server/server.cpp34
1 files changed, 24 insertions, 10 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index ff0814b2..85ae1ad9 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -2266,17 +2266,31 @@ struct server_context {
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
result.tok = id;
- const int32_t n_probs = slot.sparams.n_probs;
- if (slot.sparams.temp <= 0 && n_probs > 0) {
- // for llama_sample_token_greedy we need to sort candidates
- llama_sample_softmax(ctx, &cur_p);
- }
+ const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
+ if (n_probs > 0) {
+ const size_t n_considered = slot.ctx_sampling->n_considered;
- for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) {
- result.probs.push_back({
- cur_p.data[i].id,
- cur_p.data[i].p
- });
+ // Make sure at least n_probs top tokens are at the front of the vector:
+ if (slot.sparams.temp == 0.0f && n_probs > n_considered) {
+ llama_sample_top_k(ctx, &cur_p, n_probs, 0);
+ }
+
+ if (slot.sparams.temp == 0.0f) {
+ // With greedy sampling the probabilities have possibly not been calculated.
+ for (size_t i = 0; i < n_probs; ++i) {
+ result.probs.push_back({
+ cur_p.data[i].id,
+ i == 0 ? 1.0f : 0.0f
+ });
+ }
+ } else {
+ for (size_t i = 0; i < n_probs; ++i) {
+ result.probs.push_back({
+ cur_p.data[i].id,
+ i >= n_considered ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
+ });
+ }
+ }
}
if (!process_token(result, slot)) {