summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
Diffstat (limited to 'common')
-rw-r--r--common/common.cpp1
-rw-r--r--common/sampling.cpp5
-rw-r--r--common/sampling.h1
3 files changed, 6 insertions, 1 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 489462b5..10ef1182 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -1704,6 +1704,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
}
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
+ fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep);
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);
diff --git a/common/sampling.cpp b/common/sampling.cpp
index 611c327b..de4331a1 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -248,7 +248,10 @@ static llama_token llama_sampling_sample_impl(
llama_sample_temp(ctx_main, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
} else {
- sampler_queue(ctx_main, params, cur_p, 1);
+ // temperature sampling
+ size_t min_keep = std::max(1, params.min_keep);
+
+ sampler_queue(ctx_main, params, cur_p, min_keep);
id = llama_sample_token(ctx_main, &cur_p);
diff --git a/common/sampling.h b/common/sampling.h
index e1279a89..95d87539 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -22,6 +22,7 @@ enum class llama_sampler_type : char {
typedef struct llama_sampling_params {
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
+ int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled