summaryrefslogtreecommitdiff
path: root/src/llama-sampling.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/llama-sampling.h')
-rw-r--r--src/llama-sampling.h32
1 files changed, 31 insertions, 1 deletions
diff --git a/src/llama-sampling.h b/src/llama-sampling.h
index 69d92a3a..855278e2 100644
--- a/src/llama-sampling.h
+++ b/src/llama-sampling.h
@@ -1,7 +1,7 @@
#pragma once
#include "llama-impl.h"
-
+#include <unordered_map>
struct llama_sampling {
llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
@@ -35,6 +35,34 @@ void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_
void llama_sample_xtc_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float probability, float threshold, size_t min_keep);
void llama_sample_top_n_sigma_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float top_n_sigma);
+struct llama_sampler_dry {
+ int32_t total_context_size;
+
+ const float dry_multiplier;
+ const float dry_base;
+ const int32_t dry_allowed_length;
+ const int32_t dry_penalty_last_n;
+
+ std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
+ std::vector<int> dry_repeat_count;
+ std::unordered_map<llama_token, int> dry_max_token_repeat;
+ ring_buffer<llama_token> last_tokens;
+};
+
+struct llama_sampler_dry * llama_sampler_init_dry_impl(
+ const struct llama_vocab & vocab,
+ int32_t context_size,
+ float dry_multiplier,
+ float dry_base,
+ int32_t dry_allowed_length,
+ int32_t dry_penalty_last_n,
+ const char ** seq_breakers,
+ size_t num_breakers);
+
+void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p);
+
+
+
void llama_sample_repetition_penalties_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
@@ -56,3 +84,5 @@ llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, ll
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
+
+