summaryrefslogtreecommitdiff
path: root/examples/main/main.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r--examples/main/main.cpp92
1 files changed, 26 insertions, 66 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 7313d06a..1a5911c5 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -3,7 +3,6 @@
#include "console.h"
#include "llama.h"
#include "build-info.h"
-#include "grammar-parser.h"
#include <cassert>
#include <cinttypes>
@@ -245,12 +244,12 @@ int main(int argc, char ** argv) {
}
LOG("prompt: \"%s\"\n", log_tostr(params.prompt));
- LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
+ LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
// Should not run without any tokens
if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(ctx));
- LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp));
+ LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
}
// Tokenize negative prompt
@@ -261,10 +260,10 @@ int main(int argc, char ** argv) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos, true);
- LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
+ LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
- LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp));
+ LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
original_prompt_len = original_inp.size();
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
@@ -323,8 +322,8 @@ int main(int argc, char ** argv) {
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos, true);
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true);
- LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx));
- LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx));
+ LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str());
+ LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str());
// in instruct mode, we inject a prefix and a suffix to each input by the user
if (params.instruct) {
@@ -421,35 +420,6 @@ int main(int argc, char ** argv) {
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n");
- struct llama_grammar * grammar = NULL;
- grammar_parser::parse_state parsed_grammar;
-
- if (!params.grammar.empty()) {
- parsed_grammar = grammar_parser::parse(params.grammar.c_str());
- // will be empty (default) if there are parse errors
- if (parsed_grammar.rules.empty()) {
- return 1;
- }
- LOG_TEE("%s: grammar:\n", __func__);
- grammar_parser::print_grammar(stderr, parsed_grammar);
- LOG_TEE("\n");
-
- {
- auto it = sparams.logit_bias.find(llama_token_eos(ctx));
- if (it != sparams.logit_bias.end() && it->second == -INFINITY) {
- LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
- }
- }
-
- std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
- grammar = llama_grammar_init(
- grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
- }
-
- // TODO: replace with ring-buffer
- std::vector<llama_token> last_tokens(n_ctx);
- std::fill(last_tokens.begin(), last_tokens.end(), 0);
-
if (params.interactive) {
const char *control_message;
if (params.multiline_input) {
@@ -489,11 +459,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd;
std::vector<llama_token> embd_guidance;
- const int n_vocab = llama_n_vocab(model);
-
- llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar);
- std::vector<llama_token_data> candidates;
- candidates.reserve(n_vocab);
+ struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
// predict
@@ -540,7 +506,7 @@ int main(int argc, char ** argv) {
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
- LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
+ LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
LOG("clear session path\n");
path_session.clear();
@@ -570,7 +536,6 @@ int main(int argc, char ** argv) {
// evaluate tokens in batches
// embd is typically prepared beforehand to fit within a batch, but not always
-
if (ctx_guidance) {
int input_size = 0;
llama_token * input_buf = NULL;
@@ -592,7 +557,7 @@ int main(int argc, char ** argv) {
input_buf = embd_guidance.data();
input_size = embd_guidance.size();
- LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance));
+ LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str());
} else {
input_buf = embd.data();
input_size = embd.size();
@@ -615,7 +580,7 @@ int main(int argc, char ** argv) {
n_eval = params.n_batch;
}
- LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
+ LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
@@ -645,12 +610,11 @@ int main(int argc, char ** argv) {
LOG("saved session to %s\n", path_session.c_str());
}
- const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates);
+ const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
- last_tokens.erase(last_tokens.begin());
- last_tokens.push_back(id);
+ llama_sampling_accept(ctx_sampling, ctx, id);
- LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_tokens));
+ LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
embd.push_back(id);
@@ -666,8 +630,14 @@ int main(int argc, char ** argv) {
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
while ((int) embd_inp.size() > n_consumed) {
embd.push_back(embd_inp[n_consumed]);
- last_tokens.erase(last_tokens.begin());
- last_tokens.push_back(embd_inp[n_consumed]);
+
+ // GG: I'm not sure it's a good idea to push the prompt tokens into the sampling context
+ // Most likely will remove this in the future to avoid exposing "prev"
+ // Same thing is done in "server". If we stop pushing the prompt tokens, then the repetition
+ // penalty will be applied only based on the tokens generated by the model.
+ ctx_sampling->prev.erase(ctx_sampling->prev.begin());
+ ctx_sampling->prev.push_back(embd_inp[n_consumed]);
+
++n_consumed;
if ((int) embd.size() >= params.n_batch) {
break;
@@ -700,7 +670,7 @@ int main(int argc, char ** argv) {
// check for reverse prompt
if (!params.antiprompt.empty()) {
std::string last_output;
- for (auto id : last_tokens) {
+ for (auto id : ctx_sampling->prev) {
last_output += llama_token_to_piece(ctx, id);
}
@@ -729,7 +699,7 @@ int main(int argc, char ** argv) {
}
// deal with end of text token in interactive mode
- if (last_tokens.back() == llama_token_eos(ctx)) {
+ if (ctx_sampling->prev.back() == llama_token_eos(ctx)) {
LOG("found EOS token\n");
if (params.interactive) {
@@ -801,7 +771,7 @@ int main(int argc, char ** argv) {
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
const auto line_inp = ::llama_tokenize(ctx, buffer, false, false);
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
- LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp));
+ LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());
embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end());
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
@@ -830,15 +800,7 @@ int main(int argc, char ** argv) {
if (n_past > 0) {
if (is_interacting) {
- // reset grammar state if we're restarting generation
- if (grammar != NULL) {
- llama_grammar_free(grammar);
-
- std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
- grammar = llama_grammar_init(
- grammar_rules.data(), grammar_rules.size(),
- parsed_grammar.symbol_ids.at("root"));
- }
+ llama_sampling_reset(ctx_sampling);
}
is_interacting = false;
}
@@ -870,9 +832,7 @@ int main(int argc, char ** argv) {
llama_free(ctx);
llama_free_model(model);
- if (grammar != NULL) {
- llama_grammar_free(grammar);
- }
+ llama_sampling_free(ctx_sampling);
llama_backend_free();
#ifndef LOG_DISABLE_LOGS