summaryrefslogtreecommitdiff
path: root/examples/save-load-state/save-load-state.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/save-load-state/save-load-state.cpp')
-rw-r--r--examples/save-load-state/save-load-state.cpp27
1 files changed, 17 insertions, 10 deletions
diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp
index 00c2277a..3ea7c790 100644
--- a/examples/save-load-state/save-load-state.cpp
+++ b/examples/save-load-state/save-load-state.cpp
@@ -28,10 +28,11 @@ int main(int argc, char ** argv) {
std::string result2;
// init
- llama_model * model;
- llama_context * ctx;
+ llama_init_result llama_init = llama_init_from_gpt_params(params);
+
+ llama_model * model = llama_init.model;
+ llama_context * ctx = llama_init.context;
- std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == nullptr || ctx == nullptr) {
fprintf(stderr, "%s : failed to init\n", __func__);
return 1;
@@ -47,7 +48,7 @@ int main(int argc, char ** argv) {
// save state (rng, logits, embedding and kv_cache) to file
{
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
- const size_t written = llama_state_get_data(ctx, state_mem.data());
+ const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());
FILE *fp_write = fopen("dump_state.bin", "wb");
fwrite(state_mem.data(), 1, written, fp_write);
@@ -99,13 +100,16 @@ int main(int argc, char ** argv) {
// load state (rng, logits, embedding and kv_cache) from file
{
- std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));
+ std::vector<uint8_t> state_mem;
FILE * fp_read = fopen("dump_state.bin", "rb");
+ fseek(fp_read, 0, SEEK_END);
+ state_mem.resize(ftell(fp_read));
+ fseek(fp_read, 0, SEEK_SET);
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);
- if (read != llama_state_set_data(ctx2, state_mem.data())) {
+ if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx2);
llama_free_model(model);
@@ -159,13 +163,16 @@ int main(int argc, char ** argv) {
// load state (rng, logits, embedding and kv_cache) from file
{
- std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));
+ std::vector<uint8_t> state_mem;
FILE * fp_read = fopen("dump_state.bin", "rb");
+ fseek(fp_read, 0, SEEK_END);
+ state_mem.resize(ftell(fp_read));
+ fseek(fp_read, 0, SEEK_SET);
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read);
- if (read != llama_state_set_data(ctx3, state_mem.data())) {
+ if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx3);
llama_free_model(model);
@@ -182,7 +189,7 @@ int main(int argc, char ** argv) {
{
// save kv of seq 0
std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
- const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
+ const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0);
if (ncopy != seq_store.size()) {
fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
llama_free(ctx3);
@@ -196,7 +203,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s : kv cache cleared\n", __func__);
// restore kv into seq 1
- const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
+ const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1);
if (nset != seq_store.size()) {
fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
llama_free(ctx3);