diff options
author | Jan Boon <jan.boon@kaetemi.be> | 2024-04-08 20:43:30 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-08 15:43:30 +0300 |
commit | beea6e1b16e783a0886e78dec01002a8c00db24d (patch) | |
tree | a7365b1e93145b78a8b4be72df959239aa8c0f0d /examples/save-load-state/save-load-state.cpp | |
parent | 87fb5b4234d4b9c56ac94cf7aa229c8fd7defdb0 (diff) |
llama : save and restore kv cache for single seq id (#6341)
* llama : save and restore kv cache for single seq id
* remove trailing whitespace
* respond error in case there's no space in the kv cache
* add kv seq save restore to test case
* add --slot-save-path arg to enable save restore and restrict save location
* Returning 0 for some cases, instead of asserting.
* cleanup error cases
* rename sequence state functions
* rename state get set functions
* add previous function names back in with DEPRECATED notice
* update doc
* adjust endpoints to preferred style
* fix restoring zero cell count
* handle seq rm return value
* unused param
* keep in the size check
* fix return types
* add server test case for slot save restore
* cleanup
* add cake
* cleanup style
* add special
* removing a whole sequence never fails
* move sequence state file functionality from server to llama to match session api and add version tags
* catch exceptions on save as well
* error log messages
* check types for stricter restore
* update server doc
* readme : update API changes date
* strict filename validation
* move include, reject bom as well
* also reject empty filename
* reject whitespace and trailing dot
---------
Co-authored-by: Martin Evans <martindevans@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/save-load-state/save-load-state.cpp')
-rw-r--r-- | examples/save-load-state/save-load-state.cpp | 101 |
1 files changed, 95 insertions, 6 deletions
diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index ef952e2b..c3b76688 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -24,6 +24,7 @@ int main(int argc, char ** argv) { std::string result0; std::string result1; + std::string result2; // init llama_model * model; @@ -44,8 +45,8 @@ int main(int argc, char ** argv) { // save state (rng, logits, embedding and kv_cache) to file { - std::vector<uint8_t> state_mem(llama_get_state_size(ctx)); - const size_t written = llama_copy_state_data(ctx, state_mem.data()); + std::vector<uint8_t> state_mem(llama_state_get_size(ctx)); + const size_t written = llama_state_get_data(ctx, state_mem.data()); FILE *fp_write = fopen("dump_state.bin", "wb"); fwrite(state_mem.data(), 1, written, fp_write); @@ -97,13 +98,13 @@ int main(int argc, char ** argv) { // load state (rng, logits, embedding and kv_cache) from file { - std::vector<uint8_t> state_mem(llama_get_state_size(ctx2)); + std::vector<uint8_t> state_mem(llama_state_get_size(ctx2)); FILE * fp_read = fopen("dump_state.bin", "rb"); const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); fclose(fp_read); - if (read != llama_set_state_data(ctx2, state_mem.data())) { + if (read != llama_state_set_data(ctx2, state_mem.data())) { fprintf(stderr, "\n%s : failed to read state\n", __func__); llama_free(ctx2); llama_free_model(model); @@ -141,16 +142,104 @@ int main(int argc, char ** argv) { n_past += 1; } - printf("\n"); + printf("\n\n"); llama_free(ctx2); - llama_free_model(model); if (result0 != result1) { fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__); return 1; } + // make new context + auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); + + printf("\nsingle seq run: %s", params.prompt.c_str()); + + // load state (rng, logits, embedding and kv_cache) from file + { + std::vector<uint8_t> state_mem(llama_state_get_size(ctx3)); + + FILE * fp_read = fopen("dump_state.bin", "rb"); + const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); + fclose(fp_read); + + if (read != llama_state_set_data(ctx3, state_mem.data())) { + fprintf(stderr, "\n%s : failed to read state\n", __func__); + llama_free(ctx3); + llama_free_model(model); + return 1; + } + + fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size()); + } + + // restore state (last tokens) + n_past = n_past_saved; + + // save seq 0 and load into seq 1 + { + // save kv of seq 0 + std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0)); + const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0); + if (ncopy != seq_store.size()) { + fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size()); + llama_free(ctx3); + llama_free_model(model); + return 1; + } + fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); + + // erase whole kv + llama_kv_cache_clear(ctx3); + fprintf(stderr, "%s : kv cache cleared\n", __func__); + + // restore kv into seq 1 + const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1); + if (nset != seq_store.size()) { + fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size()); + llama_free(ctx3); + llama_free_model(model); + return 1; + } + fprintf(stderr, "%s : seq 1 restored, %zd bytes\n", __func__, nset); + } + + // third run with seq 1 instead of 0 + for (auto i = 0; i < params.n_predict; i++) { + auto * logits = llama_get_logits(ctx3); + auto n_vocab = llama_n_vocab(model); + std::vector<llama_token_data> candidates; + candidates.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + auto next_token = llama_sample_token(ctx3, &candidates_p); + auto next_token_str = llama_token_to_piece(ctx3, next_token); + + printf("%s", next_token_str.c_str()); + result2 += next_token_str; + + if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1, n_past, 1))) { + fprintf(stderr, "\n%s : failed to evaluate\n", __func__); + llama_free(ctx3); + llama_free_model(model); + return 1; + } + n_past += 1; + } + + printf("\n"); + + llama_free(ctx3); + llama_free_model(model); + + if (result0 != result2) { + fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__); + return 1; + } + fprintf(stderr, "\n%s : success\n", __func__); return 0; |