summaryrefslogtreecommitdiff
path: root/src/llama-sampling.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/llama-sampling.cpp')
-rw-r--r--src/llama-sampling.cpp60
1 files changed, 59 insertions, 1 deletions
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp
index 06f44b02..7a185c5b 100644
--- a/src/llama-sampling.cpp
+++ b/src/llama-sampling.cpp
@@ -435,7 +435,7 @@ void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array
}
void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float probability, float threshold, size_t min_keep) {
- if (probability < 0 || threshold > 0.5f || candidates->size < 2) {
+ if (probability <= 0 || threshold > 0.5f || candidates->size < 2) {
return;
}
GGML_ASSERT(smpl);
@@ -468,6 +468,64 @@ void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array
}
+void llama_sample_top_n_sigma_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float top_n_sigma) {
+
+ if (top_n_sigma <= 0.0f || candidates->size < 4) {
+ // top_n_sigma <= 0: disabled
+ // candidates->size < 4: no point in applying the transformation for fewer than 4 logits.
+ return;
+ }
+
+ const int64_t t_start_sample_us = ggml_time_us();
+
+ float max = candidates->data[0].logit;
+ float mean = 0;
+ size_t count = 0;
+ for (int i = 0; i < (int)candidates->size; ++i) {
+ // Only count non-negative infinity values
+ if (candidates->data[i].logit != -INFINITY) {
+ max = std::max(max, candidates->data[i].logit);
+ mean += candidates->data[i].logit;
+ ++count;
+ }
+ }
+ if (count < 4) {
+ return; // again, tandard deviation is not well defined for so few logits (4 is actually pushing it)
+ }
+ mean /= count;
+
+ float sigma2 = 0;
+ for (int i = 0; i < (int)candidates->size; ++i) {
+ if (candidates->data[i].logit != -INFINITY) {
+ float delta = candidates->data[i].logit - mean;
+ sigma2 += delta*delta;
+ }
+ }
+ float sigma = sqrtf(sigma2/count);
+ float thresh = max - top_n_sigma*sigma;
+
+ int n_masked = 0;
+ for (int i = 0; i < (int)candidates->size; ++i) {
+ if (candidates->data[i].logit != -INFINITY && candidates->data[i].logit < thresh) {
+ candidates->data[i].logit = -INFINITY;
+ ++n_masked;
+ }
+ }
+
+ // do we really want to compute softmax unconditionally?
+ // The following coresponds to mainline implementation with the minor optimization
+ // that we only call the relativly expensive softmax if we masked away some tokens.
+ if (n_masked > 0 || !candidates->sorted) {
+ llama_sample_softmax_impl(nullptr, candidates);
+ }
+
+ if (smpl) {
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+ smpl->n_sample++;
+ }
+}
+
+
void llama_sample_repetition_penalties_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,