summaryrefslogtreecommitdiff
path: root/examples/server/server.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r--examples/server/server.cpp35
1 files changed, 24 insertions, 11 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index ebd7f2fc..273eb36f 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -381,6 +381,10 @@ struct llama_server_context
// compare the evaluated prompt with the new prompt
n_past = common_part(embd, prompt_tokens);
+
+ // since #3228 we now have to manually manage the KV cache
+ llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx);
+
embd = prompt_tokens;
if (n_past == num_prompt_tokens)
{
@@ -411,19 +415,27 @@ struct llama_server_context
if (embd.size() >= (size_t)params.n_ctx)
{
- // Reset context
- const int n_left = (params.n_ctx - params.n_keep) / 2;
+ // Shift context
+
+ const int n_left = n_past - params.n_keep - 1;
+ const int n_discard = n_left/2;
+
+ llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
+ llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
+
+ for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++)
+ {
+ embd[i - n_discard] = embd[i];
+ }
+ embd.resize(embd.size() - n_discard);
+
+ n_past -= n_discard;
- std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + params.n_keep);
- new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
- embd = new_tokens;
- n_past = params.n_keep;
truncated = true;
LOG_VERBOSE("input truncated", {
{"n_ctx", params.n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
- {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
});
}
@@ -434,7 +446,8 @@ struct llama_server_context
{
n_eval = params.n_batch;
}
- if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads))
+
+ if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads))
{
LOG_ERROR("failed to eval", {
{"n_eval", n_eval},
@@ -523,13 +536,13 @@ struct llama_server_context
{
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
- llama_sample_temperature(ctx, &candidates_p, temp);
+ llama_sample_temp(ctx, &candidates_p, temp);
result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
}
else if (mirostat == 2)
{
static float mirostat_mu = 2.0f * mirostat_tau;
- llama_sample_temperature(ctx, &candidates_p, temp);
+ llama_sample_temp(ctx, &candidates_p, temp);
result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
}
else
@@ -540,7 +553,7 @@ struct llama_server_context
llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep);
llama_sample_typical(ctx, &candidates_p, typical_p, min_keep);
llama_sample_top_p(ctx, &candidates_p, top_p, min_keep);
- llama_sample_temperature(ctx, &candidates_p, temp);
+ llama_sample_temp(ctx, &candidates_p, temp);
result.tok = llama_sample_token(ctx, &candidates_p);
}
}