diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/llama-sampling.cpp | 60 | ||||
-rw-r--r-- | src/llama-sampling.h | 1 | ||||
-rw-r--r-- | src/llama.cpp | 4 |
3 files changed, 64 insertions, 1 deletions
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 06f44b02..7a185c5b 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -435,7 +435,7 @@ void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array } void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float probability, float threshold, size_t min_keep) { - if (probability < 0 || threshold > 0.5f || candidates->size < 2) { + if (probability <= 0 || threshold > 0.5f || candidates->size < 2) { return; } GGML_ASSERT(smpl); @@ -468,6 +468,64 @@ void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array } +void llama_sample_top_n_sigma_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float top_n_sigma) { + + if (top_n_sigma <= 0.0f || candidates->size < 4) { + // top_n_sigma <= 0: disabled + // candidates->size < 4: no point in applying the transformation for fewer than 4 logits. + return; + } + + const int64_t t_start_sample_us = ggml_time_us(); + + float max = candidates->data[0].logit; + float mean = 0; + size_t count = 0; + for (int i = 0; i < (int)candidates->size; ++i) { + // Only count non-negative infinity values + if (candidates->data[i].logit != -INFINITY) { + max = std::max(max, candidates->data[i].logit); + mean += candidates->data[i].logit; + ++count; + } + } + if (count < 4) { + return; // again, tandard deviation is not well defined for so few logits (4 is actually pushing it) + } + mean /= count; + + float sigma2 = 0; + for (int i = 0; i < (int)candidates->size; ++i) { + if (candidates->data[i].logit != -INFINITY) { + float delta = candidates->data[i].logit - mean; + sigma2 += delta*delta; + } + } + float sigma = sqrtf(sigma2/count); + float thresh = max - top_n_sigma*sigma; + + int n_masked = 0; + for (int i = 0; i < (int)candidates->size; ++i) { + if (candidates->data[i].logit != -INFINITY && candidates->data[i].logit < thresh) { + candidates->data[i].logit = -INFINITY; + ++n_masked; + } + } + + // do we really want to compute softmax unconditionally? + // The following coresponds to mainline implementation with the minor optimization + // that we only call the relativly expensive softmax if we masked away some tokens. + if (n_masked > 0 || !candidates->sorted) { + llama_sample_softmax_impl(nullptr, candidates); + } + + if (smpl) { + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + smpl->n_sample++; + } +} + + void llama_sample_repetition_penalties_impl( struct llama_sampling * smpl, llama_token_data_array * candidates, diff --git a/src/llama-sampling.h b/src/llama-sampling.h index c2a9e45f..69d92a3a 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -33,6 +33,7 @@ void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_ void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp); void llama_sample_xtc_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float probability, float threshold, size_t min_keep); +void llama_sample_top_n_sigma_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float top_n_sigma); void llama_sample_repetition_penalties_impl( struct llama_sampling * smpl, diff --git a/src/llama.cpp b/src/llama.cpp index 90e342e1..be404500 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -23270,6 +23270,10 @@ void llama_sample_xtc(struct llama_context * ctx, llama_token_data_array * candi llama_sample_xtc_impl(ctx ? &ctx->sampling : nullptr, candidates_p, probability, threshold, min_keep); } +void llama_sample_top_n_sigma(struct llama_context * ctx, llama_token_data_array * candidates_p, float top_n_sigma) { + llama_sample_top_n_sigma_impl(ctx ? &ctx->sampling : nullptr, candidates_p, top_n_sigma); +} + void llama_sample_repetition_penalties( struct llama_context * ctx, llama_token_data_array * candidates, |