summaryrefslogtreecommitdiff
path: root/examples/main/main.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r--examples/main/main.cpp12
1 files changed, 7 insertions, 5 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index f5d2f489..7555dffe 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -334,6 +334,8 @@ int main(int argc, char ** argv) {
// number of tokens to keep when resetting context
if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml) {
params.n_keep = (int)embd_inp.size();
+ } else {
+ params.n_keep += add_bos; // always keep the BOS token
}
// prefix & suffix for instruct mode
@@ -383,8 +385,8 @@ int main(int argc, char ** argv) {
}
}
- if (params.n_keep > 0) {
- LOG_TEE("%s: static prompt based on n_keep: '", __func__);
+ if (params.n_keep > add_bos) {
+ LOG_TEE("%s: static prompt based on n_keep: '", __func__);
for (int i = 0; i < params.n_keep; i++) {
LOG_TEE("%s", llama_token_to_piece(ctx, embd_inp[i]).c_str());
}
@@ -540,14 +542,14 @@ int main(int argc, char ** argv) {
break;
}
- const int n_left = n_past - params.n_keep - 1;
+ const int n_left = n_past - params.n_keep;
const int n_discard = n_left/2;
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
- llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
- llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
+ llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
+ llama_kv_cache_seq_shift(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
n_past -= n_discard;