diff options
Diffstat (limited to 'examples/infill/infill.cpp')
-rw-r--r-- | examples/infill/infill.cpp | 134 |
1 files changed, 9 insertions, 125 deletions
diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 539f7818..0e4ec79c 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -107,6 +107,7 @@ int main(int argc, char ** argv) { g_params = ¶ms; if (!gpt_params_parse(argc, argv, params)) { + gpt_params_print_usage(argc, argv, params); return 1; } @@ -139,27 +140,6 @@ int main(int argc, char ** argv) { LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__); params.n_ctx = 8; } - if (params.instruct) { - printf("\n************\n"); - printf("%s: please use the 'main' tool for instruct mode\n", __func__); - printf("************\n\n"); - - return 0; - } - if (params.chatml) { - printf("\n************\n"); - printf("%s: please use the 'main' tool for chatml mode\n", __func__); - printf("************\n\n"); - - return 0; - } - if (!params.antiprompt.empty()) { - printf("\n************\n"); - printf("%s: please use the 'main' tool for antiprompt mode\n", __func__); - printf("************\n\n"); - - return 0; - } if (!params.interactive_first && (params.input_prefix.empty() && params.input_suffix.empty())) { printf("\n************\n"); printf("%s: please use '--interactive_first' or specify '--in_prefix' and/or '--in_suffix'\n", __func__); @@ -167,20 +147,6 @@ int main(int argc, char ** argv) { return 0; } - if (params.random_prompt) { - printf("\n************\n"); - printf("%s: please use the 'main' tool for random prompt mode\n", __func__); - printf("************\n\n"); - - return 0; - } - if (!params.path_prompt_cache.empty()) { - printf("\n************\n"); - printf("%s: infill does not support prompt caching\n", __func__); - printf("************\n\n"); - - return 0; - } if (params.rope_freq_base != 0.0) { LOG_TEE("%s: warning: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base); @@ -207,17 +173,13 @@ int main(int argc, char ** argv) { llama_model * model; llama_context * ctx; - llama_context * ctx_guidance = NULL; + g_model = &model; g_ctx = &ctx; // load the model and apply lora adapter, if any LOG("%s: load the model and apply lora adapter, if any\n", __func__); std::tie(model, ctx) = llama_init_from_gpt_params(params); - if (sparams.cfg_scale > 1.f) { - struct llama_context_params lparams = llama_context_params_from_gpt_params(params); - ctx_guidance = llama_new_context_with_model(model, lparams); - } if (model == NULL) { LOG_TEE("%s: error: unable to load model\n", __func__); @@ -273,25 +235,6 @@ int main(int argc, char ** argv) { LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); } - // Tokenize negative prompt - std::vector<llama_token> guidance_inp; - int guidance_offset = 0; - int original_prompt_len = 0; - if (ctx_guidance) { - LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt)); - - guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true); - LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str()); - - std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true); - LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str()); - - original_prompt_len = original_inp.size(); - guidance_offset = (int)guidance_inp.size() - original_prompt_len; - LOG("original_prompt_len: %s", log_tostr(original_prompt_len)); - LOG("guidance_offset: %s", log_tostr(guidance_offset)); - } - if ((int) embd_inp.size() > n_ctx - 4) { LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4); return 1; @@ -319,15 +262,6 @@ int main(int argc, char ** argv) { LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str()); } - if (ctx_guidance) { - LOG_TEE("\n"); - LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str()); - LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); - for (int i = 0; i < (int) guidance_inp.size(); i++) { - LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str()); - } - } - if (params.n_keep > 0) { LOG_TEE("%s: static prompt based on n_keep: '", __func__); for (int i = 0; i < params.n_keep; i++) { @@ -395,12 +329,11 @@ int main(int argc, char ** argv) { is_interacting = params.interactive_first; } - bool input_echo = true; + bool input_echo = true; - int n_past = 0; - int n_remain = params.n_predict; - int n_consumed = 0; - int n_past_guidance = 0; + int n_past = 0; + int n_remain = params.n_predict; + int n_consumed = 0; std::vector<int> input_tokens; g_input_tokens = &input_tokens; std::vector<int> output_tokens; g_output_tokens = &output_tokens; @@ -410,7 +343,6 @@ int main(int argc, char ** argv) { console::set_display(console::prompt); std::vector<llama_token> embd; - std::vector<llama_token> embd_guidance; struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); @@ -436,7 +368,7 @@ int main(int argc, char ** argv) { // if we run out of context: // - take the n_keep first tokens from the original prompt (via n_past) // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches - if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) { + if (n_past + (int) embd.size() > n_ctx) { if (params.n_predict == -2) { LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); break; @@ -453,11 +385,7 @@ int main(int argc, char ** argv) { n_past -= n_discard; - if (ctx_guidance) { - n_past_guidance -= n_discard; - } - - LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); + LOG("after swap: n_past = %d\n", n_past); LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); @@ -465,45 +393,6 @@ int main(int argc, char ** argv) { // evaluate tokens in batches // embd is typically prepared beforehand to fit within a batch, but not always - - if (ctx_guidance) { - int input_size = 0; - llama_token * input_buf = NULL; - - if (n_past_guidance < (int) guidance_inp.size()) { - // Guidance context should have the same data with these modifications: - // - // * Replace the initial prompt - // * Shift everything by guidance_offset - embd_guidance = guidance_inp; - if (embd.begin() + original_prompt_len < embd.end()) { - embd_guidance.insert( - embd_guidance.end(), - embd.begin() + original_prompt_len, - embd.end() - ); - } - - input_buf = embd_guidance.data(); - input_size = embd_guidance.size(); - - LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str()); - } else { - input_buf = embd.data(); - input_size = embd.size(); - } - - for (int i = 0; i < input_size; i += params.n_batch) { - int n_eval = std::min(input_size - i, params.n_batch); - if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) { - LOG_TEE("%s : failed to eval\n", __func__); - return 1; - } - - n_past_guidance += n_eval; - } - } - for (int i = 0; i < (int) embd.size(); i += params.n_batch) { int n_eval = (int) embd.size() - i; if (n_eval > params.n_batch) { @@ -525,11 +414,9 @@ int main(int argc, char ** argv) { } embd.clear(); - embd_guidance.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - - const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); + const llama_token id = llama_sampling_sample(ctx_sampling, ctx, nullptr); llama_sampling_accept(ctx_sampling, ctx, id, true); @@ -583,7 +470,6 @@ int main(int argc, char ** argv) { // if not currently processing queued inputs; if ((int) embd_inp.size() <= n_consumed) { - // deal with eot token in infill mode if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){ if (is_interacting && !params.interactive_first) { @@ -644,7 +530,6 @@ int main(int argc, char ** argv) { embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); embd_inp.push_back(llama_token_middle(model)); embd.clear(); - embd_guidance.clear(); n_remain = params.n_predict; n_past = 0; n_consumed = 0; @@ -751,7 +636,6 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); - if (ctx_guidance) { llama_free(ctx_guidance); } llama_free(ctx); llama_free_model(model); |