diff options
author | Johannes Gäßler <johannesg@5d6.de> | 2024-01-28 09:59:49 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-28 09:59:49 +0100 |
commit | 9241c3a2ace544aef708334e54bbdddb90208ee8 (patch) | |
tree | 6854a475e92afdf7bfc8ff0c57199e528f17079f /llama.cpp | |
parent | b2b2bf988c098851b4f3831f0cf38394bff75121 (diff) |
Apply min_p to unsorted tokens (#5115)
Diffstat (limited to 'llama.cpp')
-rw-r--r-- | llama.cpp | 54 |
1 files changed, 45 insertions, 9 deletions
@@ -52,6 +52,7 @@ #include <algorithm> #include <array> #include <cassert> +#include <cfloat> #include <cinttypes> #include <climits> #include <cmath> @@ -8246,21 +8247,56 @@ void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * can return; } - llama_sample_softmax(ctx, candidates); - const int64_t t_start_sample_us = ggml_time_us(); - float scale = candidates->data[0].p; // scale by max prob - size_t i = 1; // first token always matches + bool min_p_applied = false; + + // if the candidates aren't sorted, try the unsorted implementation first + if (!candidates->sorted) { + std::vector<llama_token_data> filtered_tokens; + + float max_logit = -FLT_MAX; + for (size_t i = 0; i < candidates->size; ++i) { + max_logit = std::max(max_logit, candidates->data[i].logit); + } + const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max + + for (size_t i = 0; i < candidates->size; ++i) { + if (candidates->data[i].logit >= min_logit) { + filtered_tokens.push_back(candidates->data[i]); + } + } - for (; i < candidates->size; ++i) { - if (candidates->data[i].p < p * scale && i >= min_keep) { - break; // prob too small + // if we have enough values the operation was a success + if (filtered_tokens.size() >= min_keep) { + memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); + candidates->size = filtered_tokens.size(); + min_p_applied = true; } } - // Resize the output vector to keep only the matching tokens - candidates->size = i; + // if the candidates are sorted or the unsorted implementation failed, use this implementation + if (!min_p_applied) { + // Sort the logits in descending order + if (!candidates->sorted) { + std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + candidates->sorted = true; + } + + const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max + size_t i = 1; // first token always matches + + for (; i < candidates->size; ++i) { + if (candidates->data[i].logit < min_logit && i >= min_keep) { + break; // prob too small + } + } + + // Resize the output vector to keep only the matching tokens + candidates->size = i; + } if (ctx) { ctx->t_sample_us += ggml_time_us() - t_start_sample_us; |