summaryrefslogtreecommitdiff
path: root/examples/main/main.cpp
diff options
context:
space:
mode:
authorEvan Jones <evan.q.jones@gmail.com>2023-05-10 11:37:14 -0400
committerGitHub <noreply@github.com>2023-05-10 11:37:14 -0400
commitcf348a60e0af3905acd1d297cb064b918265b7ac (patch)
treeb5480b47918c0d1f386db71a195028fd5ca095be /examples/main/main.cpp
parente6a46b0ed1884c77267dc70693183e3b7164e0e0 (diff)
main : add option to save full output to session (#1338)
* main : add option to save full output to session * split behavior into --session and --prompt-cache * restore original implementation with new names * PR comments * move the check for incompatible parameters to gpt_params_parse * Fix whitespace Co-authored-by: DannyDaemonic <DannyDaemonic@gmail.com> --------- Co-authored-by: DannyDaemonic <DannyDaemonic@gmail.com>
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r--examples/main/main.cpp20
1 files changed, 10 insertions, 10 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 6e1172a4..bd1c4ab5 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -139,7 +139,7 @@ int main(int argc, char ** argv) {
// Add a space in front of the first character to match OG llama tokenizer behavior
params.prompt.insert(0, 1, ' ');
- std::string path_session = params.path_session;
+ std::string path_session = params.path_prompt_cache;
std::vector<llama_token> session_tokens;
if (!path_session.empty()) {
@@ -292,14 +292,9 @@ int main(int argc, char ** argv) {
is_interacting = params.interactive_first;
}
- bool is_antiprompt = false;
- bool input_echo = true;
-
- // HACK - because session saving incurs a non-negligible delay, for now skip re-saving session
- // if we loaded a session with at least 75% similarity. It's currently just used to speed up the
- // initial prompt so it doesn't need to be an exact match.
- bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4);
-
+ bool is_antiprompt = false;
+ bool input_echo = true;
+ bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < embd_inp.size();
int n_past = 0;
int n_remain = params.n_predict;
@@ -328,7 +323,7 @@ int main(int argc, char ** argv) {
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
// stop saving session if we run out of context
- path_session = "";
+ path_session.clear();
//printf("\n---\n");
//printf("resetting: '");
@@ -603,6 +598,11 @@ int main(int argc, char ** argv) {
}
}
+ if (!path_session.empty() && params.prompt_cache_all) {
+ fprintf(stderr, "\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
+ llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
+ }
+
llama_print_timings(ctx);
llama_free(ctx);