summaryrefslogtreecommitdiff
path: root/examples/infill/infill.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/infill/infill.cpp')
-rw-r--r--examples/infill/infill.cpp44
1 files changed, 18 insertions, 26 deletions
diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp
index 187623f5..128d6708 100644
--- a/examples/infill/infill.cpp
+++ b/examples/infill/infill.cpp
@@ -257,12 +257,12 @@ int main(int argc, char ** argv) {
LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix));
LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix));
- LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
+ LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
// Should not run without any tokens
if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(ctx));
- LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
+ LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
}
// Tokenize negative prompt
@@ -273,10 +273,10 @@ int main(int argc, char ** argv) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos);
- LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
+ LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
- LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp));
+ LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
original_prompt_len = original_inp.size();
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
@@ -294,8 +294,8 @@ int main(int argc, char ** argv) {
params.n_keep = (int)embd_inp.size();
}
- LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx));
- LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx));
+ LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str());
+ LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str());
// enable interactive mode if interactive start is specified
@@ -388,9 +388,6 @@ int main(int argc, char ** argv) {
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
- // TODO: replace with ring-buffer
- std::vector<llama_token> last_tokens(n_ctx);
- std::fill(last_tokens.begin(), last_tokens.end(), 0);
LOG_TEE("\n##### Infill mode #####\n\n");
if (params.infill) {
printf("\n************\n");
@@ -433,11 +430,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;
- const int n_vocab = llama_n_vocab(model);
-
- llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar);
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
+ struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);
while (n_remain != 0 || params.interactive) {
// predict
@@ -484,7 +477,7 @@ int main(int argc, char ** argv) {
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
- LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
+ LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
}
@@ -512,7 +505,7 @@ int main(int argc, char ** argv) {
input_buf = embd_guidance.data();
input_size = embd_guidance.size();
- LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance));
+ LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str());
} else {
input_buf = embd.data();
input_size = embd.size();
@@ -535,7 +528,7 @@ int main(int argc, char ** argv) {
n_eval = params.n_batch;
}
- LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
+ LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
@@ -554,12 +547,11 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
- const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates);
+ const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
- last_tokens.erase(last_tokens.begin());
- last_tokens.push_back(id);
+ llama_sampling_accept(ctx_sampling, ctx, id);
- LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_tokens));
+ LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
embd.push_back(id);
@@ -575,8 +567,8 @@ int main(int argc, char ** argv) {
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]);
- last_tokens.erase(last_tokens.begin());
- last_tokens.push_back(embd_inp[n_consumed]);
+ ctx_sampling->prev.erase(ctx_sampling->prev.begin());
+ ctx_sampling->prev.push_back(embd_inp[n_consumed]);
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
break;
@@ -608,7 +600,7 @@ int main(int argc, char ** argv) {
if ((int) embd_inp.size() <= n_consumed) {
// deal with eot token in infill mode
- if ((last_tokens.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
+ if ((ctx_sampling->prev.back() == llama_token_eot(ctx) || is_interacting) && params.interactive){
if(is_interacting && !params.interactive_first) {
// print an eot token
printf("%s", llama_token_to_piece(ctx, llama_token_eot(ctx)).c_str());
@@ -675,7 +667,7 @@ int main(int argc, char ** argv) {
is_interacting = false;
}
// deal with end of text token in interactive mode
- else if (last_tokens.back() == llama_token_eos(ctx)) {
+ else if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
LOG("found EOS token\n");
if (params.interactive) {
@@ -727,7 +719,7 @@ int main(int argc, char ** argv) {
const size_t original_size = embd_inp.size();
const auto line_inp = ::llama_tokenize(ctx, buffer, false);
- LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp));
+ LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());