summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/llama-impl.h114
-rw-r--r--src/llama-sampling.cpp311
-rw-r--r--src/llama-sampling.h32
-rw-r--r--src/llama-vocab.cpp19
-rw-r--r--src/llama-vocab.h7
-rw-r--r--src/llama.cpp45
6 files changed, 526 insertions, 2 deletions
diff --git a/src/llama-impl.h b/src/llama-impl.h
index a9cbe0df..a50f60cf 100644
--- a/src/llama-impl.h
+++ b/src/llama-impl.h
@@ -9,6 +9,7 @@
#define LLAMA_API_INTERNAL
#include "llama.h"
+#include <stdexcept>
#ifdef __GNUC__
#ifdef __MINGW32__
@@ -20,6 +21,7 @@
#define LLAMA_ATTRIBUTE_FORMAT(...)
#endif
+
//
// logging
//
@@ -52,3 +54,115 @@ static void replace_all(std::string & s, const std::string & search, const std::
builder.append(s, last_pos, std::string::npos);
s = std::move(builder);
}
+
+
+// the ring buffer works similarly to std::deque, but with a fixed capacity
+template<typename T>
+struct ring_buffer {
+ ring_buffer(size_t cap) : capacity(cap), data(cap) {}
+
+ T& front() {
+ if (sz == 0) {
+ throw std::runtime_error("ring buffer is empty");
+ }
+ return data[first];
+ }
+
+ const T& front() const {
+ if (sz == 0) {
+ throw std::runtime_error("ring buffer is empty");
+ }
+ return data[first];
+ }
+
+ T& back() {
+ if (sz == 0) {
+ throw std::runtime_error("ring buffer is empty");
+ }
+ return data[pos];
+ }
+
+ const T& back() const {
+ if (sz == 0) {
+ throw std::runtime_error("ring buffer is empty");
+ }
+ return data[pos];
+ }
+
+ void push_back(const T& value) {
+ if (capacity == 0) {
+ throw std::runtime_error("ring buffer: capacity is zero");
+ }
+
+ if (sz == capacity) {
+ // advance the start when buffer is full
+ first = (first + 1) % capacity;
+ }
+ else {
+ sz++;
+ }
+ data[pos] = value;
+ pos = (pos + 1) % capacity;
+ }
+
+ T pop_front() {
+ if (sz == 0) {
+ throw std::runtime_error("ring buffer is empty");
+ }
+ T value = data[first];
+ first = (first + 1) % capacity;
+ sz--;
+ return value;
+ }
+
+ //T & operator[](size_t i) {
+ // if (i >= sz) {
+ // throw std::runtime_error("ring buffer: index out of bounds");
+ // }
+ // return data[(first + i) % capacity];
+ //}
+
+ //const T & at(size_t i) const {
+ // if (i >= sz) {
+ // throw std::runtime_error("ring buffer: index out of bounds");
+ // }
+ // return data[(first + i) % capacity];
+ //}
+
+ const T& rat(size_t i) const {
+ if (i >= sz) {
+ throw std::runtime_error("ring buffer: index out of bounds");
+ }
+ return data[(first + sz - i - 1) % capacity];
+ }
+
+ std::vector<T> to_vector() const {
+ std::vector<T> result;
+ result.reserve(sz);
+ for (size_t i = 0; i < sz; i++) {
+ result.push_back(data[(first + i) % capacity]);
+ }
+ return result;
+ }
+
+ void clear() {
+ // here only reset the status of the buffer
+ sz = 0;
+ first = 0;
+ pos = 0;
+ }
+
+ bool empty() const {
+ return sz == 0;
+ }
+
+ size_t size() const {
+ return sz;
+ }
+
+ size_t capacity = 0;
+ size_t sz = 0;
+ size_t first = 0;
+ size_t pos = 0;
+ std::vector<T> data;
+};
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp
index 7a185c5b..40d9963d 100644
--- a/src/llama-sampling.cpp
+++ b/src/llama-sampling.cpp
@@ -1,4 +1,6 @@
#include "llama-sampling.h"
+#include "llama-vocab.h"
+#include "llama-grammar.h"
#include <algorithm>
#include <cstring>
@@ -469,7 +471,7 @@ void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array
}
void llama_sample_top_n_sigma_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float top_n_sigma) {
-
+
if (top_n_sigma <= 0.0f || candidates->size < 4) {
// top_n_sigma <= 0: disabled
// candidates->size < 4: no point in applying the transformation for fewer than 4 logits.
@@ -725,3 +727,310 @@ llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama
llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
}
+
+
+// DRY
+
+// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
+static void get_overlapping_token_sequences(const llama_vocab& vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
+ for (llama_token token_id = 0; token_id < (llama_token)vocab.n_tokens(); token_id++) {
+ std::string word = llama_detokenize(vocab, { token_id }, true);
+ if (word.find(str) != std::string::npos) {
+ token_sequences.emplace(token_id, std::vector<llama_token>());
+ }
+ else {
+ size_t word_len = word.size(), str_len = str.size();
+ size_t pos = -1;
+ while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
+ bool match = true;
+ size_t i;
+ for (i = 1; i < str_len && i + pos < word_len; ++i) {
+ if (word[pos + i] != str[i]) {
+ match = false;
+ break;
+ }
+ }
+ if (match) {
+ std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false);
+ if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
+ tokenization.resize(max_tail_len);
+ }
+
+ // Ensure we don't already have a duplicate matching tokenization
+ auto its = token_sequences.equal_range(token_id);
+ bool found = false;
+ for (auto it = its.first; it != its.second; ++it) {
+ if (tokenization == it->second) {
+ found = true;
+ break;
+ }
+ }
+ if (!found) {
+ token_sequences.emplace(token_id, tokenization);
+ }
+ }
+ }
+ }
+ }
+}
+
+static const char* llama_sampler_dry_name(const struct llama_sampler* /*smpl*/) {
+ return "dry";
+}
+
+
+
+// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
+void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p) {
+ if (smpl->dry_multiplier == 0.0f || smpl->dry_base < 1.0f || smpl->dry_penalty_last_n == 0) {
+ return;
+ }
+
+ int32_t effective_dry_penalty_last_n = (smpl->dry_penalty_last_n == -1) ? smpl->total_context_size : std::max(smpl->dry_penalty_last_n, 0);
+ int last_n_repeat = std::min(std::min((int)smpl->last_tokens.size(), effective_dry_penalty_last_n), smpl->total_context_size);
+
+ if (last_n_repeat <= smpl->dry_allowed_length) {
+ return;
+ }
+
+ smpl->dry_repeat_count.assign(last_n_repeat, 0);
+ smpl->dry_max_token_repeat.clear();
+
+ // Step 1: Look for restart sequences to limit the maximum repetition length.
+ // Work backwards through the context looking for any token that begins a restart sequence.
+ //
+ // The collection `restart_sequences` is a mapping from a "head" token to all "tail"
+ // sequences that together comprise a restart sequence. This allows us to quickly check
+ // whether each token is the head of a complete sequence. Most restart sequences are actually
+ // a single token, and for these the "tail" is an empty vector.
+ //
+ // If the token is a "head", test all restart sequences that begin with this token
+ // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
+ // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
+ // longest matching sequence (if any) is used to limit the maximum repetition length.
+ //
+ // Note that in the case case of a short sequence contained in a longer one, this might fail to
+ // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
+ // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
+ // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
+ //
+ // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
+ // have already clamped the maximum tail sequence length when generating `restart_sequences`.
+ // With clamping, this scan is O(N) in the context length.
+
+ int rep_limit = last_n_repeat;
+ for (int i = 0; i < last_n_repeat; ++i) {
+ llama_token token = smpl->last_tokens.rat(i);
+ auto its = smpl->dry_processed_breakers.equal_range(token);
+ if (its.first == smpl->dry_processed_breakers.end()) {
+ continue;
+ }
+ int longest_match = -1;
+ for (auto it = its.first; it != its.second; ++it) {
+ // Note that (*it) does not contain the head character, so seq_len will be
+ // the restart sequence length minus 1.
+ // In the common case of a single-token restart sequence, (*it) will be empty
+ // and we will trivially match.
+ int seq_len = (int)it->second.size();
+ if (seq_len > longest_match && seq_len <= (int)i) {
+ bool match = true;
+ for (int offset = 0; offset < seq_len; ++offset) {
+ // The -1 when indexing `last_tokens` is because we already matched the head.
+ if (it->second[offset] != smpl->last_tokens.rat(i - offset - 1)) {
+ match = false;
+ break;
+ }
+ }
+ if (match) {
+ longest_match = seq_len;
+ }
+ }
+ }
+ if (longest_match >= 0) {
+ // We found a restart sequence starting `i` tokens from the end and continuing for
+ // `longest_match` tokens.
+ rep_limit = i - longest_match;
+ break;
+ }
+ }
+ if (rep_limit < smpl->dry_allowed_length) {
+ return;
+ }
+
+ // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
+ // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
+ // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
+ //
+ // This algorithm is not currently documented on Wikipedia, but there is a clear description here:
+ // https://ivanyu.me/blog/2014/10/15/z-algorithm/
+ //
+ // The code below is adapted from the public domain implementation by the same author here:
+ // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
+ //
+ // Example:
+ // Last N tokens: a b c c b c y a b c
+ // Repeat counts: 0 0 3 1 0 2 0 0 0 0
+ // ^
+ // This `3` means that the last three tokens of the context (a b c) also appear here.
+ //
+ // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
+ // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
+ // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
+ // ensure that the inner while loops only examine each token in the context once as the outer
+ // for loop iterates over the context.
+
+ {
+ const int last = last_n_repeat - 1;
+ int rt = 0, lt = 0;
+
+ for (int k = 1; k < last_n_repeat; ++k) {
+ if (k > rt) {
+ // If k is outside the current Z-box, do naive computation.
+ int n = 0;
+ while (n + k < last_n_repeat && smpl->last_tokens.rat(n) == smpl->last_tokens.rat(n + k)) {
+ ++n;
+ }
+ smpl->dry_repeat_count[last - k] = std::min(n, rep_limit);
+ if (n > 0) {
+ lt = k;
+ rt = k + n - 1;
+ }
+ }
+ else {
+ // If k is inside the current Z-box, consider two cases.
+
+ int p = k - lt; // Pair index.
+ int right_part_len = rt - k + 1;
+
+ if (smpl->dry_repeat_count[last - p] < right_part_len) {
+ int n = std::min(smpl->dry_repeat_count[last - p], rep_limit);
+ smpl->dry_repeat_count[last - k] = n;
+ }
+ else {
+ int i = rt + 1;
+ while (i < last_n_repeat && smpl->last_tokens.rat(i) == smpl->last_tokens.rat(i - k)) {
+ i += 1;
+ }
+
+ int n = std::min(i - k, rep_limit);
+ smpl->dry_repeat_count[last - k] = n;
+ lt = k;
+ rt = i - 1;
+ }
+ }
+ }
+ }
+
+ // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
+ // that would be generated by emitting each new token that would extend a sequence.
+ //
+ // Following the same example as above:
+ // Last N tokens: a b c c b c y a b c
+ // Repeat counts: 0 0 3 1 0 2 0 0 0 0
+ //
+ // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
+ // c: 3 -> 4 (from `a b c` to `a b c c`)
+ // b: 1 -> 2 (from `c` to `c b`)
+ // y: 2 -> 3 (from `b c` to `b c y`)
+
+ for (int i = 0; i < last_n_repeat - 1; ++i) {
+ int repeat_len = smpl->dry_repeat_count[i];
+ if (repeat_len >= smpl->dry_allowed_length) {
+ // This token ends a repeat, so the next token would continue one.
+ // By convention, the value of `repeat_len` only includes the tokens currently
+ // in the context, not the new token that would be added.
+ llama_token token = smpl->last_tokens.rat(last_n_repeat - 2 - i);
+ // Track the maximum sequence ending in this token.
+ const auto& it = smpl->dry_max_token_repeat.find(token);
+ if (it == smpl->dry_max_token_repeat.end() || it->second < repeat_len) {
+ smpl->dry_max_token_repeat[token] = repeat_len;
+ }
+ }
+ }
+
+ // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
+
+ // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
+ // Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
+ const float FLOAT_MAX_LOG = 88.7228391f;
+ int max_exponent = 0;
+ if (smpl->dry_base > 1.000001f) {
+ max_exponent = FLOAT_MAX_LOG / std::log(smpl->dry_base);
+ }
+
+ for (size_t i = 0; i < cur_p->size; ++i) {
+ const auto& af_kvp = smpl->dry_max_token_repeat.find(cur_p->data[i].id);
+ if (af_kvp != smpl->dry_max_token_repeat.end()) {
+ // Check all sequence breakers starting with this token
+ auto range = smpl->dry_processed_breakers.equal_range(cur_p->data[i].id);
+ bool is_single_token_breaker = false;
+
+ for (auto it = range.first; it != range.second; ++it) {
+ if (it->second.empty()) {
+ is_single_token_breaker = true;
+ break;
+ }
+ }
+
+ // Apply penalty only if it's not a single-token sequence breaker
+ if (!is_single_token_breaker) {
+ int repeat_exp = af_kvp->second - smpl->dry_allowed_length;
+ if (max_exponent > 0 && repeat_exp > max_exponent) {
+ repeat_exp = max_exponent;
+ }
+ float penalty = smpl->dry_multiplier * std::pow(smpl->dry_base, repeat_exp);
+ cur_p->data[i].logit -= penalty;
+ }
+ }
+ }
+
+ cur_p->sorted = false;
+}
+
+
+
+struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab& vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
+ int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
+ std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
+ const int MAX_CHAR_LEN = 40;
+ const int MAX_SEQ_LEN = 20;
+
+ const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
+
+ if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
+ // Process sequence breakers
+ for (size_t i = 0; i < num_breakers; ++i) {
+ if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) {
+ LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i);
+ continue;
+ }
+
+ std::string sequence_break(seq_breakers[i]);
+ if (sequence_break.empty()) {
+ LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n");
+ continue;
+ }
+
+ if (sequence_break.size() > MAX_CHAR_LEN) {
+ LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN);
+ sequence_break.resize(MAX_CHAR_LEN);
+ }
+
+ get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
+ }
+ }
+
+ return new llama_sampler_dry {
+ /* .total_context_size = */ context_size,
+ /* .dry_multiplier = */ dry_multiplier,
+ /* .dry_base = */ dry_base,
+ /* .dry_allowed_length = */ dry_allowed_length,
+ /* .dry_penalty_last_n = */ dry_penalty_last_n,
+ /* .dry_processed_breakers = */ std::move(processed_breakers),
+ /* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{},
+ /* .dry_max_token_repeat = */ {},
+ /* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0),
+ };
+}
+
+
diff --git a/src/llama-sampling.h b/src/llama-sampling.h
index 69d92a3a..855278e2 100644
--- a/src/llama-sampling.h
+++ b/src/llama-sampling.h
@@ -1,7 +1,7 @@
#pragma once
#include "llama-impl.h"
-
+#include <unordered_map>
struct llama_sampling {
llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
@@ -35,6 +35,34 @@ void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_
void llama_sample_xtc_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float probability, float threshold, size_t min_keep);
void llama_sample_top_n_sigma_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float top_n_sigma);
+struct llama_sampler_dry {
+ int32_t total_context_size;
+
+ const float dry_multiplier;
+ const float dry_base;
+ const int32_t dry_allowed_length;
+ const int32_t dry_penalty_last_n;
+
+ std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
+ std::vector<int> dry_repeat_count;
+ std::unordered_map<llama_token, int> dry_max_token_repeat;
+ ring_buffer<llama_token> last_tokens;
+};
+
+struct llama_sampler_dry * llama_sampler_init_dry_impl(
+ const struct llama_vocab & vocab,
+ int32_t context_size,
+ float dry_multiplier,
+ float dry_base,
+ int32_t dry_allowed_length,
+ int32_t dry_penalty_last_n,
+ const char ** seq_breakers,
+ size_t num_breakers);
+
+void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p);
+
+
+
void llama_sample_repetition_penalties_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
@@ -56,3 +84,5 @@ llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, ll
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
+
+
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
index 09399417..abf48824 100644
--- a/src/llama-vocab.cpp
+++ b/src/llama-vocab.cpp
@@ -75,6 +75,9 @@ struct naive_trie {
llama_token value;
};
+uint32_t llama_vocab::n_tokens() const {
+ return (uint32_t)id_to_token.size();
+}
//
// impl
//
@@ -1741,3 +1744,19 @@ int32_t llama_detokenize_impl(
return total <= text_len_max ? total : -total;
}
+
+std::string llama_detokenize(const struct llama_vocab& vocab, const std::vector<llama_token>& tokens, bool special) {
+ std::string text;
+ text.resize(std::max(text.capacity(), tokens.size()));
+ int32_t n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+ if (n_chars < 0) {
+ text.resize(-n_chars);
+ n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
+ GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
+ }
+
+ text.resize(n_chars);
+
+ // NOTE: the original tokenizer decodes bytes after collecting the pieces.
+ return text;
+}
diff --git a/src/llama-vocab.h b/src/llama-vocab.h
index 7adfc16d..a461eca0 100644
--- a/src/llama-vocab.h
+++ b/src/llama-vocab.h
@@ -23,6 +23,8 @@ struct llama_vocab {
int max_token_len = 0; // used for optimizing longest token search
+ uint32_t n_tokens() const;
+
std::unordered_map<token, id> token_to_id;
std::vector<token_data> id_to_token;
@@ -130,3 +132,8 @@ int32_t llama_detokenize_impl(
int32_t text_len_max,
bool remove_special,
bool unparse_special);
+
+std::string llama_detokenize(
+ const struct llama_vocab& vocab,
+ const std::vector<llama_token>& tokens,
+ bool special);
diff --git a/src/llama.cpp b/src/llama.cpp
index af8ef9be..c0f147b9 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -20849,6 +20849,10 @@ enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
return model->vocab.type;
}
+const struct llama_vocab* llama_get_model_vocab(const struct llama_model* model) {
+ return &model->vocab;
+}
+
enum llama_rope_type llama_rope_type(const struct llama_model * model) {
switch (model->arch) {
// these models do not use RoPE
@@ -23280,6 +23284,11 @@ void llama_sample_top_n_sigma(struct llama_context * ctx, llama_token_data_array
llama_sample_top_n_sigma_impl(ctx ? &ctx->sampling : nullptr, candidates_p, top_n_sigma);
}
+
+void llama_sample_dry(struct llama_context* ctx, struct llama_sampler_dry* smpl, llama_token_data_array* candidates_p) {
+ llama_sampler_dry_apply(smpl, candidates_p);
+}
+
void llama_sample_repetition_penalties(
struct llama_context * ctx,
llama_token_data_array * candidates,
@@ -23327,6 +23336,42 @@ int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix,
return 0;
}
+struct llama_sampler_dry * llama_sampler_init_dry(const struct llama_vocab* vocab, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
+ return llama_sampler_init_dry_impl(*vocab, vocab->n_tokens(), dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers, num_breakers);
+}
+
+void llama_sampler_dry_reset(struct llama_sampler_dry* smpl) {
+ smpl->last_tokens.clear();
+ smpl->dry_repeat_count.clear();
+ smpl->dry_max_token_repeat.clear();
+}
+
+void llama_sampler_dry_free(struct llama_sampler_dry* smpl) {
+ delete smpl;
+}
+
+struct llama_sampler_dry* llama_sampler_dry_clone(struct llama_sampler_dry* smpl) {
+ // nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying
+ auto* result = llama_sampler_init_dry(nullptr, smpl->dry_multiplier, smpl->dry_base, smpl->dry_allowed_length, smpl->dry_penalty_last_n, NULL, 0);
+ // Copy the state, including the processed breakers
+ {
+ auto* result_ctx = smpl;
+ result_ctx->dry_processed_breakers = smpl->dry_processed_breakers;
+ result_ctx->dry_repeat_count = smpl->dry_repeat_count;
+ result_ctx->dry_max_token_repeat = smpl->dry_max_token_repeat;
+ result_ctx->last_tokens = smpl->last_tokens;
+ }
+
+ return result;
+}
+
+void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token) {
+ if (smpl->dry_multiplier == 0.0f || smpl->dry_base < 1.0f || smpl->dry_penalty_last_n == 0) {
+ return;
+ }
+ smpl->last_tokens.push_back(token);
+}
+
int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) {
std::string str_split_path(split_path);
char postfix[32];