summaryrefslogtreecommitdiff
path: root/examples/main/main.cpp
diff options
context:
space:
mode:
authorKerfuffle <44031344+KerfuffleV2@users.noreply.github.com>2023-05-25 20:18:01 -0600
committerGitHub <noreply@github.com>2023-05-25 20:18:01 -0600
commit66874d4fbcc7866377246efbcee938e8cc9c7d76 (patch)
tree503b2c4e6ddb6124342a218572166a587d8f3acb /examples/main/main.cpp
parent1fcdcc28b119a6608774d52de905931bd5f8a43d (diff)
Some improvements to loading the session with --prompt-cache (#1550)
Improvements to loading the session with `--prompt-cache` in the `main` example. 1. Fix an issue where the `--seed` parameter was ignored when loading a cached prompt. 2. When loading a cached prompt, you previously had to specify the saved prompt (or a prefix of it) again. This pull changes that behavior to default to the prompt that was cached if a prompt wasn't specified by the user.
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r--examples/main/main.cpp18
1 files changed, 14 insertions, 4 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 47b418d9..c7c59153 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -134,8 +134,6 @@ int main(int argc, char ** argv) {
return 0;
}
- // Add a space in front of the first character to match OG llama tokenizer behavior
- params.prompt.insert(0, 1, ' ');
std::string path_session = params.path_prompt_cache;
std::vector<llama_token> session_tokens;
@@ -155,6 +153,7 @@ int main(int argc, char ** argv) {
return 1;
}
session_tokens.resize(n_token_count_out);
+ llama_set_rng_seed(ctx, params.seed);
fprintf(stderr, "%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size());
} else {
@@ -163,7 +162,16 @@ int main(int argc, char ** argv) {
}
// tokenize the prompt
- auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
+ std::vector<llama_token> embd_inp;
+
+ if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
+ // Add a space in front of the first character to match OG llama tokenizer behavior
+ params.prompt.insert(0, 1, ' ');
+
+ embd_inp = ::llama_tokenize(ctx, params.prompt, true);
+ } else {
+ embd_inp = session_tokens;
+ }
const int n_ctx = llama_n_ctx(ctx);
@@ -181,7 +189,9 @@ int main(int argc, char ** argv) {
}
n_matching_session_tokens++;
}
- if (n_matching_session_tokens >= embd_inp.size()) {
+ if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) {
+ fprintf(stderr, "%s: using full prompt from session file\n", __func__);
+ } else if (n_matching_session_tokens >= embd_inp.size()) {
fprintf(stderr, "%s: session file has exact match for prompt!\n", __func__);
} else if (n_matching_session_tokens < (embd_inp.size() / 2)) {
fprintf(stderr, "%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n",