summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-06-03 17:35:09 +0300
committerGitHub <noreply@github.com>2025-06-03 17:35:09 +0300
commitf6d5fbdc5780b6dca770c896b8463de3239c7f8b (patch)
tree5174cb76b596d23383a4434ab2179d8b5213512f
parentccb265c01676aad9ae5860ba50e74e61dfcd1cf8 (diff)
Adding top-n-sigma sampler (#489)
* Adding top-n-sigma sampler * Fix typos in XTC PR * Update README.md for main and server * More README * More README --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
-rw-r--r--common/common.cpp9
-rw-r--r--common/sampling.cpp22
-rw-r--r--common/sampling.h4
-rw-r--r--examples/main/README.md16
-rw-r--r--examples/server/README.md3
-rw-r--r--include/llama.h7
-rw-r--r--src/llama-sampling.cpp60
-rw-r--r--src/llama-sampling.h1
-rw-r--r--src/llama.cpp4
9 files changed, 115 insertions, 11 deletions
diff --git a/common/common.cpp b/common/common.cpp
index cefbf63f..232101e4 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -659,6 +659,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
sparams.xtc_threshold = std::stof(argv[i]);
return true;
}
+ if (arg == "--top-n-sigma") {
+ CHECK_ARG
+ sparams.top_n_sigma = std::stof(argv[i]);
+ return true;
+ }
if (arg == "--cfg-negative-prompt") {
CHECK_ARG
sparams.cfg_negative_prompt = argv[i];
@@ -1646,7 +1651,8 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " --mirostat-lr N", "Mirostat learning rate, parameter eta (default: %.1f)", (double)sparams.mirostat_eta });
options.push_back({ "*", " --mirostat-ent N", "Mirostat target entropy, parameter tau (default: %.1f)", (double)sparams.mirostat_tau });
options.push_back({ "*", " --xtc-probability p", "xtc probability (default: %.1f, 0.0 = disabled)", (double)sparams.xtc_probability });
- options.push_back({ "*", " --xtc-threshold t", "xtc threshold (default: %.1f, 0.0 = disabled)", (double)sparams.xtc_threshold});
+ options.push_back({ "*", " --xtc-threshold t", "xtc threshold (default: %.1f, >0.5 = disabled)", (double)sparams.xtc_threshold});
+ options.push_back({ "*", " --top-n-sigma t", "top-n-sigma parmeter (default: %.1f, 0.0 = disabled)", (double)sparams.top_n_sigma});
options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n"
"i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"
"or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" });
@@ -3410,6 +3416,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);
fprintf(stream, "xtc_probability: %f # default: 0.0\n", sparams.xtc_probability);
fprintf(stream, "xtc_threshold: %f # default: 0.0\n", sparams.xtc_threshold);
+ fprintf(stream, "top_n_sigma: %f # default: 0.0\n", sparams.top_n_sigma);
fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH);
fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 84691d93..4db12ee1 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -122,11 +122,11 @@ std::string llama_sampling_print(const llama_sampling_params & params) {
"\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n"
"\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n"
"\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f\n"
- "\txtc_probability = %.3f, xtc_threshold = %.3f",
+ "\txtc_probability = %.3f, xtc_threshold = %.3f, top_n_sigma = %.3f",
params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present,
params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp,
params.mirostat, params.mirostat_eta, params.mirostat_tau,
- params.xtc_probability, params.xtc_threshold);
+ params.xtc_probability, params.xtc_threshold, params.top_n_sigma);
return std::string(result);
}
@@ -156,6 +156,7 @@ std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) {
case llama_sampler_type::MIN_P: return "min_p";
case llama_sampler_type::TEMPERATURE: return "temperature";
case llama_sampler_type::XTC : return "xtc";
+ case llama_sampler_type::TOP_N_SIGMA: return "top_n_sigma";
default : return "";
}
}
@@ -168,6 +169,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
{"min_p", llama_sampler_type::MIN_P},
{"tfs_z", llama_sampler_type::TFS_Z},
{"xtc", llama_sampler_type::XTC},
+ {"top_n_sigma", llama_sampler_type::TOP_N_SIGMA},
{"temperature", llama_sampler_type::TEMPERATURE}
};
@@ -183,6 +185,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto
{"tfs-z", llama_sampler_type::TFS_Z},
{"tfs", llama_sampler_type::TFS_Z},
{"xtc", llama_sampler_type::XTC},
+ {"top-n-sigma", llama_sampler_type::TOP_N_SIGMA},
{"temp", llama_sampler_type::TEMPERATURE}
};
@@ -218,6 +221,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
{'m', llama_sampler_type::MIN_P},
{'f', llama_sampler_type::TFS_Z},
{'x', llama_sampler_type::XTC},
+ {'n', llama_sampler_type::TOP_N_SIGMA},
{'t', llama_sampler_type::TEMPERATURE}
};
@@ -248,16 +252,18 @@ static void sampler_queue(
const float typical_p = params.typical_p;
const float xtc_probability = params.xtc_probability;
const float xtc_threshold = params.xtc_threshold;
+ const float top_n_sigma = params.top_n_sigma;
const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence;
for (auto sampler_type : samplers_sequence) {
switch (sampler_type) {
- case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
- case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
- case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
- case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
- case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
- case llama_sampler_type::XTC : llama_sample_xtc (ctx_main, &cur_p, xtc_probability, xtc_threshold, min_keep); break;
+ case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break;
+ case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break;
+ case llama_sampler_type::TYPICAL_P : llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break;
+ case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break;
+ case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break;
+ case llama_sampler_type::XTC : llama_sample_xtc (ctx_main, &cur_p, xtc_probability, xtc_threshold, min_keep); break;
+ case llama_sampler_type::TOP_N_SIGMA: llama_sample_top_n_sigma(ctx_main, &cur_p, top_n_sigma); break;
case llama_sampler_type::TEMPERATURE:
if (dynatemp_range > 0) {
float dynatemp_min = std::max(0.0f, temp - dynatemp_range);
diff --git a/common/sampling.h b/common/sampling.h
index 163cdfca..99fb07ac 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -16,6 +16,7 @@ enum class llama_sampler_type : char {
MIN_P = 'm',
TFS_Z = 'f',
XTC = 'x',
+ TOP_N_SIGMA = 'n',
TYPICAL_P = 'y',
TEMPERATURE = 't'
};
@@ -41,7 +42,8 @@ typedef struct llama_sampling_params {
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
float xtc_probability = 0.0f; // xtc probability
- float xtc_threshold = 1.0f; // xtc threashold, disabled if > 0.5
+ float xtc_threshold = 1.0f; // xtc threshold, disabled if > 0.5
+ float top_n_sigma = 0.0f; // top-n-sigma
bool penalize_nl = false; // consider newlines as a repeatable token
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
diff --git a/examples/main/README.md b/examples/main/README.md
index 9396a34f..dd627456 100644
--- a/examples/main/README.md
+++ b/examples/main/README.md
@@ -239,6 +239,22 @@ The `--mirostat-ent` option sets the Mirostat target entropy (tau), which repres
Example usage: `--mirostat 2 --mirostat-lr 0.05 --mirostat-ent 3.0`
+### XTC Sampling (Exclude Top Choices)
+
+The function of this sampler is conrolled by `--xtc-probability` and `--xtc-threshold`. `--xtc-probability` takes values between
+0 and 1 (<=0 turns this sampler off) and defines the probability for randomly invoking the sampler. `--xtc-threshold`
+defines the token probability threshold. Tokens with probability greater than this threshold will be excluded from the sampling.
+The sampler is turned off for `threshold > 0.5`.
+
+- --xtc-probability p: xtc probability (default: 0.0 => disabled)
+- --xtc-threshold t : xtc threshold (default: 1.0 => disabled)
+
+### Top-n-sigma Sampling
+
+Sets all logits $L_i$ to $-\infty$ where $L_i < L_{\rm max} - n \sigma$. Here $L_{\rm max}$ is the maximum logit, $\sigma$ is the logit standard deviation, and $n$ (a floating point number) is the top-n-sigma parameter. Increasing $n$ increases the fraction of tokens considered for sampling. In the limit of $n$ close to zero, one effectively gets greedy sampling (only top probability token considered).
+
+- --top-n-sigma t top-n-sigma parmeter (default: 0.0 => disabled)
+
### Logit Bias
- `-l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS`: Modify the likelihood of a token appearing in the generated text completion.
diff --git a/examples/server/README.md b/examples/server/README.md
index cb1eb7c9..c5a4b3e7 100644
--- a/examples/server/README.md
+++ b/examples/server/README.md
@@ -98,6 +98,9 @@ sampling:
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)
--mirostat-lr N Mirostat learning rate, parameter eta (default: 0.1)
--mirostat-ent N Mirostat target entropy, parameter tau (default: 5.0)
+ --xtc-probability p xtc probability (default: 0.0 => disabled)
+ --xtc-threshold t xtc threshold (default: 1.0 => disabled)
+ --top-n-sigma t top-n-sigma parmeter (default: 0.0 => disabled)
-l TOKEN_ID(+/-)BIAS modifies the likelihood of token appearing in the completion,
i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',
or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'
diff --git a/include/llama.h b/include/llama.h
index 89526276..f1645228 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -1216,6 +1216,13 @@ extern "C" {
float threshold,
size_t min_keep);
+ /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641
+ LLAMA_API void llama_sample_top_n_sigma(
+ struct llama_context * ctx,
+ llama_token_data_array * candidates_p,
+ float top_n_sigma);
+
+
/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp
index 06f44b02..7a185c5b 100644
--- a/src/llama-sampling.cpp
+++ b/src/llama-sampling.cpp
@@ -435,7 +435,7 @@ void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array
}
void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float probability, float threshold, size_t min_keep) {
- if (probability < 0 || threshold > 0.5f || candidates->size < 2) {
+ if (probability <= 0 || threshold > 0.5f || candidates->size < 2) {
return;
}
GGML_ASSERT(smpl);
@@ -468,6 +468,64 @@ void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array
}
+void llama_sample_top_n_sigma_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float top_n_sigma) {
+
+ if (top_n_sigma <= 0.0f || candidates->size < 4) {
+ // top_n_sigma <= 0: disabled
+ // candidates->size < 4: no point in applying the transformation for fewer than 4 logits.
+ return;
+ }
+
+ const int64_t t_start_sample_us = ggml_time_us();
+
+ float max = candidates->data[0].logit;
+ float mean = 0;
+ size_t count = 0;
+ for (int i = 0; i < (int)candidates->size; ++i) {
+ // Only count non-negative infinity values
+ if (candidates->data[i].logit != -INFINITY) {
+ max = std::max(max, candidates->data[i].logit);
+ mean += candidates->data[i].logit;
+ ++count;
+ }
+ }
+ if (count < 4) {
+ return; // again, tandard deviation is not well defined for so few logits (4 is actually pushing it)
+ }
+ mean /= count;
+
+ float sigma2 = 0;
+ for (int i = 0; i < (int)candidates->size; ++i) {
+ if (candidates->data[i].logit != -INFINITY) {
+ float delta = candidates->data[i].logit - mean;
+ sigma2 += delta*delta;
+ }
+ }
+ float sigma = sqrtf(sigma2/count);
+ float thresh = max - top_n_sigma*sigma;
+
+ int n_masked = 0;
+ for (int i = 0; i < (int)candidates->size; ++i) {
+ if (candidates->data[i].logit != -INFINITY && candidates->data[i].logit < thresh) {
+ candidates->data[i].logit = -INFINITY;
+ ++n_masked;
+ }
+ }
+
+ // do we really want to compute softmax unconditionally?
+ // The following coresponds to mainline implementation with the minor optimization
+ // that we only call the relativly expensive softmax if we masked away some tokens.
+ if (n_masked > 0 || !candidates->sorted) {
+ llama_sample_softmax_impl(nullptr, candidates);
+ }
+
+ if (smpl) {
+ smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
+ smpl->n_sample++;
+ }
+}
+
+
void llama_sample_repetition_penalties_impl(
struct llama_sampling * smpl,
llama_token_data_array * candidates,
diff --git a/src/llama-sampling.h b/src/llama-sampling.h
index c2a9e45f..69d92a3a 100644
--- a/src/llama-sampling.h
+++ b/src/llama-sampling.h
@@ -33,6 +33,7 @@ void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_
void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
void llama_sample_xtc_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float probability, float threshold, size_t min_keep);
+void llama_sample_top_n_sigma_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float top_n_sigma);
void llama_sample_repetition_penalties_impl(
struct llama_sampling * smpl,
diff --git a/src/llama.cpp b/src/llama.cpp
index 90e342e1..be404500 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -23270,6 +23270,10 @@ void llama_sample_xtc(struct llama_context * ctx, llama_token_data_array * candi
llama_sample_xtc_impl(ctx ? &ctx->sampling : nullptr, candidates_p, probability, threshold, min_keep);
}
+void llama_sample_top_n_sigma(struct llama_context * ctx, llama_token_data_array * candidates_p, float top_n_sigma) {
+ llama_sample_top_n_sigma_impl(ctx ? &ctx->sampling : nullptr, candidates_p, top_n_sigma);
+}
+
void llama_sample_repetition_penalties(
struct llama_context * ctx,
llama_token_data_array * candidates,