diff options
Diffstat (limited to 'src/llama-sampling.h')
-rw-r--r-- | src/llama-sampling.h | 32 |
1 files changed, 31 insertions, 1 deletions
diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 69d92a3a..855278e2 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -1,7 +1,7 @@ #pragma once #include "llama-impl.h" - +#include <unordered_map> struct llama_sampling { llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {} @@ -35,6 +35,34 @@ void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_ void llama_sample_xtc_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float probability, float threshold, size_t min_keep); void llama_sample_top_n_sigma_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float top_n_sigma); +struct llama_sampler_dry { + int32_t total_context_size; + + const float dry_multiplier; + const float dry_base; + const int32_t dry_allowed_length; + const int32_t dry_penalty_last_n; + + std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers; + std::vector<int> dry_repeat_count; + std::unordered_map<llama_token, int> dry_max_token_repeat; + ring_buffer<llama_token> last_tokens; +}; + +struct llama_sampler_dry * llama_sampler_init_dry_impl( + const struct llama_vocab & vocab, + int32_t context_size, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const char ** seq_breakers, + size_t num_breakers); + +void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p); + + + void llama_sample_repetition_penalties_impl( struct llama_sampling * smpl, llama_token_data_array * candidates, @@ -56,3 +84,5 @@ llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, ll llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng); llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); + + |