From d1031cf49c3b958b915fd558e23453471c29ac33 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 20 Oct 2023 21:07:23 +0300 Subject: sampling : refactor init to use llama_sampling_params (#3696) * sampling : refactor init to use llama_sampling_params * llama : combine repetition, frequency and presence penalties in 1 call * examples : remove embd-input and gptneox-wip * sampling : rename penalty params + reduce size of "prev" vector * sampling : add llama_sampling_print helper * sampling : hide prev behind API and apply #3661 ggml-ci --- examples/infill/infill.cpp | 67 ++++++++++++---------------------------------- 1 file changed, 17 insertions(+), 50 deletions(-) (limited to 'examples/infill/infill.cpp') diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 128d6708..6331335e 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -39,8 +39,8 @@ static gpt_params * g_params; static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; static std::vector * g_output_tokens; -static bool is_interacting = false; +static bool is_interacting = false; static void write_logfile( const llama_context * ctx, const gpt_params & params, const llama_model * model, @@ -104,7 +104,7 @@ static void sigint_handler(int signo) { int main(int argc, char ** argv) { gpt_params params; - llama_sampling_params & sparams = params.sampling_params; + llama_sampling_params & sparams = params.sparams; g_params = ¶ms; if (!gpt_params_parse(argc, argv, params)) { @@ -358,36 +358,10 @@ int main(int argc, char ** argv) { LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str()); } } - LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", - sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau); + LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("\n\n"); - struct llama_grammar * grammar = NULL; - grammar_parser::parse_state parsed_grammar; - - if (!params.grammar.empty()) { - parsed_grammar = grammar_parser::parse(params.grammar.c_str()); - // will be empty (default) if there are parse errors - if (parsed_grammar.rules.empty()) { - return 1; - } - LOG_TEE("%s: grammar:\n", __func__); - grammar_parser::print_grammar(stderr, parsed_grammar); - LOG_TEE("\n"); - - { - auto it = sparams.logit_bias.find(llama_token_eos(ctx)); - if (it != sparams.logit_bias.end() && it->second == -INFINITY) { - LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); - } - } - - std::vector grammar_rules(parsed_grammar.c_rules()); - grammar = llama_grammar_init( - grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); - } - LOG_TEE("\n##### Infill mode #####\n\n"); if (params.infill) { printf("\n************\n"); @@ -430,7 +404,7 @@ int main(int argc, char ** argv) { std::vector embd; std::vector embd_guidance; - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); while (n_remain != 0 || params.interactive) { // predict @@ -549,7 +523,7 @@ int main(int argc, char ** argv) { const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); - llama_sampling_accept(ctx_sampling, ctx, id); + llama_sampling_accept(ctx_sampling, ctx, id, true); LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); @@ -567,8 +541,11 @@ int main(int argc, char ** argv) { LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed); while ((int) embd_inp.size() > n_consumed) { embd.push_back(embd_inp[n_consumed]); - ctx_sampling->prev.erase(ctx_sampling->prev.begin()); - ctx_sampling->prev.push_back(embd_inp[n_consumed]); + + // push the prompt in the sampling context in order to apply repetition penalties later + // for the prompt, we don't apply grammar rules + llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false); + ++n_consumed; if ((int) embd.size() >= params.n_batch) { break; @@ -600,7 +577,7 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed) { // deal with eot token in infill mode - if ((ctx_sampling->prev.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){ + if ((llama_sampling_last(ctx_sampling) == llama_token_eot(ctx) || is_interacting) && params.interactive){ if(is_interacting && !params.interactive_first) { // print an eot token printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str()); @@ -617,7 +594,7 @@ int main(int argc, char ** argv) { buffer += line; } while (another_line); // check if we got an empty line, if so we use the old input - if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { + if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { params.input_prefix = buffer; } buffer.clear(); @@ -627,7 +604,7 @@ int main(int argc, char ** argv) { buffer += line; } while (another_line); // check if we got an empty line - if(!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { + if (!buffer.empty() && !(buffer.length() == 1 && buffer[0] == '\n')) { params.input_suffix = buffer; } buffer.clear(); @@ -640,7 +617,7 @@ int main(int argc, char ** argv) { process_escapes(params.input_suffix); } suff_rm_leading_spc = params.escape; - if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { + if (suff_rm_leading_spc && params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) { params.input_suffix.erase(0, 1); suff_rm_leading_spc = false; } @@ -667,7 +644,7 @@ int main(int argc, char ** argv) { is_interacting = false; } // deal with end of text token in interactive mode - else if (ctx_sampling->prev.back() == llama_token_eos(ctx)) { + else if (llama_sampling_last(ctx_sampling) == llama_token_eos(ctx)) { LOG("found EOS token\n"); if (params.interactive) { @@ -740,15 +717,7 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { - // reset grammar state if we're restarting generation - if (grammar != NULL) { - llama_grammar_free(grammar); - - std::vector grammar_rules(parsed_grammar.c_rules()); - grammar = llama_grammar_init( - grammar_rules.data(), grammar_rules.size(), - parsed_grammar.symbol_ids.at("root")); - } + llama_sampling_reset(ctx_sampling); } is_interacting = false; } @@ -778,9 +747,7 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); - if (grammar != NULL) { - llama_grammar_free(grammar); - } + llama_sampling_free(ctx_sampling); llama_backend_free(); #ifndef LOG_DISABLE_LOGS -- cgit v1.2.3