diff options
Diffstat (limited to 'src/llama-sampling.cpp')
-rw-r--r-- | src/llama-sampling.cpp | 311 |
1 files changed, 310 insertions, 1 deletions
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 7a185c5b..40d9963d 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1,4 +1,6 @@ #include "llama-sampling.h" +#include "llama-vocab.h" +#include "llama-grammar.h" #include <algorithm> #include <cstring> @@ -469,7 +471,7 @@ void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array } void llama_sample_top_n_sigma_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float top_n_sigma) { - + if (top_n_sigma <= 0.0f || candidates->size < 4) { // top_n_sigma <= 0: disabled // candidates->size < 4: no point in applying the transformation for fewer than 4 logits. @@ -725,3 +727,310 @@ llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng); } + + +// DRY + +// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) +static void get_overlapping_token_sequences(const llama_vocab& vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) { + for (llama_token token_id = 0; token_id < (llama_token)vocab.n_tokens(); token_id++) { + std::string word = llama_detokenize(vocab, { token_id }, true); + if (word.find(str) != std::string::npos) { + token_sequences.emplace(token_id, std::vector<llama_token>()); + } + else { + size_t word_len = word.size(), str_len = str.size(); + size_t pos = -1; + while ((pos = word.find(str[0], pos + 1)) != std::string::npos) { + bool match = true; + size_t i; + for (i = 1; i < str_len && i + pos < word_len; ++i) { + if (word[pos + i] != str[i]) { + match = false; + break; + } + } + if (match) { + std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false); + if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) { + tokenization.resize(max_tail_len); + } + + // Ensure we don't already have a duplicate matching tokenization + auto its = token_sequences.equal_range(token_id); + bool found = false; + for (auto it = its.first; it != its.second; ++it) { + if (tokenization == it->second) { + found = true; + break; + } + } + if (!found) { + token_sequences.emplace(token_id, tokenization); + } + } + } + } + } +} + +static const char* llama_sampler_dry_name(const struct llama_sampler* /*smpl*/) { + return "dry"; +} + + + +// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) +void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p) { + if (smpl->dry_multiplier == 0.0f || smpl->dry_base < 1.0f || smpl->dry_penalty_last_n == 0) { + return; + } + + int32_t effective_dry_penalty_last_n = (smpl->dry_penalty_last_n == -1) ? smpl->total_context_size : std::max(smpl->dry_penalty_last_n, 0); + int last_n_repeat = std::min(std::min((int)smpl->last_tokens.size(), effective_dry_penalty_last_n), smpl->total_context_size); + + if (last_n_repeat <= smpl->dry_allowed_length) { + return; + } + + smpl->dry_repeat_count.assign(last_n_repeat, 0); + smpl->dry_max_token_repeat.clear(); + + // Step 1: Look for restart sequences to limit the maximum repetition length. + // Work backwards through the context looking for any token that begins a restart sequence. + // + // The collection `restart_sequences` is a mapping from a "head" token to all "tail" + // sequences that together comprise a restart sequence. This allows us to quickly check + // whether each token is the head of a complete sequence. Most restart sequences are actually + // a single token, and for these the "tail" is an empty vector. + // + // If the token is a "head", test all restart sequences that begin with this token + // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and + // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The + // longest matching sequence (if any) is used to limit the maximum repetition length. + // + // Note that in the case case of a short sequence contained in a longer one, this might fail to + // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as + // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress + // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare. + // + // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we + // have already clamped the maximum tail sequence length when generating `restart_sequences`. + // With clamping, this scan is O(N) in the context length. + + int rep_limit = last_n_repeat; + for (int i = 0; i < last_n_repeat; ++i) { + llama_token token = smpl->last_tokens.rat(i); + auto its = smpl->dry_processed_breakers.equal_range(token); + if (its.first == smpl->dry_processed_breakers.end()) { + continue; + } + int longest_match = -1; + for (auto it = its.first; it != its.second; ++it) { + // Note that (*it) does not contain the head character, so seq_len will be + // the restart sequence length minus 1. + // In the common case of a single-token restart sequence, (*it) will be empty + // and we will trivially match. + int seq_len = (int)it->second.size(); + if (seq_len > longest_match && seq_len <= (int)i) { + bool match = true; + for (int offset = 0; offset < seq_len; ++offset) { + // The -1 when indexing `last_tokens` is because we already matched the head. + if (it->second[offset] != smpl->last_tokens.rat(i - offset - 1)) { + match = false; + break; + } + } + if (match) { + longest_match = seq_len; + } + } + } + if (longest_match >= 0) { + // We found a restart sequence starting `i` tokens from the end and continuing for + // `longest_match` tokens. + rep_limit = i - longest_match; + break; + } + } + if (rep_limit < smpl->dry_allowed_length) { + return; + } + + // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in + // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing + // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences. + // + // This algorithm is not currently documented on Wikipedia, but there is a clear description here: + // https://ivanyu.me/blog/2014/10/15/z-algorithm/ + // + // The code below is adapted from the public domain implementation by the same author here: + // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py + // + // Example: + // Last N tokens: a b c c b c y a b c + // Repeat counts: 0 0 3 1 0 2 0 0 0 0 + // ^ + // This `3` means that the last three tokens of the context (a b c) also appear here. + // + // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested + // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each + // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables + // ensure that the inner while loops only examine each token in the context once as the outer + // for loop iterates over the context. + + { + const int last = last_n_repeat - 1; + int rt = 0, lt = 0; + + for (int k = 1; k < last_n_repeat; ++k) { + if (k > rt) { + // If k is outside the current Z-box, do naive computation. + int n = 0; + while (n + k < last_n_repeat && smpl->last_tokens.rat(n) == smpl->last_tokens.rat(n + k)) { + ++n; + } + smpl->dry_repeat_count[last - k] = std::min(n, rep_limit); + if (n > 0) { + lt = k; + rt = k + n - 1; + } + } + else { + // If k is inside the current Z-box, consider two cases. + + int p = k - lt; // Pair index. + int right_part_len = rt - k + 1; + + if (smpl->dry_repeat_count[last - p] < right_part_len) { + int n = std::min(smpl->dry_repeat_count[last - p], rep_limit); + smpl->dry_repeat_count[last - k] = n; + } + else { + int i = rt + 1; + while (i < last_n_repeat && smpl->last_tokens.rat(i) == smpl->last_tokens.rat(i - k)) { + i += 1; + } + + int n = std::min(i - k, rep_limit); + smpl->dry_repeat_count[last - k] = n; + lt = k; + rt = i - 1; + } + } + } + } + + // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length + // that would be generated by emitting each new token that would extend a sequence. + // + // Following the same example as above: + // Last N tokens: a b c c b c y a b c + // Repeat counts: 0 0 3 1 0 2 0 0 0 0 + // + // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition. + // c: 3 -> 4 (from `a b c` to `a b c c`) + // b: 1 -> 2 (from `c` to `c b`) + // y: 2 -> 3 (from `b c` to `b c y`) + + for (int i = 0; i < last_n_repeat - 1; ++i) { + int repeat_len = smpl->dry_repeat_count[i]; + if (repeat_len >= smpl->dry_allowed_length) { + // This token ends a repeat, so the next token would continue one. + // By convention, the value of `repeat_len` only includes the tokens currently + // in the context, not the new token that would be added. + llama_token token = smpl->last_tokens.rat(last_n_repeat - 2 - i); + // Track the maximum sequence ending in this token. + const auto& it = smpl->dry_max_token_repeat.find(token); + if (it == smpl->dry_max_token_repeat.end() || it->second < repeat_len) { + smpl->dry_max_token_repeat[token] = repeat_len; + } + } + } + + // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens. + + // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`. + // Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()` + const float FLOAT_MAX_LOG = 88.7228391f; + int max_exponent = 0; + if (smpl->dry_base > 1.000001f) { + max_exponent = FLOAT_MAX_LOG / std::log(smpl->dry_base); + } + + for (size_t i = 0; i < cur_p->size; ++i) { + const auto& af_kvp = smpl->dry_max_token_repeat.find(cur_p->data[i].id); + if (af_kvp != smpl->dry_max_token_repeat.end()) { + // Check all sequence breakers starting with this token + auto range = smpl->dry_processed_breakers.equal_range(cur_p->data[i].id); + bool is_single_token_breaker = false; + + for (auto it = range.first; it != range.second; ++it) { + if (it->second.empty()) { + is_single_token_breaker = true; + break; + } + } + + // Apply penalty only if it's not a single-token sequence breaker + if (!is_single_token_breaker) { + int repeat_exp = af_kvp->second - smpl->dry_allowed_length; + if (max_exponent > 0 && repeat_exp > max_exponent) { + repeat_exp = max_exponent; + } + float penalty = smpl->dry_multiplier * std::pow(smpl->dry_base, repeat_exp); + cur_p->data[i].logit -= penalty; + } + } + } + + cur_p->sorted = false; +} + + + +struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab& vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { + int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0); + std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers; + const int MAX_CHAR_LEN = 40; + const int MAX_SEQ_LEN = 20; + + const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0); + + if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) { + // Process sequence breakers + for (size_t i = 0; i < num_breakers; ++i) { + if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) { + LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i); + continue; + } + + std::string sequence_break(seq_breakers[i]); + if (sequence_break.empty()) { + LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n"); + continue; + } + + if (sequence_break.size() > MAX_CHAR_LEN) { + LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN); + sequence_break.resize(MAX_CHAR_LEN); + } + + get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN); + } + } + + return new llama_sampler_dry { + /* .total_context_size = */ context_size, + /* .dry_multiplier = */ dry_multiplier, + /* .dry_base = */ dry_base, + /* .dry_allowed_length = */ dry_allowed_length, + /* .dry_penalty_last_n = */ dry_penalty_last_n, + /* .dry_processed_breakers = */ std::move(processed_breakers), + /* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{}, + /* .dry_max_token_repeat = */ {}, + /* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0), + }; +} + + |