summaryrefslogtreecommitdiff
path: root/examples/save-load-state/save-load-state.cpp
diff options
context:
space:
mode:
authorJan Boon <jan.boon@kaetemi.be>2024-04-08 20:43:30 +0800
committerGitHub <noreply@github.com>2024-04-08 15:43:30 +0300
commitbeea6e1b16e783a0886e78dec01002a8c00db24d (patch)
treea7365b1e93145b78a8b4be72df959239aa8c0f0d /examples/save-load-state/save-load-state.cpp
parent87fb5b4234d4b9c56ac94cf7aa229c8fd7defdb0 (diff)
llama : save and restore kv cache for single seq id (#6341)
* llama : save and restore kv cache for single seq id * remove trailing whitespace * respond error in case there's no space in the kv cache * add kv seq save restore to test case * add --slot-save-path arg to enable save restore and restrict save location * Returning 0 for some cases, instead of asserting. * cleanup error cases * rename sequence state functions * rename state get set functions * add previous function names back in with DEPRECATED notice * update doc * adjust endpoints to preferred style * fix restoring zero cell count * handle seq rm return value * unused param * keep in the size check * fix return types * add server test case for slot save restore * cleanup * add cake * cleanup style * add special * removing a whole sequence never fails * move sequence state file functionality from server to llama to match session api and add version tags * catch exceptions on save as well * error log messages * check types for stricter restore * update server doc * readme : update API changes date * strict filename validation * move include, reject bom as well * also reject empty filename * reject whitespace and trailing dot --------- Co-authored-by: Martin Evans <martindevans@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/save-load-state/save-load-state.cpp')
-rw-r--r--examples/save-load-state/save-load-state.cpp101
1 files changed, 95 insertions, 6 deletions
diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp
index ef952e2b..c3b76688 100644
--- a/examples/save-load-state/save-load-state.cpp
+++ b/examples/save-load-state/save-load-state.cpp
@@ -24,6 +24,7 @@ int main(int argc, char ** argv) {
std::string result0;
std::string result1;
+ std::string result2;
// init
llama_model * model;
@@ -44,8 +45,8 @@ int main(int argc, char ** argv) {
// save state (rng, logits, embedding and kv_cache) to file
{
- std::vector<uint8_t> state_mem(llama_get_state_size(ctx));
- const size_t written = llama_copy_state_data(ctx, state_mem.data());
+ std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
+ const size_t written = llama_state_get_data(ctx, state_mem.data());
FILE *fp_write = fopen("dump_state.bin", "wb");
fwrite(state_mem.data(), 1, written, fp_write);
@@ -97,13 +98,13 @@ int main(int argc, char ** argv) {
// load state (rng, logits, embedding and kv_cache) from file
{
- std::vector<uint8_t> state_mem(llama_get_state_size(ctx2));
+ std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));
FILE * fp_read = fopen("dump_state.bin", "rb");
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);
- if (read != llama_set_state_data(ctx2, state_mem.data())) {
+ if (read != llama_state_set_data(ctx2, state_mem.data())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx2);
llama_free_model(model);
@@ -141,16 +142,104 @@ int main(int argc, char ** argv) {
n_past += 1;
}
- printf("\n");
+ printf("\n\n");
llama_free(ctx2);
- llama_free_model(model);
if (result0 != result1) {
fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__);
return 1;
}
+ // make new context
+ auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
+
+ printf("\nsingle seq run: %s", params.prompt.c_str());
+
+ // load state (rng, logits, embedding and kv_cache) from file
+ {
+ std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));
+
+ FILE * fp_read = fopen("dump_state.bin", "rb");
+ const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
+ fclose(fp_read);
+
+ if (read != llama_state_set_data(ctx3, state_mem.data())) {
+ fprintf(stderr, "\n%s : failed to read state\n", __func__);
+ llama_free(ctx3);
+ llama_free_model(model);
+ return 1;
+ }
+
+ fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size());
+ }
+
+ // restore state (last tokens)
+ n_past = n_past_saved;
+
+ // save seq 0 and load into seq 1
+ {
+ // save kv of seq 0
+ std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
+ const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
+ if (ncopy != seq_store.size()) {
+ fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
+ llama_free(ctx3);
+ llama_free_model(model);
+ return 1;
+ }
+ fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
+
+ // erase whole kv
+ llama_kv_cache_clear(ctx3);
+ fprintf(stderr, "%s : kv cache cleared\n", __func__);
+
+ // restore kv into seq 1
+ const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
+ if (nset != seq_store.size()) {
+ fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
+ llama_free(ctx3);
+ llama_free_model(model);
+ return 1;
+ }
+ fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset);
+ }
+
+ // third run with seq 1 instead of 0
+ for (auto i = 0; i < params.n_predict; i++) {
+ auto * logits = llama_get_logits(ctx3);
+ auto n_vocab = llama_n_vocab(model);
+ std::vector<llama_token_data> candidates;
+ candidates.reserve(n_vocab);
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+ }
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+ auto next_token = llama_sample_token(ctx3, &candidates_p);
+ auto next_token_str = llama_token_to_piece(ctx3, next_token);
+
+ printf("%s", next_token_str.c_str());
+ result2 += next_token_str;
+
+ if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) {
+ fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
+ llama_free(ctx3);
+ llama_free_model(model);
+ return 1;
+ }
+ n_past += 1;
+ }
+
+ printf("\n");
+
+ llama_free(ctx3);
+ llama_free_model(model);
+
+ if (result0 != result2) {
+ fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__);
+ return 1;
+ }
+
fprintf(stderr, "\n%s : success\n", __func__);
return 0;