summaryrefslogtreecommitdiff
path: root/src/llama-sampling.h
diff options
context:
space:
mode:
authorfirecoperana <xuqiaowei1124@gmail.com>2025-06-19 02:24:53 -0500
committerGitHub <noreply@github.com>2025-06-19 10:24:53 +0300
commit3f111ad7bbb2d4f721332f9b2b344e48b3bbf9aa (patch)
treea3a17ee74e0436253e17f0d322320ed554d34b0a /src/llama-sampling.h
parentc5368148cf3af7a3694e0eb03d24a08326c01d12 (diff)
add dry sampler (#513)
* add dry sampler * use vocab instead of model in dry_init function * fix compile error for build test --------- Co-authored-by: firecoperana <firecoperana>
Diffstat (limited to 'src/llama-sampling.h')
-rw-r--r--src/llama-sampling.h32
1 files changed, 31 insertions, 1 deletions
diff --git a/src/llama-sampling.h b/src/llama-sampling.h
index 69d92a3a..855278e2 100644
--- a/src/llama-sampling.h
+++ b/src/llama-sampling.h
@@ -1,7 +1,7 @@
#pragma once
#include "llama-impl.h"
-
+#include <unordered_map>
struct llama_sampling {
llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
@@ -35,6 +35,34 @@ void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_
void llama_sample_xtc_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float probability, float threshold, size_t min_keep);
void llama_sample_top_n_sigma_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float top_n_sigma);
+struct llama_sampler_dry {
+ int32_t total_context_size;
+
+ const float dry_multiplier;
+ const float dry_base;
+ const int32_t dry_allowed_length;
+ const int32_t dry_penalty_last_n;
+
+ std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers;
+ std::vector<int> dry_repeat_count;
+ std::unordered_map<llama_token, int> dry_max_token_repeat;
+ ring_buffer<llama_token> last_tokens;
+};
+
+struct llama_sampler_dry * llama_sampler_init_dry_impl(
+ const struct llama_vocab & vocab,
+ int32_t context_size,
+ float dry_multiplier,
+ float dry_base,
+ int32_t dry_allowed_length,
+ int32_t dry_penalty_last_n,
+ const char ** seq_breakers,
+ size_t num_breakers);
+
+void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p);
+
+
+
void llama_sample_repetition_penalties_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
@@ -56,3 +84,5 @@ llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, ll
llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
+
+