diff options
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r-- | examples/main/main.cpp | 39 |
1 files changed, 15 insertions, 24 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp index a632bea1..388e1f7d 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -143,7 +143,7 @@ int main(int argc, char ** argv) { { fprintf(stderr, "%s: testing memory usage for n_batch = %d, n_ctx = %d\n", __func__, params.n_batch, params.n_ctx); - const std::vector<llama_token> tmp(params.n_batch, llama_token_bos()); + const std::vector<llama_token> tmp(params.n_batch, llama_token_bos(ctx)); llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads); } @@ -191,10 +191,6 @@ int main(int argc, char ** argv) { // tokenize the prompt std::vector<llama_token> embd_inp; - - // Add a space in front of the first character to match OG llama tokenizer behavior - params.prompt.insert(0, 1, ' '); - if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { embd_inp = ::llama_tokenize(ctx, params.prompt, true); } else { @@ -270,15 +266,12 @@ int main(int argc, char ** argv) { params.interactive = true; } - // determine newline token - auto llama_token_newline = ::llama_tokenize(ctx, "\n", false); - if (params.verbose_prompt) { fprintf(stderr, "\n"); fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); for (int i = 0; i < (int) embd_inp.size(); i++) { - fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); + fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]).c_str()); } if (ctx_guidance) { @@ -286,14 +279,14 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str()); fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); for (int i = 0; i < (int) guidance_inp.size(); i++) { - fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i])); + fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i]).c_str()); } } if (params.n_keep > 0) { fprintf(stderr, "%s: static prompt based on n_keep: '", __func__); for (int i = 0; i < params.n_keep; i++) { - fprintf(stderr, "%s", llama_token_to_str(ctx, embd_inp[i])); + fprintf(stderr, "%s", llama_token_to_str(ctx, embd_inp[i]).c_str()); } fprintf(stderr, "'\n"); } @@ -311,7 +304,7 @@ int main(int argc, char ** argv) { auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; }; - SetConsoleCtrlHandler(static_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true); + SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true); #endif fprintf(stderr, "%s: interactive mode on.\n", __func__); @@ -352,10 +345,9 @@ int main(int argc, char ** argv) { fprintf(stderr, "\n"); { - auto it = params.logit_bias.find(llama_token_eos()); + auto it = params.logit_bias.find(llama_token_eos(ctx)); if (it != params.logit_bias.end() && it->second == -INFINITY) { - fprintf(stderr, - "%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); + fprintf(stderr, "%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); } } @@ -405,7 +397,7 @@ int main(int argc, char ** argv) { // do one empty run to warm up the model { - const std::vector<llama_token> tmp = { llama_token_bos(), }; + const std::vector<llama_token> tmp = { llama_token_bos(ctx), }; llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); llama_reset_timings(ctx); } @@ -589,7 +581,7 @@ int main(int argc, char ** argv) { } // Apply penalties - float nl_logit = logits[llama_token_nl()]; + float nl_logit = logits[llama_token_nl(ctx)]; auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); llama_sample_repetition_penalty(ctx, &candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, @@ -598,7 +590,7 @@ int main(int argc, char ** argv) { last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_repeat, alpha_frequency, alpha_presence); if (!penalize_nl) { - logits[llama_token_nl()] = nl_logit; + logits[llama_token_nl(ctx)] = nl_logit; } if (grammar != NULL) { @@ -662,7 +654,7 @@ int main(int argc, char ** argv) { // display text if (input_echo) { for (auto id : embd) { - printf("%s", llama_token_to_str(ctx, id)); + printf("%s", llama_token_to_str(ctx, id).c_str()); } fflush(stdout); } @@ -704,7 +696,7 @@ int main(int argc, char ** argv) { } // deal with end of text token in interactive mode - if (last_n_tokens.back() == llama_token_eos()) { + if (last_n_tokens.back() == llama_token_eos(ctx)) { if (params.interactive) { if (params.antiprompt.size() != 0) { // tokenize and inject first reverse prompt @@ -728,7 +720,7 @@ int main(int argc, char ** argv) { } if (params.input_prefix_bos) { - embd_inp.push_back(llama_token_bos()); + embd_inp.push_back(llama_token_bos(ctx)); } std::string buffer; @@ -782,8 +774,7 @@ int main(int argc, char ** argv) { if (grammar != NULL) { llama_grammar_free(grammar); - std::vector<const llama_grammar_element *> grammar_rules( - parsed_grammar.c_rules()); + std::vector<const llama_grammar_element *> grammar_rules( parsed_grammar.c_rules()); grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); @@ -794,7 +785,7 @@ int main(int argc, char ** argv) { } // end of text token - if (!embd.empty() && embd.back() == llama_token_eos() && !(params.instruct || params.interactive)) { + if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !(params.instruct || params.interactive)) { fprintf(stderr, " [end of text]\n"); break; } |