diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2024-01-08 11:14:04 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-08 11:14:04 +0200 |
commit | b0034d93ce2949ce7d9c098ca02e56f66cd484e2 (patch) | |
tree | 5409bc6706ff5cf5aabc109ea466b73d46bb9839 /llama.cpp | |
parent | b7e7982953f80a656e03feb5cfb17a17a173eb26 (diff) |
examples : add passkey test (#3856)
* examples : add passkey test
* passkey : better prints
* passkey : select pass key pos from CLI
* passkey : simplify n_past logic
* make : add passkey target
* passkey : add "self-extend"-like context extension (#4810)
* llama : "self-extend"-like context extension
* passkey : add comment
* passkey : add readme
Diffstat (limited to 'llama.cpp')
-rw-r--r-- | llama.cpp | 34 |
1 files changed, 34 insertions, 0 deletions
@@ -1903,6 +1903,28 @@ static void llama_kv_cache_seq_shift( cache.head = new_head != cache.size ? new_head : 0; } +static void llama_kv_cache_seq_div( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.has_shift = true; + + { + llama_pos p_old = cache.cells[i].pos; + cache.cells[i].pos /= d; + cache.cells[i].delta += cache.cells[i].pos - p_old; + } + } + } +} + // // model loading and saving // @@ -10140,9 +10162,21 @@ void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { } void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + if (delta == 0) { + return; + } + llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta); } +void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (d == 1) { + return; + } + + llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d); +} + // Returns the *maximum* size of the state size_t llama_get_state_size(const struct llama_context * ctx) { // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. |