diff options
author | firecoperana <xuqiaowei1124@gmail.com> | 2025-06-19 02:24:53 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-19 10:24:53 +0300 |
commit | 3f111ad7bbb2d4f721332f9b2b344e48b3bbf9aa (patch) | |
tree | a3a17ee74e0436253e17f0d322320ed554d34b0a /src/llama-sampling.h | |
parent | c5368148cf3af7a3694e0eb03d24a08326c01d12 (diff) |
add dry sampler (#513)
* add dry sampler
* use vocab instead of model in dry_init function
* fix compile error for build test
---------
Co-authored-by: firecoperana <firecoperana>
Diffstat (limited to 'src/llama-sampling.h')
-rw-r--r-- | src/llama-sampling.h | 32 |
1 files changed, 31 insertions, 1 deletions
diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 69d92a3a..855278e2 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -1,7 +1,7 @@ #pragma once #include "llama-impl.h" - +#include <unordered_map> struct llama_sampling { llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {} @@ -35,6 +35,34 @@ void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_ void llama_sample_xtc_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float probability, float threshold, size_t min_keep); void llama_sample_top_n_sigma_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float top_n_sigma); +struct llama_sampler_dry { + int32_t total_context_size; + + const float dry_multiplier; + const float dry_base; + const int32_t dry_allowed_length; + const int32_t dry_penalty_last_n; + + std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers; + std::vector<int> dry_repeat_count; + std::unordered_map<llama_token, int> dry_max_token_repeat; + ring_buffer<llama_token> last_tokens; +}; + +struct llama_sampler_dry * llama_sampler_init_dry_impl( + const struct llama_vocab & vocab, + int32_t context_size, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const char ** seq_breakers, + size_t num_breakers); + +void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p); + + + void llama_sample_repetition_penalties_impl( struct llama_sampling * smpl, llama_token_data_array * candidates, @@ -56,3 +84,5 @@ llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, ll llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng); llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); + + |