diff options
Diffstat (limited to 'examples/infill')
-rw-r--r-- | examples/infill/infill.cpp | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index d994de5e..187623f5 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -104,6 +104,7 @@ static void sigint_handler(int signo) { int main(int argc, char ** argv) { gpt_params params; + llama_sampling_params & sparams = params.sampling_params; g_params = ¶ms; if (!gpt_params_parse(argc, argv, params)) { @@ -206,7 +207,7 @@ int main(int argc, char ** argv) { // load the model and apply lora adapter, if any LOG("%s: load the model and apply lora adapter, if any\n", __func__); std::tie(model, ctx) = llama_init_from_gpt_params(params); - if (params.cfg_scale > 1.f) { + if (sparams.cfg_scale > 1.f) { struct llama_context_params lparams = llama_context_params_from_gpt_params(params); ctx_guidance = llama_new_context_with_model(model, lparams); } @@ -269,9 +270,9 @@ int main(int argc, char ** argv) { int guidance_offset = 0; int original_prompt_len = 0; if (ctx_guidance) { - LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(params.cfg_negative_prompt)); + LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt)); - guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos); + guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos); LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp)); std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos); @@ -312,7 +313,7 @@ int main(int argc, char ** argv) { if (ctx_guidance) { LOG_TEE("\n"); - LOG_TEE("%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str()); + LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str()); LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); for (int i = 0; i < (int) guidance_inp.size(); i++) { LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str()); @@ -358,7 +359,7 @@ int main(int argc, char ** argv) { } } LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", - params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); + sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("\n\n"); @@ -376,8 +377,8 @@ int main(int argc, char ** argv) { LOG_TEE("\n"); { - auto it = params.logit_bias.find(llama_token_eos(ctx)); - if (it != params.logit_bias.end() && it->second == -INFINITY) { + auto it = sparams.logit_bias.find(llama_token_eos(ctx)); + if (it != sparams.logit_bias.end() && it->second == -INFINITY) { LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); } } @@ -434,6 +435,7 @@ int main(int argc, char ** argv) { const int n_vocab = llama_n_vocab(model); + llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar); std::vector<llama_token_data> candidates; candidates.reserve(n_vocab); @@ -552,7 +554,7 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates); + const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates); last_tokens.erase(last_tokens.begin()); last_tokens.push_back(id); |