diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-09-28 19:04:36 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-28 19:04:36 +0300 |
commit | ec893798b7a2a803466cc8f063051499ec3d96f7 (patch) | |
tree | 6c0c68de076d3d8493135cf7d958e43eeda04fd8 /examples/simple/simple.cpp | |
parent | 45855b3f1c7bdd0320aa632334d0b3e8965c26c4 (diff) |
llama : custom attention mask + parallel decoding + no context swaps (#3228)
* tests : verify that RoPE is "additive"
* llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask)
* ggml : ggml_rope now takes a vector with positions instead of n_past
* metal : add rope_f16 kernel + optimize cpy kernels
* llama : unified KV cache + batch inference API
* llama : add new llama_decode() API that works with llama_batch
* llama : add cell_max heuristic for more efficient kv_cache
* llama : extend llama_kv_cache API
* llama : more robust cell_max heuristic + wip shift
* metal : disable concurrency optimization
* llama : add llama_kv_cache_shift_seq + no more context swaps
* llama : apply K-cache roping for Falcon and Baichuan
* speculative : fix KV cache management
* parallel : example for serving multiple users in parallel
* parallel : disable hot-plug to avoid cache fragmentation
* fixes : speculative KV cache + llama worst-case graph
* llama : extend batch API to select which logits to output
* llama : fix worst case graph build
* ggml-cuda : update rope implementation for parallel decoding (#3254)
* ggml-cuda : update rope implementation for parallel decoding
* better solution for p0 computation
* fix rope
* simpler rope implementation
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* make : add parallel to build + fix static functions in llama.cpp
* simple : fix token counting
* parallel : various improvements
* llama : fix cell_max logic + rename functions
* parallel : try smaller batches when the KV cache is fragmented
* parallel : fix sequence termination criteria
* llama : silence errors KV cache errors
* parallel : remove new line from prompt
* parallel : process system prompt once + configurable paramters + llama API
* parallel : remove question with short answers
* parallel : count cache misses
* parallel : print misses on each request
* parallel : minor
* llama : fix n_kv to never become 0
* parallel : rename hot-plug to continuous-batching
* llama : improve llama_batch API + simplify parallel example
* simple : add parallel decoding support
* simple : improve comments + free batch
* ggml-cuda : add rope f16, restore performance with parallel decoding (#3272)
* ggml-cuda : add rope f16, restore performance
* offload KQ_mask with all models
* fix rope shift
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* llama : disable MPI for now
ggml-ci
* train : make KQ_pos memory buffer permanent via dummy scale op
* ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275)
ggml-ci
* parallel : fix bug (extra BOS) + smaller token_prev array
* parallel : fix cases where the input prompts can overflow the batch
* parallel : add disabled experimental batch chunking in powers of two
* llama : llama.h formatting + comments
* simple : add README.md
* llama : fix kv cache heuristic when context is less than 32
* parallel : fix crash when `-n -1`
* llama : simplify returns if/else branches
* metal : use mm kernels for batch size > 2
* examples : utilize new llama_get_logits_ith()
* examples : add example for batched decoding
* examples : do not eval prompt 2 times (close #3348)
* server : clear the KV cache beyond n_past before llama_decode
* server : avoid context swaps by shifting the KV cache
---------
Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'examples/simple/simple.cpp')
-rw-r--r-- | examples/simple/simple.cpp | 136 |
1 files changed, 97 insertions, 39 deletions
diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 440d22ec..1616a4a7 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -26,12 +26,18 @@ int main(int argc, char ** argv) { params.prompt = "Hello my name is"; } + // total length of the sequence including the prompt + const int n_len = 32; + // init LLM llama_backend_init(params.numa); llama_context_params ctx_params = llama_context_default_params(); + ctx_params.seed = 1234; + ctx_params.n_ctx = 2048; + llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params); if (model == NULL) { @@ -41,20 +47,31 @@ int main(int argc, char ** argv) { llama_context * ctx = llama_new_context_with_model(model, ctx_params); + if (ctx == NULL) { + fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + // tokenize the prompt std::vector<llama_token> tokens_list; tokens_list = ::llama_tokenize(ctx, params.prompt, true); - const int max_context_size = llama_n_ctx(ctx); - const int max_tokens_list_size = max_context_size - 4; + const int n_ctx = llama_n_ctx(ctx); + const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); + + LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_kv_req); - if ((int) tokens_list.size() > max_tokens_list_size) { - fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) tokens_list.size(), max_tokens_list_size); + // make sure the KV cache is big enough to hold all the prompt and generated tokens + if (n_kv_req > n_ctx) { + LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__); + LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__); return 1; } - fprintf(stderr, "\n\n"); + // print the prompt token-by-token + + fprintf(stderr, "\n"); for (auto id : tokens_list) { fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str()); @@ -62,63 +79,104 @@ int main(int argc, char ** argv) { fflush(stderr); - // main loop + // create a llama_batch with size 512 + // we use this object to submit token data for decoding - // The LLM keeps a contextual cache memory of previous token evaluation. - // Usually, once this cache is full, it is required to recompute a compressed context based on previous - // tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist - // example, we will just stop the loop once this cache is full or once an end of stream is detected. + llama_batch batch = llama_batch_init(512, 0); - const int n_gen = std::min(32, max_context_size); + // evaluate the initial prompt + batch.n_tokens = tokens_list.size(); - while (llama_get_kv_cache_token_count(ctx) < n_gen) { - // evaluate the transformer + for (int32_t i = 0; i < batch.n_tokens; i++) { + batch.token[i] = tokens_list[i]; + batch.pos[i] = i; + batch.seq_id[i] = 0; + batch.logits[i] = false; + } - if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return 1; - } + // llama_decode will output logits only for the last token of the prompt + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx, batch, params.n_threads) != 0) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + + // main loop - tokens_list.clear(); + int n_cur = batch.n_tokens; + int n_decode = 0; + const auto t_main_start = ggml_time_us(); + + while (n_cur <= n_len) { // sample the next token + { + auto n_vocab = llama_n_vocab(ctx); + auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - llama_token new_token_id = 0; + std::vector<llama_token_data> candidates; + candidates.reserve(n_vocab); - auto logits = llama_get_logits(ctx); - auto n_vocab = llama_n_vocab(ctx); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); + } - std::vector<llama_token_data> candidates; - candidates.reserve(n_vocab); + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); - } + // sample the most likely token + const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + + // is it an end of stream? + if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) { + LOG_TEE("\n"); + + break; + } - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str()); + fflush(stdout); - new_token_id = llama_sample_token_greedy(ctx , &candidates_p); + // prepare the next batch + batch.n_tokens = 0; - // is it an end of stream ? - if (new_token_id == llama_token_eos(ctx)) { - fprintf(stderr, " [end of text]\n"); - break; + // push this new token for next evaluation + batch.token [batch.n_tokens] = new_token_id; + batch.pos [batch.n_tokens] = n_cur; + batch.seq_id[batch.n_tokens] = 0; + batch.logits[batch.n_tokens] = true; + + batch.n_tokens += 1; + + n_decode += 1; } - // print the new token : - printf("%s", llama_token_to_piece(ctx, new_token_id).c_str()); - fflush(stdout); + n_cur += 1; - // push this new token for next evaluation - tokens_list.push_back(new_token_id); + // evaluate the current batch with the transformer model + if (llama_decode(ctx, batch, params.n_threads)) { + fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); + return 1; + } } + LOG_TEE("\n"); + + const auto t_main_end = ggml_time_us(); + + LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", + __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); + + llama_print_timings(ctx); + + fprintf(stderr, "\n"); + + llama_batch_free(batch); + llama_free(ctx); llama_free_model(model); llama_backend_free(); - fprintf(stderr, "\n\n"); - return 0; } |