diff options
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r-- | examples/server/server.cpp | 51 |
1 files changed, 36 insertions, 15 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 109ff717..59a59d56 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -659,7 +659,11 @@ struct server_context { bool load_model(const gpt_params & params_) { params = params_; + // dedicate one sequence to the system prompt + params.n_parallel += 1; + std::tie(model, ctx) = llama_init_from_gpt_params(params); + params.n_parallel -= 1; // but be sneaky about it if (model == nullptr) { LOG_ERROR("unable to load model", {{"model", params.model}}); return false; @@ -1018,8 +1022,8 @@ struct server_context { } // assign the system KV cache to all parallel sequences - for (int32_t i = 1; i < params.n_parallel; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, 0, system_tokens.size()); + for (int32_t i = 1; i <= params.n_parallel; ++i) { + llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); } } @@ -1306,7 +1310,7 @@ struct server_context { const int n_embd = llama_n_embd(model); for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) { continue; } @@ -1633,8 +1637,8 @@ struct server_context { {"n_cache_tokens", slot.cache_tokens.size()} }); - llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); + llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard); + llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); if (slot.params.cache_prompt) { for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { @@ -1666,7 +1670,7 @@ struct server_context { // TODO: we always have to take into account the "system_tokens" // this is not great and needs to be improved somehow - llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true); + llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true); slot.n_past += 1; @@ -1804,9 +1808,6 @@ struct server_context { // reuse any previously computed tokens that are common with the new prompt slot.n_past = common_part(slot.cache_tokens, prompt_tokens); - // remove the non-common part from the cache - slot.cache_tokens.resize(slot.n_past); - // push the prompt into the sampling context (do not apply grammar) for (int i = 0; i < slot.n_past; ++i) { llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false); @@ -1837,8 +1838,28 @@ struct server_context { } } - const int p0 = (int) system_tokens.size() + slot.n_past; - llama_kv_cache_seq_rm(ctx, slot.id, p0, -1); + // keep only the common part + int p0 = (int) system_tokens.size() + slot.n_past; + if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) { + // could not partially delete (likely using a non-Transformer model) + llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1); + + p0 = (int) system_tokens.size(); + if (p0 != 0) { + // copy over the system prompt when there is one + llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1); + } + + // there is no common part left (except for the system prompt) + slot.n_past = 0; + slot.n_past_se = 0; + slot.ga_i = 0; + // TODO: is the system prompt ever in the sampling context? + llama_sampling_reset(slot.ctx_sampling); + } + + // remove the non-common part from the cache + slot.cache_tokens.resize(slot.n_past); LOG_INFO("kv cache rm [p0, end)", { { "id_slot", slot.id }, @@ -1863,7 +1884,7 @@ struct server_context { } } - llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id }, false); + llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false); if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); @@ -1937,9 +1958,9 @@ struct server_context { LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); - llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i, slot.n_past_se, ib * bd); - llama_kv_cache_seq_div(ctx, slot.id, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); - llama_kv_cache_seq_add(ctx, slot.id, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd); + llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd); + llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); + llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd); slot.n_past_se -= bd; |