diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-10-20 21:07:23 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-20 21:07:23 +0300 |
commit | d1031cf49c3b958b915fd558e23453471c29ac33 (patch) | |
tree | 14fa2bc6d54d5e27bd1e8bfd6fa4dbf894dbe6b9 /examples/llava | |
parent | 8cf19d60dc93809db8e51fedc811595eed9134c5 (diff) |
sampling : refactor init to use llama_sampling_params (#3696)
* sampling : refactor init to use llama_sampling_params
* llama : combine repetition, frequency and presence penalties in 1 call
* examples : remove embd-input and gptneox-wip
* sampling : rename penalty params + reduce size of "prev" vector
* sampling : add llama_sampling_print helper
* sampling : hide prev behind API and apply #3661
ggml-ci
Diffstat (limited to 'examples/llava')
-rw-r--r-- | examples/llava/llava-utils.h | 58 |
1 files changed, 30 insertions, 28 deletions
diff --git a/examples/llava/llava-utils.h b/examples/llava/llava-utils.h index e050b59b..45b2b1ad 100644 --- a/examples/llava/llava-utils.h +++ b/examples/llava/llava-utils.h @@ -58,28 +58,30 @@ inline bool eval_string(struct llama_context * ctx_llama, const char* str, int n // TODO: use common/sampling.h inline llama_token sample_id(llama_context * ctx_llama, gpt_params & params) { - // out of user input, sample next token - const float temp = params.sampling_params.temp; - const int32_t top_k = params.sampling_params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx_llama)) : params.sampling_params.top_k; - const float top_p = params.sampling_params.top_p; - const float tfs_z = params.sampling_params.tfs_z; - const float typical_p = params.sampling_params.typical_p; - // const int32_t repeat_last_n = params.sampling_params.repeat_last_n < 0 ? n_ctx : params.sampling_params.repeat_last_n; - // const float repeat_penalty = params.sampling_params.repeat_penalty; - // const float alpha_presence = params.sampling_params.presence_penalty; - // const float alpha_frequency = params.sampling_params.frequency_penalty; - const int mirostat = params.sampling_params.mirostat; - const float mirostat_tau = params.sampling_params.mirostat_tau; - const float mirostat_eta = params.sampling_params.mirostat_eta; - // const bool penalize_nl = params.sampling_params.penalize_nl; + auto & sparams = params.sparams; + + // out of user input, sample next token + const float temp = sparams.temp; + const int32_t top_k = sparams.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx_llama)) : sparams.top_k; + const float top_p = sparams.top_p; + const float tfs_z = sparams.tfs_z; + const float typical_p = sparams.typical_p; + // const int32_t repeat_last_n = sparams.repeat_last_n < 0 ? n_ctx : sparams.repeat_last_n; + // const float repeat_penalty = sparams.repeat_penalty; + // const float alpha_presence = sparams.presence_penalty; + // const float alpha_frequency = sparams.frequency_penalty; + const int mirostat = sparams.mirostat; + const float mirostat_tau = sparams.mirostat_tau; + const float mirostat_eta = sparams.mirostat_eta; + // const bool penalize_nl = sparams.penalize_nl; llama_token id = 0; { auto logits = llama_get_logits(ctx_llama); auto n_vocab = llama_n_vocab(llama_get_model(ctx_llama)); - // Apply params.logit_bias map - for (auto it = params.sampling_params.logit_bias.begin(); it != params.sampling_params.logit_bias.end(); it++) { + // Apply params.logit_bias map + for (auto it = sparams.logit_bias.begin(); it != sparams.logit_bias.end(); it++) { logits[it->first] += it->second; } @@ -91,18 +93,18 @@ inline llama_token sample_id(llama_context * ctx_llama, gpt_params & params) { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - // TODO: Apply penalties - // float nl_logit = logits[llama_token_nl(ctx)]; - // auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); - // llama_sample_repetition_penalty(ctx, &candidates_p, - // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - // last_n_repeat, repeat_penalty); - // llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, - // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, - // last_n_repeat, alpha_frequency, alpha_presence); - // if (!penalize_nl) { - // logits[llama_token_nl(ctx)] = nl_logit; - // } + // TODO: Apply penalties + // float nl_logit = logits[llama_token_nl(ctx)]; + // auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); + // llama_sample_repetition_penalty(ctx, &candidates_p, + // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + // last_n_repeat, repeat_penalty); + // llama_sample_frequency_and_presence_penalties(ctx, &candidates_p, + // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, + // last_n_repeat, alpha_frequency, alpha_presence); + // if (!penalize_nl) { + // logits[llama_token_nl(ctx)] = nl_logit; + // } if (temp <= 0) { // Greedy sampling |