diff options
author | Johannes Gäßler <johannesg@5d6.de> | 2024-05-07 23:07:58 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-07 23:07:58 +0200 |
commit | af0a5b616359809ce886ea433acedebb39b12969 (patch) | |
tree | c9da4c284790e7ddd053c1c5f5e3b636581ee07a /examples/server/server.cpp | |
parent | b6aa6702030320a3d5fbc2508307af0d7c947e40 (diff) |
server: fix incorrectly reported token probabilities (#7125)
* server: normalize token probabilities
* fix temperature == 0.0f
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r-- | examples/server/server.cpp | 34 |
1 files changed, 24 insertions, 10 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ff0814b2..85ae1ad9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2266,17 +2266,31 @@ struct server_context { llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false }; result.tok = id; - const int32_t n_probs = slot.sparams.n_probs; - if (slot.sparams.temp <= 0 && n_probs > 0) { - // for llama_sample_token_greedy we need to sort candidates - llama_sample_softmax(ctx, &cur_p); - } + const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs); + if (n_probs > 0) { + const size_t n_considered = slot.ctx_sampling->n_considered; - for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) { - result.probs.push_back({ - cur_p.data[i].id, - cur_p.data[i].p - }); + // Make sure at least n_probs top tokens are at the front of the vector: + if (slot.sparams.temp == 0.0f && n_probs > n_considered) { + llama_sample_top_k(ctx, &cur_p, n_probs, 0); + } + + if (slot.sparams.temp == 0.0f) { + // With greedy sampling the probabilities have possibly not been calculated. + for (size_t i = 0; i < n_probs; ++i) { + result.probs.push_back({ + cur_p.data[i].id, + i == 0 ? 1.0f : 0.0f + }); + } + } else { + for (size_t i = 0; i < n_probs; ++i) { + result.probs.push_back({ + cur_p.data[i].id, + i >= n_considered ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability. + }); + } + } } if (!process_token(result, slot)) { |