diff options
Diffstat (limited to 'examples/save-load-state/save-load-state.cpp')
-rw-r--r-- | examples/save-load-state/save-load-state.cpp | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 95527bb8..6e4d40b9 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -35,11 +35,11 @@ int main(int argc, char ** argv) { auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0); // init - auto model = llama_load_model_from_file(params.model.c_str(), lparams); + auto * model = llama_load_model_from_file(params.model.c_str(), lparams); if (model == nullptr) { return 1; } - auto ctx = llama_new_context_with_model(model, lparams); + auto * ctx = llama_new_context_with_model(model, lparams); if (ctx == nullptr) { llama_free_model(model); return 1; @@ -54,7 +54,7 @@ int main(int argc, char ** argv) { } // evaluate prompt - llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads); + llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0), params.n_threads); last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens); n_past += n_prompt_tokens; @@ -78,7 +78,7 @@ int main(int argc, char ** argv) { printf("\n%s", params.prompt.c_str()); for (auto i = 0; i < params.n_predict; i++) { - auto logits = llama_get_logits(ctx); + auto * logits = llama_get_logits(ctx); auto n_vocab = llama_n_vocab(ctx); std::vector<llama_token_data> candidates; candidates.reserve(n_vocab); @@ -91,7 +91,7 @@ int main(int argc, char ** argv) { last_n_tokens_data.push_back(next_token); printf("%s", next_token_str.c_str()); - if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx); llama_free_model(model); @@ -106,7 +106,7 @@ int main(int argc, char ** argv) { llama_free(ctx); // make new context - auto ctx2 = llama_new_context_with_model(model, lparams); + auto * ctx2 = llama_new_context_with_model(model, lparams); // Load state (rng, logits, embedding and kv_cache) from file { @@ -138,7 +138,7 @@ int main(int argc, char ** argv) { // second run for (auto i = 0; i < params.n_predict; i++) { - auto logits = llama_get_logits(ctx2); + auto * logits = llama_get_logits(ctx2); auto n_vocab = llama_n_vocab(ctx2); std::vector<llama_token_data> candidates; candidates.reserve(n_vocab); @@ -151,7 +151,7 @@ int main(int argc, char ** argv) { last_n_tokens_data.push_back(next_token); printf("%s", next_token_str.c_str()); - if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx2); llama_free_model(model); |