diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-09-28 19:04:36 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-28 19:04:36 +0300 |
commit | ec893798b7a2a803466cc8f063051499ec3d96f7 (patch) | |
tree | 6c0c68de076d3d8493135cf7d958e43eeda04fd8 /examples/save-load-state/save-load-state.cpp | |
parent | 45855b3f1c7bdd0320aa632334d0b3e8965c26c4 (diff) |
llama : custom attention mask + parallel decoding + no context swaps (#3228)
* tests : verify that RoPE is "additive"
* llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask)
* ggml : ggml_rope now takes a vector with positions instead of n_past
* metal : add rope_f16 kernel + optimize cpy kernels
* llama : unified KV cache + batch inference API
* llama : add new llama_decode() API that works with llama_batch
* llama : add cell_max heuristic for more efficient kv_cache
* llama : extend llama_kv_cache API
* llama : more robust cell_max heuristic + wip shift
* metal : disable concurrency optimization
* llama : add llama_kv_cache_shift_seq + no more context swaps
* llama : apply K-cache roping for Falcon and Baichuan
* speculative : fix KV cache management
* parallel : example for serving multiple users in parallel
* parallel : disable hot-plug to avoid cache fragmentation
* fixes : speculative KV cache + llama worst-case graph
* llama : extend batch API to select which logits to output
* llama : fix worst case graph build
* ggml-cuda : update rope implementation for parallel decoding (#3254)
* ggml-cuda : update rope implementation for parallel decoding
* better solution for p0 computation
* fix rope
* simpler rope implementation
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* make : add parallel to build + fix static functions in llama.cpp
* simple : fix token counting
* parallel : various improvements
* llama : fix cell_max logic + rename functions
* parallel : try smaller batches when the KV cache is fragmented
* parallel : fix sequence termination criteria
* llama : silence errors KV cache errors
* parallel : remove new line from prompt
* parallel : process system prompt once + configurable paramters + llama API
* parallel : remove question with short answers
* parallel : count cache misses
* parallel : print misses on each request
* parallel : minor
* llama : fix n_kv to never become 0
* parallel : rename hot-plug to continuous-batching
* llama : improve llama_batch API + simplify parallel example
* simple : add parallel decoding support
* simple : improve comments + free batch
* ggml-cuda : add rope f16, restore performance with parallel decoding (#3272)
* ggml-cuda : add rope f16, restore performance
* offload KQ_mask with all models
* fix rope shift
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* llama : disable MPI for now
ggml-ci
* train : make KQ_pos memory buffer permanent via dummy scale op
* ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275)
ggml-ci
* parallel : fix bug (extra BOS) + smaller token_prev array
* parallel : fix cases where the input prompts can overflow the batch
* parallel : add disabled experimental batch chunking in powers of two
* llama : llama.h formatting + comments
* simple : add README.md
* llama : fix kv cache heuristic when context is less than 32
* parallel : fix crash when `-n -1`
* llama : simplify returns if/else branches
* metal : use mm kernels for batch size > 2
* examples : utilize new llama_get_logits_ith()
* examples : add example for batched decoding
* examples : do not eval prompt 2 times (close #3348)
* server : clear the KV cache beyond n_past before llama_decode
* server : avoid context swaps by shifting the KV cache
---------
Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'examples/save-load-state/save-load-state.cpp')
-rw-r--r-- | examples/save-load-state/save-load-state.cpp | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 95527bb8..6e4d40b9 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -35,11 +35,11 @@ int main(int argc, char ** argv) { auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0); // init - auto model = llama_load_model_from_file(params.model.c_str(), lparams); + auto * model = llama_load_model_from_file(params.model.c_str(), lparams); if (model == nullptr) { return 1; } - auto ctx = llama_new_context_with_model(model, lparams); + auto * ctx = llama_new_context_with_model(model, lparams); if (ctx == nullptr) { llama_free_model(model); return 1; @@ -54,7 +54,7 @@ int main(int argc, char ** argv) { } // evaluate prompt - llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads); + llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0), params.n_threads); last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens); n_past += n_prompt_tokens; @@ -78,7 +78,7 @@ int main(int argc, char ** argv) { printf("\n%s", params.prompt.c_str()); for (auto i = 0; i < params.n_predict; i++) { - auto logits = llama_get_logits(ctx); + auto * logits = llama_get_logits(ctx); auto n_vocab = llama_n_vocab(ctx); std::vector<llama_token_data> candidates; candidates.reserve(n_vocab); @@ -91,7 +91,7 @@ int main(int argc, char ** argv) { last_n_tokens_data.push_back(next_token); printf("%s", next_token_str.c_str()); - if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx); llama_free_model(model); @@ -106,7 +106,7 @@ int main(int argc, char ** argv) { llama_free(ctx); // make new context - auto ctx2 = llama_new_context_with_model(model, lparams); + auto * ctx2 = llama_new_context_with_model(model, lparams); // Load state (rng, logits, embedding and kv_cache) from file { @@ -138,7 +138,7 @@ int main(int argc, char ** argv) { // second run for (auto i = 0; i < params.n_predict; i++) { - auto logits = llama_get_logits(ctx2); + auto * logits = llama_get_logits(ctx2); auto n_vocab = llama_n_vocab(ctx2); std::vector<llama_token_data> candidates; candidates.reserve(n_vocab); @@ -151,7 +151,7 @@ int main(int argc, char ** argv) { last_n_tokens_data.push_back(next_token); printf("%s", next_token_str.c_str()); - if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx2); llama_free_model(model); |