summaryrefslogtreecommitdiff
path: root/examples/server/server.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r--examples/server/server.cpp51
1 files changed, 36 insertions, 15 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 109ff717..59a59d56 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -659,7 +659,11 @@ struct server_context {
bool load_model(const gpt_params & params_) {
params = params_;
+ // dedicate one sequence to the system prompt
+ params.n_parallel += 1;
+
std::tie(model, ctx) = llama_init_from_gpt_params(params);
+ params.n_parallel -= 1; // but be sneaky about it
if (model == nullptr) {
LOG_ERROR("unable to load model", {{"model", params.model}});
return false;
@@ -1018,8 +1022,8 @@ struct server_context {
}
// assign the system KV cache to all parallel sequences
- for (int32_t i = 1; i < params.n_parallel; ++i) {
- llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size());
+ for (int32_t i = 1; i <= params.n_parallel; ++i) {
+ llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}
}
@@ -1306,7 +1310,7 @@ struct server_context {
const int n_embd = llama_n_embd(model);
for (int i = 0; i < batch.n_tokens; ++i) {
- if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
+ if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
continue;
}
@@ -1633,8 +1637,8 @@ struct server_context {
{"n_cache_tokens", slot.cache_tokens.size()}
});
- llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
- llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
+ llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
+ llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
@@ -1666,7 +1670,7 @@ struct server_context {
// TODO: we always have to take into account the "system_tokens"
// this is not great and needs to be improved somehow
- llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
+ llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true);
slot.n_past += 1;
@@ -1804,9 +1808,6 @@ struct server_context {
// reuse any previously computed tokens that are common with the new prompt
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
- // remove the non-common part from the cache
- slot.cache_tokens.resize(slot.n_past);
-
// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
@@ -1837,8 +1838,28 @@ struct server_context {
}
}
- const int p0 = (int) system_tokens.size() + slot.n_past;
- llama_kv_cache_seq_rm(ctx, slot.id, p0, -1);
+ // keep only the common part
+ int p0 = (int) system_tokens.size() + slot.n_past;
+ if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
+ // could not partially delete (likely using a non-Transformer model)
+ llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
+
+ p0 = (int) system_tokens.size();
+ if (p0 != 0) {
+ // copy over the system prompt when there is one
+ llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
+ }
+
+ // there is no common part left (except for the system prompt)
+ slot.n_past = 0;
+ slot.n_past_se = 0;
+ slot.ga_i = 0;
+ // TODO: is the system prompt ever in the sampling context?
+ llama_sampling_reset(slot.ctx_sampling);
+ }
+
+ // remove the non-common part from the cache
+ slot.cache_tokens.resize(slot.n_past);
LOG_INFO("kv cache rm [p0, end)", {
{ "id_slot", slot.id },
@@ -1863,7 +1884,7 @@ struct server_context {
}
}
- llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false);
+ llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false);
if (slot.params.cache_prompt) {
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
@@ -1937,9 +1958,9 @@ struct server_context {
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
- llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd);
- llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
- llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
+ llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
+ llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
+ llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
slot.n_past_se -= bd;