summaryrefslogtreecommitdiff
path: root/examples/main/main.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r--examples/main/main.cpp36
1 files changed, 16 insertions, 20 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index d7811226..1ed543cb 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -124,7 +124,7 @@ int main(int argc, char ** argv) {
console::init(params.simple_io, params.use_color);
atexit([]() { console::cleanup(); });
- if (params.perplexity) {
+ if (params.logits_all) {
printf("\n************\n");
printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__);
printf("************\n\n");
@@ -200,15 +200,6 @@ int main(int argc, char ** argv) {
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
}
- // export the cgraph and exit
- if (params.export_cgraph) {
- llama_eval_export(ctx, "llama.ggml");
- llama_free(ctx);
- llama_free_model(model);
-
- return 0;
- }
-
std::string path_session = params.path_prompt_cache;
std::vector<llama_token> session_tokens;
@@ -508,17 +499,22 @@ int main(int argc, char ** argv) {
break;
}
- const int n_left = n_past - params.n_keep;
- LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d\n", n_past, n_left, n_ctx, params.n_keep);
+ const int n_left = n_past - params.n_keep - 1;
+ const int n_discard = n_left/2;
- // always keep the first token - BOS
- n_past = std::max(1, params.n_keep);
- n_past_guidance = std::max(1, params.n_keep + guidance_offset);
+ LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
+ n_past, n_left, n_ctx, params.n_keep, n_discard);
- LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
+ llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
+ llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
- // insert n_left/2 tokens at the start of embd from last_tokens
- embd.insert(embd.begin(), last_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_tokens.end() - embd.size());
+ n_past -= n_discard;
+
+ if (ctx_guidance) {
+ n_past_guidance -= n_discard;
+ }
+
+ LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
@@ -580,7 +576,7 @@ int main(int argc, char ** argv) {
for (int i = 0; i < input_size; i += params.n_batch) {
int n_eval = std::min(input_size - i, params.n_batch);
- if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) {
+ if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0), params.n_threads)) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}
@@ -597,7 +593,7 @@ int main(int argc, char ** argv) {
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
- if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) {
+ if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0), params.n_threads)) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}