summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
Diffstat (limited to 'common')
-rw-r--r--common/common.cpp7
-rw-r--r--common/common.h3
-rw-r--r--common/sampling.cpp79
-rw-r--r--common/sampling.h7
4 files changed, 87 insertions, 9 deletions
diff --git a/common/common.cpp b/common/common.cpp
index dbe7e922..036a9813 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -513,12 +513,6 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_sequences = std::stoi(argv[i]);
- } else if (arg == "--p-accept" || arg == "-pa") {
- if (++i >= argc) {
- invalid_param = true;
- break;
- }
- params.p_accept = std::stof(argv[i]);
} else if (arg == "--p-split" || arg == "-ps") {
if (++i >= argc) {
invalid_param = true;
@@ -1044,7 +1038,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel);
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
- printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept);
printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
diff --git a/common/common.h b/common/common.h
index b2868833..977ce419 100644
--- a/common/common.h
+++ b/common/common.h
@@ -53,11 +53,10 @@ struct gpt_params {
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
- int32_t n_draft = 8; // number of tokens to draft during speculative decoding
+ int32_t n_draft = 5; // number of tokens to draft during speculative decoding
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_parallel = 1; // number of parallel sequences to decode
int32_t n_sequences = 1; // number of sequences to decode
- float p_accept = 0.5f; // speculative decoding accept probability
float p_split = 0.1f; // speculative decoding split probability
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
diff --git a/common/sampling.cpp b/common/sampling.cpp
index e67096be..823031fe 100644
--- a/common/sampling.cpp
+++ b/common/sampling.cpp
@@ -295,6 +295,77 @@ static llama_token llama_sampling_sample_impl(
return id;
}
+static llama_token_data_array llama_sample_probability_distribution_impl(
+ struct llama_sampling_context * ctx_sampling,
+ struct llama_context * ctx_main,
+ struct llama_context * ctx_cfg,
+ const int idx) {
+ const llama_sampling_params & params = ctx_sampling->params;
+
+ const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
+
+ const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
+ const float penalty_repeat = params.penalty_repeat;
+ const float penalty_freq = params.penalty_freq;
+ const float penalty_present = params.penalty_present;
+ const bool penalize_nl = params.penalize_nl;
+
+ auto & prev = ctx_sampling->prev;
+ auto & cur = ctx_sampling->cur;
+
+ // Get a pointer to the logits
+ float * logits = llama_get_logits_ith(ctx_main, idx);
+
+ // Declare original_logits at the beginning of the function scope
+ std::vector<float> original_logits;
+
+ // apply params.logit_bias map
+ for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
+ logits[it->first] += it->second;
+ }
+
+ if (ctx_cfg) {
+ float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx);
+ llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale);
+ }
+
+ cur.clear();
+
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+ }
+
+ llama_token_data_array cur_p = { cur.data(), cur.size(), false };
+
+ // apply penalties
+ const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev;
+ const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n);
+ if (penalty_tokens_used_size) {
+ const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))];
+
+ llama_sample_repetition_penalties(ctx_main, &cur_p,
+ penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size,
+ penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present);
+
+ if (!penalize_nl) {
+ for (size_t idx = 0; idx < cur_p.size; idx++) {
+ if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) {
+ cur_p.data[idx].logit = nl_logit;
+ break;
+ }
+ }
+ }
+ }
+
+ // apply grammar checks
+ if (ctx_sampling->grammar != NULL) {
+ llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
+ }
+
+ llama_sample_softmax(ctx_main, &cur_p);
+ return cur_p;
+}
+
llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
@@ -304,6 +375,14 @@ llama_token llama_sampling_sample(
return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, false);
}
+llama_token_data_array llama_sampling_probability_distribution(
+ struct llama_sampling_context * ctx_sampling,
+ struct llama_context * ctx_main,
+ struct llama_context * ctx_cfg,
+ const int idx) {
+ return llama_sample_probability_distribution_impl(ctx_sampling,ctx_main, ctx_cfg, idx);
+}
+
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
diff --git a/common/sampling.h b/common/sampling.h
index 95d87539..48b2459d 100644
--- a/common/sampling.h
+++ b/common/sampling.h
@@ -131,6 +131,13 @@ llama_token llama_sampling_sample(
struct llama_context * ctx_cfg,
int idx = 0);
+// returns the probability that token of given id will be sampled
+llama_token_data_array llama_sampling_probability_distribution(
+ struct llama_sampling_context * ctx_sampling,
+ struct llama_context * ctx_main,
+ struct llama_context * ctx_cfg,
+ int idx = 0);
+
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,