summaryrefslogtreecommitdiff
path: root/examples/main/main.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r--examples/main/main.cpp39
1 files changed, 15 insertions, 24 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index a632bea1..388e1f7d 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -143,7 +143,7 @@ int main(int argc, char ** argv) {
{
fprintf(stderr, "%s: testing memory usage for n_batch = %d, n_ctx = %d\n", __func__, params.n_batch, params.n_ctx);
- const std::vector<llama_token> tmp(params.n_batch, llama_token_bos());
+ const std::vector<llama_token> tmp(params.n_batch, llama_token_bos(ctx));
llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads);
}
@@ -191,10 +191,6 @@ int main(int argc, char ** argv) {
// tokenize the prompt
std::vector<llama_token> embd_inp;
-
- // Add a space in front of the first character to match OG llama tokenizer behavior
- params.prompt.insert(0, 1, ' ');
-
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
embd_inp = ::llama_tokenize(ctx, params.prompt, true);
} else {
@@ -270,15 +266,12 @@ int main(int argc, char ** argv) {
params.interactive = true;
}
- // determine newline token
- auto llama_token_newline = ::llama_tokenize(ctx, "\n", false);
-
if (params.verbose_prompt) {
fprintf(stderr, "\n");
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
for (int i = 0; i < (int) embd_inp.size(); i++) {
- fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]));
+ fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]).c_str());
}
if (ctx_guidance) {
@@ -286,14 +279,14 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str());
fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
for (int i = 0; i < (int) guidance_inp.size(); i++) {
- fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i]));
+ fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i]).c_str());
}
}
if (params.n_keep > 0) {
fprintf(stderr, "%s: static prompt based on n_keep: '", __func__);
for (int i = 0; i < params.n_keep; i++) {
- fprintf(stderr, "%s", llama_token_to_str(ctx, embd_inp[i]));
+ fprintf(stderr, "%s", llama_token_to_str(ctx, embd_inp[i]).c_str());
}
fprintf(stderr, "'\n");
}
@@ -311,7 +304,7 @@ int main(int argc, char ** argv) {
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
};
- SetConsoleCtrlHandler(static_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
+ SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
fprintf(stderr, "%s: interactive mode on.\n", __func__);
@@ -352,10 +345,9 @@ int main(int argc, char ** argv) {
fprintf(stderr, "\n");
{
- auto it = params.logit_bias.find(llama_token_eos());
+ auto it = params.logit_bias.find(llama_token_eos(ctx));
if (it != params.logit_bias.end() && it->second == -INFINITY) {
- fprintf(stderr,
- "%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
+ fprintf(stderr, "%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
}
}
@@ -405,7 +397,7 @@ int main(int argc, char ** argv) {
// do one empty run to warm up the model
{
- const std::vector<llama_token> tmp = { llama_token_bos(), };
+ const std::vector<llama_token> tmp = { llama_token_bos(ctx), };
llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads);
llama_reset_timings(ctx);
}
@@ -589,7 +581,7 @@ int main(int argc, char ** argv) {
}
// Apply penalties
- float nl_logit = logits[llama_token_nl()];
+ float nl_logit = logits[llama_token_nl(ctx)];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
llama_sample_repetition_penalty(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
@@ -598,7 +590,7 @@ int main(int argc, char ** argv) {
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl) {
- logits[llama_token_nl()] = nl_logit;
+ logits[llama_token_nl(ctx)] = nl_logit;
}
if (grammar != NULL) {
@@ -662,7 +654,7 @@ int main(int argc, char ** argv) {
// display text
if (input_echo) {
for (auto id : embd) {
- printf("%s", llama_token_to_str(ctx, id));
+ printf("%s", llama_token_to_str(ctx, id).c_str());
}
fflush(stdout);
}
@@ -704,7 +696,7 @@ int main(int argc, char ** argv) {
}
// deal with end of text token in interactive mode
- if (last_n_tokens.back() == llama_token_eos()) {
+ if (last_n_tokens.back() == llama_token_eos(ctx)) {
if (params.interactive) {
if (params.antiprompt.size() != 0) {
// tokenize and inject first reverse prompt
@@ -728,7 +720,7 @@ int main(int argc, char ** argv) {
}
if (params.input_prefix_bos) {
- embd_inp.push_back(llama_token_bos());
+ embd_inp.push_back(llama_token_bos(ctx));
}
std::string buffer;
@@ -782,8 +774,7 @@ int main(int argc, char ** argv) {
if (grammar != NULL) {
llama_grammar_free(grammar);
- std::vector<const llama_grammar_element *> grammar_rules(
- parsed_grammar.c_rules());
+ std::vector<const llama_grammar_element *> grammar_rules( parsed_grammar.c_rules());
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(),
parsed_grammar.symbol_ids.at("root"));
@@ -794,7 +785,7 @@ int main(int argc, char ** argv) {
}
// end of text token
- if (!embd.empty() && embd.back() == llama_token_eos() && !(params.instruct || params.interactive)) {
+ if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !(params.instruct || params.interactive)) {
fprintf(stderr, " [end of text]\n");
break;
}