From ec893798b7a2a803466cc8f063051499ec3d96f7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 28 Sep 2023 19:04:36 +0300 Subject: llama : custom attention mask + parallel decoding + no context swaps (#3228) * tests : verify that RoPE is "additive" * llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask) * ggml : ggml_rope now takes a vector with positions instead of n_past * metal : add rope_f16 kernel + optimize cpy kernels * llama : unified KV cache + batch inference API * llama : add new llama_decode() API that works with llama_batch * llama : add cell_max heuristic for more efficient kv_cache * llama : extend llama_kv_cache API * llama : more robust cell_max heuristic + wip shift * metal : disable concurrency optimization * llama : add llama_kv_cache_shift_seq + no more context swaps * llama : apply K-cache roping for Falcon and Baichuan * speculative : fix KV cache management * parallel : example for serving multiple users in parallel * parallel : disable hot-plug to avoid cache fragmentation * fixes : speculative KV cache + llama worst-case graph * llama : extend batch API to select which logits to output * llama : fix worst case graph build * ggml-cuda : update rope implementation for parallel decoding (#3254) * ggml-cuda : update rope implementation for parallel decoding * better solution for p0 computation * fix rope * simpler rope implementation --------- Co-authored-by: Georgi Gerganov * make : add parallel to build + fix static functions in llama.cpp * simple : fix token counting * parallel : various improvements * llama : fix cell_max logic + rename functions * parallel : try smaller batches when the KV cache is fragmented * parallel : fix sequence termination criteria * llama : silence errors KV cache errors * parallel : remove new line from prompt * parallel : process system prompt once + configurable paramters + llama API * parallel : remove question with short answers * parallel : count cache misses * parallel : print misses on each request * parallel : minor * llama : fix n_kv to never become 0 * parallel : rename hot-plug to continuous-batching * llama : improve llama_batch API + simplify parallel example * simple : add parallel decoding support * simple : improve comments + free batch * ggml-cuda : add rope f16, restore performance with parallel decoding (#3272) * ggml-cuda : add rope f16, restore performance * offload KQ_mask with all models * fix rope shift --------- Co-authored-by: Georgi Gerganov * llama : disable MPI for now ggml-ci * train : make KQ_pos memory buffer permanent via dummy scale op * ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275) ggml-ci * parallel : fix bug (extra BOS) + smaller token_prev array * parallel : fix cases where the input prompts can overflow the batch * parallel : add disabled experimental batch chunking in powers of two * llama : llama.h formatting + comments * simple : add README.md * llama : fix kv cache heuristic when context is less than 32 * parallel : fix crash when `-n -1` * llama : simplify returns if/else branches * metal : use mm kernels for batch size > 2 * examples : utilize new llama_get_logits_ith() * examples : add example for batched decoding * examples : do not eval prompt 2 times (close #3348) * server : clear the KV cache beyond n_past before llama_decode * server : avoid context swaps by shifting the KV cache --------- Co-authored-by: slaren --- examples/server/server.cpp | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) (limited to 'examples/server/server.cpp') diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ebd7f2fc..273eb36f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -381,6 +381,10 @@ struct llama_server_context // compare the evaluated prompt with the new prompt n_past = common_part(embd, prompt_tokens); + + // since #3228 we now have to manually manage the KV cache + llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx); + embd = prompt_tokens; if (n_past == num_prompt_tokens) { @@ -411,19 +415,27 @@ struct llama_server_context if (embd.size() >= (size_t)params.n_ctx) { - // Reset context - const int n_left = (params.n_ctx - params.n_keep) / 2; + // Shift context + + const int n_left = n_past - params.n_keep - 1; + const int n_discard = n_left/2; + + llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + + for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++) + { + embd[i - n_discard] = embd[i]; + } + embd.resize(embd.size() - n_discard); + + n_past -= n_discard; - std::vector new_tokens(embd.begin(), embd.begin() + params.n_keep); - new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end()); - embd = new_tokens; - n_past = params.n_keep; truncated = true; LOG_VERBOSE("input truncated", { {"n_ctx", params.n_ctx}, {"n_keep", params.n_keep}, {"n_left", n_left}, - {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, }); } @@ -434,7 +446,8 @@ struct llama_server_context { n_eval = params.n_batch; } - if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads)) + + if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads)) { LOG_ERROR("failed to eval", { {"n_eval", n_eval}, @@ -523,13 +536,13 @@ struct llama_server_context { static float mirostat_mu = 2.0f * mirostat_tau; const int mirostat_m = 100; - llama_sample_temperature(ctx, &candidates_p, temp); + llama_sample_temp(ctx, &candidates_p, temp); result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); } else if (mirostat == 2) { static float mirostat_mu = 2.0f * mirostat_tau; - llama_sample_temperature(ctx, &candidates_p, temp); + llama_sample_temp(ctx, &candidates_p, temp); result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); } else @@ -540,7 +553,7 @@ struct llama_server_context llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep); llama_sample_typical(ctx, &candidates_p, typical_p, min_keep); llama_sample_top_p(ctx, &candidates_p, top_p, min_keep); - llama_sample_temperature(ctx, &candidates_p, temp); + llama_sample_temp(ctx, &candidates_p, temp); result.tok = llama_sample_token(ctx, &candidates_p); } } -- cgit v1.2.3