diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-10-20 21:07:23 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-20 21:07:23 +0300 |
commit | d1031cf49c3b958b915fd558e23453471c29ac33 (patch) | |
tree | 14fa2bc6d54d5e27bd1e8bfd6fa4dbf894dbe6b9 /examples/main/main.cpp | |
parent | 8cf19d60dc93809db8e51fedc811595eed9134c5 (diff) |
sampling : refactor init to use llama_sampling_params (#3696)
* sampling : refactor init to use llama_sampling_params
* llama : combine repetition, frequency and presence penalties in 1 call
* examples : remove embd-input and gptneox-wip
* sampling : rename penalty params + reduce size of "prev" vector
* sampling : add llama_sampling_print helper
* sampling : hide prev behind API and apply #3661
ggml-ci
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r-- | examples/main/main.cpp | 28 |
1 files changed, 11 insertions, 17 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 1a5911c5..db5309af 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -108,7 +108,7 @@ int main(int argc, char ** argv) { if (!gpt_params_parse(argc, argv, params)) { return 1; } - llama_sampling_params & sparams = params.sampling_params; + llama_sampling_params & sparams = params.sparams; #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("main", "log")); @@ -415,8 +415,7 @@ int main(int argc, char ** argv) { } } } - LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", - sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau); + LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("\n\n"); @@ -459,7 +458,7 @@ int main(int argc, char ** argv) { std::vector<llama_token> embd; std::vector<llama_token> embd_guidance; - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict @@ -612,7 +611,7 @@ int main(int argc, char ** argv) { const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); - llama_sampling_accept(ctx_sampling, ctx, id); + llama_sampling_accept(ctx_sampling, ctx, id, true); LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); @@ -631,12 +630,9 @@ int main(int argc, char ** argv) { while ((int) embd_inp.size() > n_consumed) { embd.push_back(embd_inp[n_consumed]); - // GG: I'm not sure it's a good idea to push the prompt tokens into the sampling context - // Most likely will remove this in the future to avoid exposing "prev" - // Same thing is done in "server". If we stop pushing the prompt tokens, then the repetition - // penalty will be applied only based on the tokens generated by the model. - ctx_sampling->prev.erase(ctx_sampling->prev.begin()); - ctx_sampling->prev.push_back(embd_inp[n_consumed]); + // push the prompt in the sampling context in order to apply repetition penalties later + // for the prompt, we don't apply grammar rules + llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false); ++n_consumed; if ((int) embd.size() >= params.n_batch) { @@ -667,12 +663,10 @@ int main(int argc, char ** argv) { // if not currently processing queued inputs; if ((int) embd_inp.size() <= n_consumed) { - // check for reverse prompt + // check for reverse prompt in the last n_prev tokens if (!params.antiprompt.empty()) { - std::string last_output; - for (auto id : ctx_sampling->prev) { - last_output += llama_token_to_piece(ctx, id); - } + const int n_prev = 32; + const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev); is_antiprompt = false; // Check if each of the reverse prompts appears at the end of the output. @@ -699,7 +693,7 @@ int main(int argc, char ** argv) { } // deal with end of text token in interactive mode - if (ctx_sampling->prev.back() == llama_token_eos(ctx)) { + if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) { LOG("found EOS token\n"); if (params.interactive) { |