summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKerfuffle <44031344+KerfuffleV2@users.noreply.github.com>2023-10-29 11:31:40 -0600
committerGitHub <noreply@github.com>2023-10-29 11:31:40 -0600
commit6e08281e588bbba1a5d180290a94a43f167f3a1a (patch)
tree46add394417eb2b5929793ca879c793a478fd3f8
parent2046eb4345e62c4575b3cdc0115a51db89f3fb70 (diff)
Extend llama_kv_cache_seq_rm to allow matching any sequence (#3843)
* Extend llama_kv_cache_seq_rm to allow matichng any sequence * Replace llama_kv_cache_tokens_rm with llama_kv_cache_clear Use llama_kv_cache_clear for cache clearing Change calls to llama_kv_cache_tokens_rm that want to delete by position to use llama_kv_cache_seq_rm functionality
-rw-r--r--common/common.cpp2
-rw-r--r--examples/batched-bench/batched-bench.cpp2
-rw-r--r--examples/llama-bench/llama-bench.cpp4
-rw-r--r--examples/main/main.cpp2
-rw-r--r--examples/perplexity/perplexity.cpp6
-rw-r--r--examples/server/server.cpp2
-rw-r--r--llama.cpp29
-rw-r--r--llama.h15
8 files changed, 30 insertions, 32 deletions
diff --git a/common/common.cpp b/common/common.cpp
index f81f4d35..c187128d 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -889,7 +889,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
- llama_kv_cache_tokens_rm(lctx, -1, -1);
+ llama_kv_cache_clear(lctx);
llama_reset_timings(lctx);
}
diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp
index 43f9c971..533c55c1 100644
--- a/examples/batched-bench/batched-bench.cpp
+++ b/examples/batched-bench/batched-bench.cpp
@@ -185,7 +185,7 @@ int main(int argc, char ** argv) {
const auto t_pp_start = ggml_time_us();
- llama_kv_cache_tokens_rm(ctx, -1, -1);
+ llama_kv_cache_clear(ctx);
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index 20767d55..78039818 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -1037,7 +1037,7 @@ int main(int argc, char ** argv) {
test t(inst, lmodel, ctx);
- llama_kv_cache_tokens_rm(ctx, -1, -1);
+ llama_kv_cache_clear(ctx);
// warmup run
if (t.n_prompt > 0) {
@@ -1048,7 +1048,7 @@ int main(int argc, char ** argv) {
}
for (int i = 0; i < params.reps; i++) {
- llama_kv_cache_tokens_rm(ctx, -1, -1);
+ llama_kv_cache_clear(ctx);
uint64_t t_start = get_time_ns();
if (t.n_prompt > 0) {
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 3d9f670b..8a43b6ab 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -298,7 +298,7 @@ int main(int argc, char ** argv) {
}
// remove any "future" tokens that we might have inherited from the previous session
- llama_kv_cache_tokens_rm(ctx, n_matching_session_tokens, -1);
+ llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
}
LOGLN(
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp
index 3c2542e8..bd2c73d8 100644
--- a/examples/perplexity/perplexity.cpp
+++ b/examples/perplexity/perplexity.cpp
@@ -210,7 +210,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
- llama_kv_cache_tokens_rm(ctx, -1, -1);
+ llama_kv_cache_clear(ctx);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
@@ -339,7 +339,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
- llama_kv_cache_tokens_rm(ctx, -1, -1);
+ llama_kv_cache_clear(ctx);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
@@ -573,7 +573,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
}
// clear the KV cache
- llama_kv_cache_tokens_rm(ctx, -1, -1);
+ llama_kv_cache_clear(ctx);
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
if (logits.empty()) {
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 5b7e4139..c163c7f8 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -857,7 +857,7 @@ struct llama_server_context
void kv_cache_clear() {
// clear the entire KV cache
- llama_kv_cache_tokens_rm(ctx, -1, -1);
+ llama_kv_cache_clear(ctx);
clean_kv_cache = false;
}
diff --git a/llama.cpp b/llama.cpp
index d8510a5c..a4340d52 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1466,17 +1466,12 @@ static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
return 0;
}
-static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
- if (c0 < 0) c0 = 0;
- if (c1 < 0) c1 = cache.size;
-
- for (int32_t i = c0; i < c1; ++i) {
+static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
+ for (int32_t i = 0; i < cache.size; ++i) {
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
}
-
- // Searching for a free slot can start here since we know it will be empty.
- cache.head = uint32_t(c0);
+ cache.head = 0;
}
static void llama_kv_cache_seq_rm(
@@ -1490,8 +1485,14 @@ static void llama_kv_cache_seq_rm(
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
for (uint32_t i = 0; i < cache.size; ++i) {
- if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
- cache.cells[i].seq_id.erase(seq_id);
+ if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
+ if (seq_id < 0) {
+ cache.cells[i].seq_id.clear();
+ } else if (cache.cells[i].has_seq_id(seq_id)) {
+ cache.cells[i].seq_id.erase(seq_id);
+ } else {
+ continue;
+ }
if (cache.cells[i].seq_id.empty()) {
cache.cells[i].pos = -1;
if (new_head == cache.size) new_head = i;
@@ -9207,8 +9208,8 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
return ctx->kv_self.head;
}
-void llama_kv_cache_tokens_rm(struct llama_context * ctx, int32_t c0, int32_t c1) {
- llama_kv_cache_tokens_rm(ctx->kv_self, c0, c1);
+void llama_kv_cache_clear(struct llama_context * ctx) {
+ llama_kv_cache_clear(ctx->kv_self);
}
void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
@@ -9654,7 +9655,7 @@ int llama_eval(
llama_token * tokens,
int32_t n_tokens,
int n_past) {
- llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
+ llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1);
const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0));
if (ret < 0) {
@@ -9669,7 +9670,7 @@ int llama_eval_embd(
float * embd,
int32_t n_tokens,
int n_past) {
- llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
+ llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1);
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
diff --git a/llama.h b/llama.h
index 6927bd60..d727dbd9 100644
--- a/llama.h
+++ b/llama.h
@@ -334,17 +334,14 @@ extern "C" {
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
- // Remove all tokens data of cells in [c0, c1)
- // c0 < 0 : [0, c1]
- // c1 < 0 : [c0, inf)
- LLAMA_API void llama_kv_cache_tokens_rm(
- struct llama_context * ctx,
- int32_t c0,
- int32_t c1);
+ // Clear the KV cache
+ LLAMA_API void llama_kv_cache_clear(
+ struct llama_context * ctx);
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
- // p0 < 0 : [0, p1]
- // p1 < 0 : [p0, inf)
+ // seq_id < 0 : match any sequence
+ // p0 < 0 : [0, p1]
+ // p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_cache_seq_rm(
struct llama_context * ctx,
llama_seq_id seq_id,