summaryrefslogtreecommitdiff
path: root/examples/save-load-state/save-load-state.cpp
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-09-28 19:04:36 +0300
committerGitHub <noreply@github.com>2023-09-28 19:04:36 +0300
commitec893798b7a2a803466cc8f063051499ec3d96f7 (patch)
tree6c0c68de076d3d8493135cf7d958e43eeda04fd8 /examples/save-load-state/save-load-state.cpp
parent45855b3f1c7bdd0320aa632334d0b3e8965c26c4 (diff)
llama : custom attention mask + parallel decoding + no context swaps (#3228)
* tests : verify that RoPE is "additive" * llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask) * ggml : ggml_rope now takes a vector with positions instead of n_past * metal : add rope_f16 kernel + optimize cpy kernels * llama : unified KV cache + batch inference API * llama : add new llama_decode() API that works with llama_batch * llama : add cell_max heuristic for more efficient kv_cache * llama : extend llama_kv_cache API * llama : more robust cell_max heuristic + wip shift * metal : disable concurrency optimization * llama : add llama_kv_cache_shift_seq + no more context swaps * llama : apply K-cache roping for Falcon and Baichuan * speculative : fix KV cache management * parallel : example for serving multiple users in parallel * parallel : disable hot-plug to avoid cache fragmentation * fixes : speculative KV cache + llama worst-case graph * llama : extend batch API to select which logits to output * llama : fix worst case graph build * ggml-cuda : update rope implementation for parallel decoding (#3254) * ggml-cuda : update rope implementation for parallel decoding * better solution for p0 computation * fix rope * simpler rope implementation --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * make : add parallel to build + fix static functions in llama.cpp * simple : fix token counting * parallel : various improvements * llama : fix cell_max logic + rename functions * parallel : try smaller batches when the KV cache is fragmented * parallel : fix sequence termination criteria * llama : silence errors KV cache errors * parallel : remove new line from prompt * parallel : process system prompt once + configurable paramters + llama API * parallel : remove question with short answers * parallel : count cache misses * parallel : print misses on each request * parallel : minor * llama : fix n_kv to never become 0 * parallel : rename hot-plug to continuous-batching * llama : improve llama_batch API + simplify parallel example * simple : add parallel decoding support * simple : improve comments + free batch * ggml-cuda : add rope f16, restore performance with parallel decoding (#3272) * ggml-cuda : add rope f16, restore performance * offload KQ_mask with all models * fix rope shift --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * llama : disable MPI for now ggml-ci * train : make KQ_pos memory buffer permanent via dummy scale op * ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275) ggml-ci * parallel : fix bug (extra BOS) + smaller token_prev array * parallel : fix cases where the input prompts can overflow the batch * parallel : add disabled experimental batch chunking in powers of two * llama : llama.h formatting + comments * simple : add README.md * llama : fix kv cache heuristic when context is less than 32 * parallel : fix crash when `-n -1` * llama : simplify returns if/else branches * metal : use mm kernels for batch size > 2 * examples : utilize new llama_get_logits_ith() * examples : add example for batched decoding * examples : do not eval prompt 2 times (close #3348) * server : clear the KV cache beyond n_past before llama_decode * server : avoid context swaps by shifting the KV cache --------- Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'examples/save-load-state/save-load-state.cpp')
-rw-r--r--examples/save-load-state/save-load-state.cpp16
1 files changed, 8 insertions, 8 deletions
diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp
index 95527bb8..6e4d40b9 100644
--- a/examples/save-load-state/save-load-state.cpp
+++ b/examples/save-load-state/save-load-state.cpp
@@ -35,11 +35,11 @@ int main(int argc, char ** argv) {
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
// init
- auto model = llama_load_model_from_file(params.model.c_str(), lparams);
+ auto * model = llama_load_model_from_file(params.model.c_str(), lparams);
if (model == nullptr) {
return 1;
}
- auto ctx = llama_new_context_with_model(model, lparams);
+ auto * ctx = llama_new_context_with_model(model, lparams);
if (ctx == nullptr) {
llama_free_model(model);
return 1;
@@ -54,7 +54,7 @@ int main(int argc, char ** argv) {
}
// evaluate prompt
- llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads);
+ llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0), params.n_threads);
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
n_past += n_prompt_tokens;
@@ -78,7 +78,7 @@ int main(int argc, char ** argv) {
printf("\n%s", params.prompt.c_str());
for (auto i = 0; i < params.n_predict; i++) {
- auto logits = llama_get_logits(ctx);
+ auto * logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
@@ -91,7 +91,7 @@ int main(int argc, char ** argv) {
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str.c_str());
- if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
+ if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx);
llama_free_model(model);
@@ -106,7 +106,7 @@ int main(int argc, char ** argv) {
llama_free(ctx);
// make new context
- auto ctx2 = llama_new_context_with_model(model, lparams);
+ auto * ctx2 = llama_new_context_with_model(model, lparams);
// Load state (rng, logits, embedding and kv_cache) from file
{
@@ -138,7 +138,7 @@ int main(int argc, char ** argv) {
// second run
for (auto i = 0; i < params.n_predict; i++) {
- auto logits = llama_get_logits(ctx2);
+ auto * logits = llama_get_logits(ctx2);
auto n_vocab = llama_n_vocab(ctx2);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
@@ -151,7 +151,7 @@ int main(int argc, char ** argv) {
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str.c_str());
- if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
+ if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx2);
llama_free_model(model);