diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-09-28 19:04:36 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-28 19:04:36 +0300 |
commit | ec893798b7a2a803466cc8f063051499ec3d96f7 (patch) | |
tree | 6c0c68de076d3d8493135cf7d958e43eeda04fd8 /examples/perplexity | |
parent | 45855b3f1c7bdd0320aa632334d0b3e8965c26c4 (diff) |
llama : custom attention mask + parallel decoding + no context swaps (#3228)
* tests : verify that RoPE is "additive"
* llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask)
* ggml : ggml_rope now takes a vector with positions instead of n_past
* metal : add rope_f16 kernel + optimize cpy kernels
* llama : unified KV cache + batch inference API
* llama : add new llama_decode() API that works with llama_batch
* llama : add cell_max heuristic for more efficient kv_cache
* llama : extend llama_kv_cache API
* llama : more robust cell_max heuristic + wip shift
* metal : disable concurrency optimization
* llama : add llama_kv_cache_shift_seq + no more context swaps
* llama : apply K-cache roping for Falcon and Baichuan
* speculative : fix KV cache management
* parallel : example for serving multiple users in parallel
* parallel : disable hot-plug to avoid cache fragmentation
* fixes : speculative KV cache + llama worst-case graph
* llama : extend batch API to select which logits to output
* llama : fix worst case graph build
* ggml-cuda : update rope implementation for parallel decoding (#3254)
* ggml-cuda : update rope implementation for parallel decoding
* better solution for p0 computation
* fix rope
* simpler rope implementation
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* make : add parallel to build + fix static functions in llama.cpp
* simple : fix token counting
* parallel : various improvements
* llama : fix cell_max logic + rename functions
* parallel : try smaller batches when the KV cache is fragmented
* parallel : fix sequence termination criteria
* llama : silence errors KV cache errors
* parallel : remove new line from prompt
* parallel : process system prompt once + configurable paramters + llama API
* parallel : remove question with short answers
* parallel : count cache misses
* parallel : print misses on each request
* parallel : minor
* llama : fix n_kv to never become 0
* parallel : rename hot-plug to continuous-batching
* llama : improve llama_batch API + simplify parallel example
* simple : add parallel decoding support
* simple : improve comments + free batch
* ggml-cuda : add rope f16, restore performance with parallel decoding (#3272)
* ggml-cuda : add rope f16, restore performance
* offload KQ_mask with all models
* fix rope shift
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* llama : disable MPI for now
ggml-ci
* train : make KQ_pos memory buffer permanent via dummy scale op
* ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275)
ggml-ci
* parallel : fix bug (extra BOS) + smaller token_prev array
* parallel : fix cases where the input prompts can overflow the batch
* parallel : add disabled experimental batch chunking in powers of two
* llama : llama.h formatting + comments
* simple : add README.md
* llama : fix kv cache heuristic when context is less than 32
* parallel : fix crash when `-n -1`
* llama : simplify returns if/else branches
* metal : use mm kernels for batch size > 2
* examples : utilize new llama_get_logits_ith()
* examples : add example for batched decoding
* examples : do not eval prompt 2 times (close #3348)
* server : clear the KV cache beyond n_past before llama_decode
* server : avoid context swaps by shifting the KV cache
---------
Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'examples/perplexity')
-rw-r--r-- | examples/perplexity/perplexity.cpp | 51 |
1 files changed, 36 insertions, 15 deletions
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 2b375e34..de08bd4a 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -80,7 +80,9 @@ static void write_logfile( static std::vector<float> softmax(const std::vector<float>& logits) { std::vector<float> probs(logits.size()); float max_logit = logits[0]; - for (float v : logits) max_logit = std::max(max_logit, v); + for (float v : logits) { + max_logit = std::max(max_logit, v); + } double sum_exp = 0.0; for (size_t i = 0; i < logits.size(); i++) { // Subtract the maximum logit value from the current logit value for numerical stability @@ -89,15 +91,21 @@ static std::vector<float> softmax(const std::vector<float>& logits) { sum_exp += exp_logit; probs[i] = exp_logit; } - for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp; + for (size_t i = 0; i < probs.size(); i++) { + probs[i] /= sum_exp; + } return probs; } static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) { float max_logit = logits[0]; - for (int i = 1; i < n_vocab; ++i) max_logit = std::max(max_logit, logits[i]); + for (int i = 1; i < n_vocab; ++i) { + max_logit = std::max(max_logit, logits[i]); + } double sum_exp = 0.0; - for (int i = 0; i < n_vocab; ++i) sum_exp += expf(logits[i] - max_logit); + for (int i = 0; i < n_vocab; ++i) { + sum_exp += expf(logits[i] - max_logit); + } return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp}; } @@ -108,7 +116,8 @@ static void process_logits( std::mutex mutex; int counter = 0; auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () { - double local_nll = 0, local_nll2 = 0; + double local_nll = 0; + double local_nll2 = 0; while (true) { std::unique_lock<std::mutex> lock(mutex); int i = counter++; @@ -126,10 +135,13 @@ static void process_logits( prob_history[i] = results.prob; } }; - for (auto & w : workers) w = std::thread(compute); + for (auto & w : workers) { + w = std::thread(compute); + } compute(); - for (auto & w : workers) w.join(); - + for (auto & w : workers) { + w.join(); + } } static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) { @@ -152,8 +164,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & return {std::move(tokens), 0., {}, {}}; } - std::vector<float> logit_history; - std::vector<float> prob_history; + std::vector<float> logit_history; + std::vector<float> prob_history; logit_history.resize(tokens.size()); prob_history.resize(tokens.size()); @@ -195,12 +207,15 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const auto t_start = std::chrono::high_resolution_clock::now(); + // clear the KV cache + llama_kv_cache_tokens_rm(ctx, -1, -1); + for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); //fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); - if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) { //fprintf(stderr, "%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; } @@ -320,6 +335,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const auto t_start = std::chrono::high_resolution_clock::now(); + // clear the KV cache + llama_kv_cache_tokens_rm(ctx, -1, -1); + for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); @@ -332,7 +350,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par tokens[batch_start] = llama_token_bos(ctx); } - if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; } @@ -402,7 +420,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par } static std::vector<float> hellaswag_evaluate_tokens( - llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch, int n_vocab, int n_thread + llama_context * ctx, std::vector<int> & tokens, int n_past, int n_batch, int n_vocab, int n_thread ) { std::vector<float> result; result.reserve(tokens.size() * n_vocab); @@ -410,7 +428,7 @@ static std::vector<float> hellaswag_evaluate_tokens( for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) { size_t n_tokens = tokens.size() - i_chunk * n_batch; n_tokens = std::min(n_tokens, size_t(n_batch)); - if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) { + if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0), n_thread)) { fprintf(stderr, "%s : failed to eval\n", __func__); return {}; } @@ -550,6 +568,9 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { query_embd.resize(32); } + // clear the KV cache + llama_kv_cache_tokens_rm(ctx, -1, -1); + auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads); if (logits.empty()) { fprintf(stderr, "%s : failed to eval\n", __func__); @@ -661,7 +682,7 @@ int main(int argc, char ** argv) { return 1; } - params.perplexity = true; + params.logits_all = true; params.n_batch = std::min(params.n_batch, params.n_ctx); if (params.ppl_stride > 0) { |