summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--llama.cpp27
1 files changed, 18 insertions, 9 deletions
diff --git a/llama.cpp b/llama.cpp
index 1d1db8fc..d8510a5c 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1552,14 +1552,14 @@ static void llama_kv_cache_seq_shift(
for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
- cache.cells[i].pos += delta;
+ cache.has_shift = true;
+ cache.cells[i].pos += delta;
+ cache.cells[i].delta += delta;
+
if (cache.cells[i].pos < 0) {
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
- } else {
- cache.has_shift = true;
- cache.cells[i].delta = delta;
}
}
}
@@ -6073,11 +6073,20 @@ static int llama_decode_internal(
#endif
// update the kv ring buffer
- lctx.kv_self.has_shift = false;
- lctx.kv_self.head += n_tokens;
- // Ensure kv cache head points to a valid index.
- if (lctx.kv_self.head >= lctx.kv_self.size) {
- lctx.kv_self.head = 0;
+ {
+ if (kv_self.has_shift) {
+ kv_self.has_shift = false;
+ for (uint32_t i = 0; i < kv_self.size; ++i) {
+ kv_self.cells[i].delta = 0;
+ }
+ }
+
+ kv_self.head += n_tokens;
+
+ // Ensure kv cache head points to a valid index.
+ if (kv_self.head >= kv_self.size) {
+ kv_self.head = 0;
+ }
}
#ifdef GGML_PERF