summaryrefslogtreecommitdiff
path: root/examples/save-load-state
diff options
context:
space:
mode:
Diffstat (limited to 'examples/save-load-state')
-rw-r--r--examples/save-load-state/save-load-state.cpp16
1 files changed, 8 insertions, 8 deletions
diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp
index 95527bb8..6e4d40b9 100644
--- a/examples/save-load-state/save-load-state.cpp
+++ b/examples/save-load-state/save-load-state.cpp
@@ -35,11 +35,11 @@ int main(int argc, char ** argv) {
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
// init
- auto model = llama_load_model_from_file(params.model.c_str(), lparams);
+ auto * model = llama_load_model_from_file(params.model.c_str(), lparams);
if (model == nullptr) {
return 1;
}
- auto ctx = llama_new_context_with_model(model, lparams);
+ auto * ctx = llama_new_context_with_model(model, lparams);
if (ctx == nullptr) {
llama_free_model(model);
return 1;
@@ -54,7 +54,7 @@ int main(int argc, char ** argv) {
}
// evaluate prompt
- llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads);
+ llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0), params.n_threads);
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
n_past += n_prompt_tokens;
@@ -78,7 +78,7 @@ int main(int argc, char ** argv) {
printf("\n%s", params.prompt.c_str());
for (auto i = 0; i < params.n_predict; i++) {
- auto logits = llama_get_logits(ctx);
+ auto * logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
@@ -91,7 +91,7 @@ int main(int argc, char ** argv) {
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str.c_str());
- if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
+ if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx);
llama_free_model(model);
@@ -106,7 +106,7 @@ int main(int argc, char ** argv) {
llama_free(ctx);
// make new context
- auto ctx2 = llama_new_context_with_model(model, lparams);
+ auto * ctx2 = llama_new_context_with_model(model, lparams);
// Load state (rng, logits, embedding and kv_cache) from file
{
@@ -138,7 +138,7 @@ int main(int argc, char ** argv) {
// second run
for (auto i = 0; i < params.n_predict; i++) {
- auto logits = llama_get_logits(ctx2);
+ auto * logits = llama_get_logits(ctx2);
auto n_vocab = llama_n_vocab(ctx2);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
@@ -151,7 +151,7 @@ int main(int argc, char ** argv) {
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str.c_str());
- if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
+ if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx2);
llama_free_model(model);