diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-09-28 19:04:36 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-28 19:04:36 +0300 |
commit | ec893798b7a2a803466cc8f063051499ec3d96f7 (patch) | |
tree | 6c0c68de076d3d8493135cf7d958e43eeda04fd8 /examples/speculative/speculative.cpp | |
parent | 45855b3f1c7bdd0320aa632334d0b3e8965c26c4 (diff) |
llama : custom attention mask + parallel decoding + no context swaps (#3228)
* tests : verify that RoPE is "additive"
* llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask)
* ggml : ggml_rope now takes a vector with positions instead of n_past
* metal : add rope_f16 kernel + optimize cpy kernels
* llama : unified KV cache + batch inference API
* llama : add new llama_decode() API that works with llama_batch
* llama : add cell_max heuristic for more efficient kv_cache
* llama : extend llama_kv_cache API
* llama : more robust cell_max heuristic + wip shift
* metal : disable concurrency optimization
* llama : add llama_kv_cache_shift_seq + no more context swaps
* llama : apply K-cache roping for Falcon and Baichuan
* speculative : fix KV cache management
* parallel : example for serving multiple users in parallel
* parallel : disable hot-plug to avoid cache fragmentation
* fixes : speculative KV cache + llama worst-case graph
* llama : extend batch API to select which logits to output
* llama : fix worst case graph build
* ggml-cuda : update rope implementation for parallel decoding (#3254)
* ggml-cuda : update rope implementation for parallel decoding
* better solution for p0 computation
* fix rope
* simpler rope implementation
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* make : add parallel to build + fix static functions in llama.cpp
* simple : fix token counting
* parallel : various improvements
* llama : fix cell_max logic + rename functions
* parallel : try smaller batches when the KV cache is fragmented
* parallel : fix sequence termination criteria
* llama : silence errors KV cache errors
* parallel : remove new line from prompt
* parallel : process system prompt once + configurable paramters + llama API
* parallel : remove question with short answers
* parallel : count cache misses
* parallel : print misses on each request
* parallel : minor
* llama : fix n_kv to never become 0
* parallel : rename hot-plug to continuous-batching
* llama : improve llama_batch API + simplify parallel example
* simple : add parallel decoding support
* simple : improve comments + free batch
* ggml-cuda : add rope f16, restore performance with parallel decoding (#3272)
* ggml-cuda : add rope f16, restore performance
* offload KQ_mask with all models
* fix rope shift
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* llama : disable MPI for now
ggml-ci
* train : make KQ_pos memory buffer permanent via dummy scale op
* ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275)
ggml-ci
* parallel : fix bug (extra BOS) + smaller token_prev array
* parallel : fix cases where the input prompts can overflow the batch
* parallel : add disabled experimental batch chunking in powers of two
* llama : llama.h formatting + comments
* simple : add README.md
* llama : fix kv cache heuristic when context is less than 32
* parallel : fix crash when `-n -1`
* llama : simplify returns if/else branches
* metal : use mm kernels for batch size > 2
* examples : utilize new llama_get_logits_ith()
* examples : add example for batched decoding
* examples : do not eval prompt 2 times (close #3348)
* server : clear the KV cache beyond n_past before llama_decode
* server : avoid context swaps by shifting the KV cache
---------
Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'examples/speculative/speculative.cpp')
-rw-r--r-- | examples/speculative/speculative.cpp | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index aa904183..2445d78d 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -37,7 +37,7 @@ int main(int argc, char ** argv) { llama_context * ctx_dft = NULL; // load the target model - params.perplexity = true; // HACK: enable logits_all = true + params.logits_all = true; std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params); // load the draft model @@ -70,9 +70,9 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt with both models - llama_eval(ctx_tgt, inp.data(), int(inp.size() - 1), 0, params.n_threads); - llama_eval(ctx_tgt, &inp.back(), 1, inp.size() - 1, params.n_threads); - llama_eval(ctx_dft, inp.data(), int(inp.size()), 0, params.n_threads); + llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0), params.n_threads); + llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0), params.n_threads); + llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0), params.n_threads); const auto t_enc_end = ggml_time_us(); @@ -134,7 +134,7 @@ int main(int argc, char ** argv) { while (true) { // sample from the target model - const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); + llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); // remember which tokens were sampled - used for repetition penalties during sampling last_tokens.erase(last_tokens.begin()); @@ -172,7 +172,8 @@ int main(int argc, char ** argv) { LOG("out of drafted tokens\n"); } - llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads); + llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx); + llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads); ++n_past_dft; // heuristic for n_draft @@ -256,7 +257,8 @@ int main(int argc, char ** argv) { } // evaluate the drafted token on the draft model - llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads); + llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx); + llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads); ++n_past_cur; if (grammar_dft != NULL) { @@ -265,7 +267,8 @@ int main(int argc, char ** argv) { } // evaluate the target model on the drafted tokens - llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads); + llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx); + llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads); ++n_past_tgt; // the first token is always proposed by the traget model before the speculation loop |