summaryrefslogtreecommitdiff
path: root/examples/save-load-state
diff options
context:
space:
mode:
Diffstat (limited to 'examples/save-load-state')
-rw-r--r--examples/save-load-state/save-load-state.cpp34
1 files changed, 18 insertions, 16 deletions
diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp
index 39aa7f82..07dfa2c7 100644
--- a/examples/save-load-state/save-load-state.cpp
+++ b/examples/save-load-state/save-load-state.cpp
@@ -64,14 +64,15 @@ int main(int argc, char ** argv) {
// first run
printf("\n%s", params.prompt.c_str());
for (auto i = 0; i < params.n_predict; i++) {
- auto next_token = llama_sample_top_p_top_k(
- ctx,
- &last_n_tokens_data.back() - params.repeat_last_n,
- params.repeat_last_n,
- 40,
- 1.0,
- 1.0,
- 1.1);
+ auto logits = llama_get_logits(ctx);
+ auto n_vocab = llama_n_vocab(ctx);
+ std::vector<llama_token_data> candidates;
+ candidates.reserve(n_vocab);
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+ }
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+ auto next_token = llama_sample_token(ctx, &candidates_p);
auto next_token_str = llama_token_to_str(ctx, next_token);
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str);
@@ -106,14 +107,15 @@ int main(int argc, char ** argv) {
// second run
for (auto i = 0; i < params.n_predict; i++) {
- auto next_token = llama_sample_top_p_top_k(
- ctx2,
- &last_n_tokens_data.back() - params.repeat_last_n,
- params.repeat_last_n,
- 40,
- 1.0,
- 1.0,
- 1.1);
+ auto logits = llama_get_logits(ctx2);
+ auto n_vocab = llama_n_vocab(ctx2);
+ std::vector<llama_token_data> candidates;
+ candidates.reserve(n_vocab);
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
+ candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
+ }
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
+ auto next_token = llama_sample_token(ctx2, &candidates_p);
auto next_token_str = llama_token_to_str(ctx2, next_token);
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str);