summaryrefslogtreecommitdiff
path: root/examples/infill/infill.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/infill/infill.cpp')
-rw-r--r--examples/infill/infill.cpp134
1 files changed, 9 insertions, 125 deletions
diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp
index 539f7818..0e4ec79c 100644
--- a/examples/infill/infill.cpp
+++ b/examples/infill/infill.cpp
@@ -107,6 +107,7 @@ int main(int argc, char ** argv) {
g_params = &params;
if (!gpt_params_parse(argc, argv, params)) {
+ gpt_params_print_usage(argc, argv, params);
return 1;
}
@@ -139,27 +140,6 @@ int main(int argc, char ** argv) {
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8;
}
- if (params.instruct) {
- printf("\n************\n");
- printf("%s: please use the 'main' tool for instruct mode\n", __func__);
- printf("************\n\n");
-
- return 0;
- }
- if (params.chatml) {
- printf("\n************\n");
- printf("%s: please use the 'main' tool for chatml mode\n", __func__);
- printf("************\n\n");
-
- return 0;
- }
- if (!params.antiprompt.empty()) {
- printf("\n************\n");
- printf("%s: please use the 'main' tool for antiprompt mode\n", __func__);
- printf("************\n\n");
-
- return 0;
- }
if (!params.interactive_first && (params.input_prefix.empty() && params.input_suffix.empty())) {
printf("\n************\n");
printf("%s: please use '--interactive_first' or specify '--in_prefix' and/or '--in_suffix'\n", __func__);
@@ -167,20 +147,6 @@ int main(int argc, char ** argv) {
return 0;
}
- if (params.random_prompt) {
- printf("\n************\n");
- printf("%s: please use the 'main' tool for random prompt mode\n", __func__);
- printf("************\n\n");
-
- return 0;
- }
- if (!params.path_prompt_cache.empty()) {
- printf("\n************\n");
- printf("%s: infill does not support prompt caching\n", __func__);
- printf("************\n\n");
-
- return 0;
- }
if (params.rope_freq_base != 0.0) {
LOG_TEE("%s: warning: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base);
@@ -207,17 +173,13 @@ int main(int argc, char ** argv) {
llama_model * model;
llama_context * ctx;
- llama_context * ctx_guidance = NULL;
+
g_model = &model;
g_ctx = &ctx;
// load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
std::tie(model, ctx) = llama_init_from_gpt_params(params);
- if (sparams.cfg_scale > 1.f) {
- struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
- ctx_guidance = llama_new_context_with_model(model, lparams);
- }
if (model == NULL) {
LOG_TEE("%s: error: unable to load model\n", __func__);
@@ -273,25 +235,6 @@ int main(int argc, char ** argv) {
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
}
- // Tokenize negative prompt
- std::vector<llama_token> guidance_inp;
- int guidance_offset = 0;
- int original_prompt_len = 0;
- if (ctx_guidance) {
- LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
-
- guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true);
- LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
-
- std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true);
- LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
-
- original_prompt_len = original_inp.size();
- guidance_offset = (int)guidance_inp.size() - original_prompt_len;
- LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
- LOG("guidance_offset: %s", log_tostr(guidance_offset));
- }
-
if ((int) embd_inp.size() > n_ctx - 4) {
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
return 1;
@@ -319,15 +262,6 @@ int main(int argc, char ** argv) {
LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
}
- if (ctx_guidance) {
- LOG_TEE("\n");
- LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str());
- LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
- for (int i = 0; i < (int) guidance_inp.size(); i++) {
- LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str());
- }
- }
-
if (params.n_keep > 0) {
LOG_TEE("%s: static prompt based on n_keep: '", __func__);
for (int i = 0; i < params.n_keep; i++) {
@@ -395,12 +329,11 @@ int main(int argc, char ** argv) {
is_interacting = params.interactive_first;
}
- bool input_echo = true;
+ bool input_echo = true;
- int n_past = 0;
- int n_remain = params.n_predict;
- int n_consumed = 0;
- int n_past_guidance = 0;
+ int n_past = 0;
+ int n_remain = params.n_predict;
+ int n_consumed = 0;
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
@@ -410,7 +343,6 @@ int main(int argc, char ** argv) {
console::set_display(console::prompt);
std::vector<llama_token> embd;
- std::vector<llama_token> embd_guidance;
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
@@ -436,7 +368,7 @@ int main(int argc, char ** argv) {
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
- if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) {
+ if (n_past + (int) embd.size() > n_ctx) {
if (params.n_predict == -2) {
LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
break;
@@ -453,11 +385,7 @@ int main(int argc, char ** argv) {
n_past -= n_discard;
- if (ctx_guidance) {
- n_past_guidance -= n_discard;
- }
-
- LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
+ LOG("after swap: n_past = %d\n", n_past);
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
@@ -465,45 +393,6 @@ int main(int argc, char ** argv) {
// evaluate tokens in batches
// embd is typically prepared beforehand to fit within a batch, but not always
-
- if (ctx_guidance) {
- int input_size = 0;
- llama_token * input_buf = NULL;
-
- if (n_past_guidance < (int) guidance_inp.size()) {
- // Guidance context should have the same data with these modifications:
- //
- // * Replace the initial prompt
- // * Shift everything by guidance_offset
- embd_guidance = guidance_inp;
- if (embd.begin() + original_prompt_len < embd.end()) {
- embd_guidance.insert(
- embd_guidance.end(),
- embd.begin() + original_prompt_len,
- embd.end()
- );
- }
-
- input_buf = embd_guidance.data();
- input_size = embd_guidance.size();
-
- LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str());
- } else {
- input_buf = embd.data();
- input_size = embd.size();
- }
-
- for (int i = 0; i < input_size; i += params.n_batch) {
- int n_eval = std::min(input_size - i, params.n_batch);
- if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) {
- LOG_TEE("%s : failed to eval\n", __func__);
- return 1;
- }
-
- n_past_guidance += n_eval;
- }
- }
-
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
int n_eval = (int) embd.size() - i;
if (n_eval > params.n_batch) {
@@ -525,11 +414,9 @@ int main(int argc, char ** argv) {
}
embd.clear();
- embd_guidance.clear();
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
-
- const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
+ const llama_token id = llama_sampling_sample(ctx_sampling, ctx, nullptr);
llama_sampling_accept(ctx_sampling, ctx, id, true);
@@ -583,7 +470,6 @@ int main(int argc, char ** argv) {
// if not currently processing queued inputs;
if ((int) embd_inp.size() <= n_consumed) {
-
// deal with eot token in infill mode
if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){
if (is_interacting && !params.interactive_first) {
@@ -644,7 +530,6 @@ int main(int argc, char ** argv) {
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
embd_inp.push_back(llama_token_middle(model));
embd.clear();
- embd_guidance.clear();
n_remain = params.n_predict;
n_past = 0;
n_consumed = 0;
@@ -751,7 +636,6 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx);
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
- if (ctx_guidance) { llama_free(ctx_guidance); }
llama_free(ctx);
llama_free_model(model);