diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/llama-sampling.cpp | 34 | ||||
-rw-r--r-- | src/llama-sampling.h | 1 | ||||
-rw-r--r-- | src/llama.cpp | 5 |
3 files changed, 40 insertions, 0 deletions
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 8910f6d6..06f44b02 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -434,6 +434,40 @@ void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array } } +void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float probability, float threshold, size_t min_keep) { + if (probability < 0 || threshold > 0.5f || candidates->size < 2) { + return; + } + GGML_ASSERT(smpl); + const int64_t t_start_sample_us = ggml_time_us(); + if (probability < 1) { + std::uniform_real_distribution<float> distribution(0.0f, 1.0f); + float chance = distribution(smpl->rng); + if (chance > probability) return; + } + + llama_sample_softmax_impl(nullptr, candidates); + + auto cur_size = candidates->size; + + int pos_last = 0; + + for (size_t i = 0; i < candidates->size; ++i) { + if (candidates->data[i].p >= threshold) { + pos_last = i; + } else break; + } + + if (candidates->size - pos_last >= min_keep && pos_last > 0) { + candidates->data += pos_last; + candidates->size -= pos_last; + } + + smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + smpl->n_sample++; + +} + void llama_sample_repetition_penalties_impl( struct llama_sampling * smpl, llama_token_data_array * candidates, diff --git a/src/llama-sampling.h b/src/llama-sampling.h index f7f8e3ef..c2a9e45f 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -32,6 +32,7 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_ void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp); +void llama_sample_xtc_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float probability, float threshold, size_t min_keep); void llama_sample_repetition_penalties_impl( struct llama_sampling * smpl, diff --git a/src/llama.cpp b/src/llama.cpp index 18c7cd0f..90e342e1 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -23265,6 +23265,11 @@ void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * cand llama_sample_temp_impl(ctx ? &ctx->sampling : nullptr, candidates_p, temp); } +void llama_sample_xtc(struct llama_context * ctx, llama_token_data_array * candidates_p, + float probability, float threshold, size_t min_keep) { + llama_sample_xtc_impl(ctx ? &ctx->sampling : nullptr, candidates_p, probability, threshold, min_keep); +} + void llama_sample_repetition_penalties( struct llama_context * ctx, llama_token_data_array * candidates, |