summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-06-03 11:32:03 +0300
committerGitHub <noreply@github.com>2025-06-03 11:32:03 +0300
commitccb265c01676aad9ae5860ba50e74e61dfcd1cf8 (patch)
tree8e2d9303bd091c4d0015fce8402162346d998cca /src
parent4f8b05a0d76e6c5e47fe1f6c7bd079e0fe95dbba (diff)
Adding the XTC sampler (#486)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'src')
-rw-r--r--src/llama-sampling.cpp34
-rw-r--r--src/llama-sampling.h1
-rw-r--r--src/llama.cpp5
3 files changed, 40 insertions, 0 deletions
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp
index 8910f6d6..06f44b02 100644
--- a/src/llama-sampling.cpp
+++ b/src/llama-sampling.cpp
@@ -434,6 +434,40 @@ void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array
}
}
+void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float probability, float threshold, size_t min_keep) {
+ if (probability < 0 || threshold > 0.5f || candidates->size < 2) {
+ return;
+ }
+ GGML_ASSERT(smpl);
+ const int64_t t_start_sample_us = ggml_time_us();
+ if (probability < 1) {
+ std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
+ float chance = distribution(smpl->rng);
+ if (chance > probability) return;
+ }
+
+ llama_sample_softmax_impl(nullptr, candidates);
+
+ auto cur_size = candidates->size;
+
+ int pos_last = 0;
+
+ for (size_t i = 0; i < candidates->size; ++i) {
+ if (candidates->data[i].p >= threshold) {
+ pos_last = i;
+ } else break;
+ }
+
+ if (candidates->size - pos_last >= min_keep && pos_last > 0) {
+ candidates->data += pos_last;
+ candidates->size -= pos_last;
+ }
+
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+ smpl->n_sample++;
+
+}
+
void llama_sample_repetition_penalties_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
diff --git a/src/llama-sampling.h b/src/llama-sampling.h
index f7f8e3ef..c2a9e45f 100644
--- a/src/llama-sampling.h
+++ b/src/llama-sampling.h
@@ -32,6 +32,7 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_
void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
+void llama_sample_xtc_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float probability, float threshold, size_t min_keep);
void llama_sample_repetition_penalties_impl(
struct llama_sampling * smpl,
diff --git a/src/llama.cpp b/src/llama.cpp
index 18c7cd0f..90e342e1 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -23265,6 +23265,11 @@ void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * cand
llama_sample_temp_impl(ctx ? &ctx->sampling : nullptr, candidates_p, temp);
}
+void llama_sample_xtc(struct llama_context * ctx, llama_token_data_array * candidates_p,
+ float probability, float threshold, size_t min_keep) {
+ llama_sample_xtc_impl(ctx ? &ctx->sampling : nullptr, candidates_p, probability, threshold, min_keep);
+}
+
void llama_sample_repetition_penalties(
struct llama_context * ctx,
llama_token_data_array * candidates,