summaryrefslogtreecommitdiff
path: root/examples/main/main.cpp
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-10-20 21:07:23 +0300
committerGitHub <noreply@github.com>2023-10-20 21:07:23 +0300
commitd1031cf49c3b958b915fd558e23453471c29ac33 (patch)
tree14fa2bc6d54d5e27bd1e8bfd6fa4dbf894dbe6b9 /examples/main/main.cpp
parent8cf19d60dc93809db8e51fedc811595eed9134c5 (diff)
sampling : refactor init to use llama_sampling_params (#3696)
* sampling : refactor init to use llama_sampling_params * llama : combine repetition, frequency and presence penalties in 1 call * examples : remove embd-input and gptneox-wip * sampling : rename penalty params + reduce size of "prev" vector * sampling : add llama_sampling_print helper * sampling : hide prev behind API and apply #3661 ggml-ci
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r--examples/main/main.cpp28
1 files changed, 11 insertions, 17 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 1a5911c5..db5309af 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -108,7 +108,7 @@ int main(int argc, char ** argv) {
if (!gpt_params_parse(argc, argv, params)) {
return 1;
}
- llama_sampling_params & sparams = params.sampling_params;
+ llama_sampling_params & sparams = params.sparams;
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("main", "log"));
@@ -415,8 +415,7 @@ int main(int argc, char ** argv) {
}
}
}
- LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
- sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau);
+ LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n");
@@ -459,7 +458,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;
- struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);
+ struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
@@ -612,7 +611,7 @@ int main(int argc, char ** argv) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
- llama_sampling_accept(ctx_sampling, ctx, id);
+ llama_sampling_accept(ctx_sampling, ctx, id, true);
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
@@ -631,12 +630,9 @@ int main(int argc, char ** argv) {
while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]);
- // GG: I'm not sure it's a good idea to push the prompt tokens into the sampling context
- // Most likely will remove this in the future to avoid exposing "prev"
- // Same thing is done in "server". If we stop pushing the prompt tokens, then the repetition
- // penalty will be applied only based on the tokens generated by the model.
- ctx_sampling->prev.erase(ctx_sampling->prev.begin());
- ctx_sampling->prev.push_back(embd_inp[n_consumed]);
+ // push the prompt in the sampling context in order to apply repetition penalties later
+ // for the prompt, we don't apply grammar rules
+ llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
@@ -667,12 +663,10 @@ int main(int argc, char ** argv) {
// if not currently processing queued inputs;
if ((int) embd_inp.size() <= n_consumed) {
- // check for reverse prompt
+ // check for reverse prompt in the last n_prev tokens
if (!params.antiprompt.empty()) {
- std::string last_output;
- for (auto id : ctx_sampling->prev) {
- last_output += llama_token_to_piece(ctx, id);
- }
+ const int n_prev = 32;
+ const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev);
is_antiprompt = false;
// Check if each of the reverse prompts appears at the end of the output.
@@ -699,7 +693,7 @@ int main(int argc, char ** argv) {
}
// deal with end of text token in interactive mode
- if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
+ if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) {
LOG("found EOS token\n");
if (params.interactive) {