From 66874d4fbcc7866377246efbcee938e8cc9c7d76 Mon Sep 17 00:00:00 2001 From: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> Date: Thu, 25 May 2023 20:18:01 -0600 Subject: Some improvements to loading the session with --prompt-cache (#1550) Improvements to loading the session with `--prompt-cache` in the `main` example. 1. Fix an issue where the `--seed` parameter was ignored when loading a cached prompt. 2. When loading a cached prompt, you previously had to specify the saved prompt (or a prefix of it) again. This pull changes that behavior to default to the prompt that was cached if a prompt wasn't specified by the user. --- examples/main/main.cpp | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) (limited to 'examples/main/main.cpp') diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 47b418d9..c7c59153 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -134,8 +134,6 @@ int main(int argc, char ** argv) { return 0; } - // Add a space in front of the first character to match OG llama tokenizer behavior - params.prompt.insert(0, 1, ' '); std::string path_session = params.path_prompt_cache; std::vector session_tokens; @@ -155,6 +153,7 @@ int main(int argc, char ** argv) { return 1; } session_tokens.resize(n_token_count_out); + llama_set_rng_seed(ctx, params.seed); fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size()); } else { @@ -163,7 +162,16 @@ int main(int argc, char ** argv) { } // tokenize the prompt - auto embd_inp = ::llama_tokenize(ctx, params.prompt, true); + std::vector embd_inp; + + if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { + // Add a space in front of the first character to match OG llama tokenizer behavior + params.prompt.insert(0, 1, ' '); + + embd_inp = ::llama_tokenize(ctx, params.prompt, true); + } else { + embd_inp = session_tokens; + } const int n_ctx = llama_n_ctx(ctx); @@ -181,7 +189,9 @@ int main(int argc, char ** argv) { } n_matching_session_tokens++; } - if (n_matching_session_tokens >= embd_inp.size()) { + if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) { + fprintf(stderr, "%s: using full prompt from session file\n", __func__); + } else if (n_matching_session_tokens >= embd_inp.size()) { fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__); } else if (n_matching_session_tokens < (embd_inp.size() / 2)) { fprintf(stderr, "%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n", -- cgit v1.2.3