summaryrefslogtreecommitdiff
path: root/examples/server/server.cpp
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-09-28 19:04:36 +0300
committerGitHub <noreply@github.com>2023-09-28 19:04:36 +0300
commitec893798b7a2a803466cc8f063051499ec3d96f7 (patch)
tree6c0c68de076d3d8493135cf7d958e43eeda04fd8 /examples/server/server.cpp
parent45855b3f1c7bdd0320aa632334d0b3e8965c26c4 (diff)
llama : custom attention mask + parallel decoding + no context swaps (#3228)
* tests : verify that RoPE is "additive" * llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask) * ggml : ggml_rope now takes a vector with positions instead of n_past * metal : add rope_f16 kernel + optimize cpy kernels * llama : unified KV cache + batch inference API * llama : add new llama_decode() API that works with llama_batch * llama : add cell_max heuristic for more efficient kv_cache * llama : extend llama_kv_cache API * llama : more robust cell_max heuristic + wip shift * metal : disable concurrency optimization * llama : add llama_kv_cache_shift_seq + no more context swaps * llama : apply K-cache roping for Falcon and Baichuan * speculative : fix KV cache management * parallel : example for serving multiple users in parallel * parallel : disable hot-plug to avoid cache fragmentation * fixes : speculative KV cache + llama worst-case graph * llama : extend batch API to select which logits to output * llama : fix worst case graph build * ggml-cuda : update rope implementation for parallel decoding (#3254) * ggml-cuda : update rope implementation for parallel decoding * better solution for p0 computation * fix rope * simpler rope implementation --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * make : add parallel to build + fix static functions in llama.cpp * simple : fix token counting * parallel : various improvements * llama : fix cell_max logic + rename functions * parallel : try smaller batches when the KV cache is fragmented * parallel : fix sequence termination criteria * llama : silence errors KV cache errors * parallel : remove new line from prompt * parallel : process system prompt once + configurable paramters + llama API * parallel : remove question with short answers * parallel : count cache misses * parallel : print misses on each request * parallel : minor * llama : fix n_kv to never become 0 * parallel : rename hot-plug to continuous-batching * llama : improve llama_batch API + simplify parallel example * simple : add parallel decoding support * simple : improve comments + free batch * ggml-cuda : add rope f16, restore performance with parallel decoding (#3272) * ggml-cuda : add rope f16, restore performance * offload KQ_mask with all models * fix rope shift --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * llama : disable MPI for now ggml-ci * train : make KQ_pos memory buffer permanent via dummy scale op * ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275) ggml-ci * parallel : fix bug (extra BOS) + smaller token_prev array * parallel : fix cases where the input prompts can overflow the batch * parallel : add disabled experimental batch chunking in powers of two * llama : llama.h formatting + comments * simple : add README.md * llama : fix kv cache heuristic when context is less than 32 * parallel : fix crash when `-n -1` * llama : simplify returns if/else branches * metal : use mm kernels for batch size > 2 * examples : utilize new llama_get_logits_ith() * examples : add example for batched decoding * examples : do not eval prompt 2 times (close #3348) * server : clear the KV cache beyond n_past before llama_decode * server : avoid context swaps by shifting the KV cache --------- Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r--examples/server/server.cpp35
1 files changed, 24 insertions, 11 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index ebd7f2fc..273eb36f 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -381,6 +381,10 @@ struct llama_server_context
// compare the evaluated prompt with the new prompt
n_past = common_part(embd, prompt_tokens);
+
+ // since #3228 we now have to manually manage the KV cache
+ llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx);
+
embd = prompt_tokens;
if (n_past == num_prompt_tokens)
{
@@ -411,19 +415,27 @@ struct llama_server_context
if (embd.size() >= (size_t)params.n_ctx)
{
- // Reset context
- const int n_left = (params.n_ctx - params.n_keep) / 2;
+ // Shift context
+
+ const int n_left = n_past - params.n_keep - 1;
+ const int n_discard = n_left/2;
+
+ llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
+ llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
+
+ for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++)
+ {
+ embd[i - n_discard] = embd[i];
+ }
+ embd.resize(embd.size() - n_discard);
+
+ n_past -= n_discard;
- std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + params.n_keep);
- new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
- embd = new_tokens;
- n_past = params.n_keep;
truncated = true;
LOG_VERBOSE("input truncated", {
{"n_ctx", params.n_ctx},
{"n_keep", params.n_keep},
{"n_left", n_left},
- {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
});
}
@@ -434,7 +446,8 @@ struct llama_server_context
{
n_eval = params.n_batch;
}
- if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads))
+
+ if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads))
{
LOG_ERROR("failed to eval", {
{"n_eval", n_eval},
@@ -523,13 +536,13 @@ struct llama_server_context
{
static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100;
- llama_sample_temperature(ctx, &candidates_p, temp);
+ llama_sample_temp(ctx, &candidates_p, temp);
result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
}
else if (mirostat == 2)
{
static float mirostat_mu = 2.0f * mirostat_tau;
- llama_sample_temperature(ctx, &candidates_p, temp);
+ llama_sample_temp(ctx, &candidates_p, temp);
result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
}
else
@@ -540,7 +553,7 @@ struct llama_server_context
llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep);
llama_sample_typical(ctx, &candidates_p, typical_p, min_keep);
llama_sample_top_p(ctx, &candidates_p, top_p, min_keep);
- llama_sample_temperature(ctx, &candidates_p, temp);
+ llama_sample_temp(ctx, &candidates_p, temp);
result.tok = llama_sample_token(ctx, &candidates_p);
}
}