diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-10-11 21:25:33 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-11 21:25:33 +0300 |
commit | 8c70a5ff25964f0a81e20d142a2f5ac5baff22fc (patch) | |
tree | 50946ed36e647f5619b1866d77e3042d4d9743c5 /examples/batched-bench/batched-bench.cpp | |
parent | 24ba3d829e31a6eda3fa1723f692608c2fa3adda (diff) |
batched : add bench tool (#3545)
* batched : add bench tool
* batched : minor fix table
* batched-bench : add readme + n_kv_max is now configurable
* batched-bench : init warm-up batch
* batched-bench : pass custom set of PP, TG and PL
* batched-bench : add mmq CLI arg
Diffstat (limited to 'examples/batched-bench/batched-bench.cpp')
-rw-r--r-- | examples/batched-bench/batched-bench.cpp | 251 |
1 files changed, 251 insertions, 0 deletions
diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp new file mode 100644 index 00000000..3e1e0716 --- /dev/null +++ b/examples/batched-bench/batched-bench.cpp @@ -0,0 +1,251 @@ +#include "common.h" +#include "llama.h" + +#include <algorithm> +#include <cmath> +#include <cstdio> +#include <string> +#include <vector> + +// mutates the input string +static std::vector<int> parse_list(char * p) { + std::vector<int> ret; + + char * q = p; + + while (*p) { + if (*p == ',') { + *p = '\0'; + ret.push_back(std::atoi(q)); + q = p + 1; + } + + ++p; + } + + ret.push_back(std::atoi(q)); + + return ret; +} + +int main(int argc, char ** argv) { + gpt_params params; + + if (argc == 1 || argv[1][0] == '-') { + printf("usage: %s MODEL_PATH [N_KV_MAX] [IS_PP_SHARED] [NGL] [MMQ] <PP> <TG> <PL>\n" , argv[0]); + printf(" <PP>, <TG> and PL are comma-separated lists of numbers without spaces\n\n"); + printf(" example: %s ggml-model-f16.gguf 2048 0 999 0 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]); + return 1 ; + } + + int n_kv_max = 2048; + int is_pp_shared = 0; + int n_gpu_layers = 0; + int mmq = 0; + + std::vector<int> n_pp = { 128, 256, 512, 1024, 2048, 3584, 7680, }; + std::vector<int> n_tg = { 128, 256, }; + std::vector<int> n_pl = { 1, 2, 4, 8, 16, 32, }; + //std::vector<int> n_pl = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 32, }; + + if (argc >= 2) { + params.model = argv[1]; + } + + if (argc >= 3) { + n_kv_max = std::atoi(argv[2]); + } + + if (argc >= 4) { + is_pp_shared = std::atoi(argv[3]); + } + + if (argc >= 5) { + n_gpu_layers = std::atoi(argv[4]); + } + + if (argc >= 6) { + mmq = std::atoi(argv[5]); + } + + if (argc >= 7) { + n_pp = parse_list(argv[6]); + } + + if (argc >= 8) { + n_tg = parse_list(argv[7]); + } + + if (argc >= 9) { + n_pl = parse_list(argv[8]); + } + + // init LLM + + llama_backend_init(params.numa); + + // initialize the model + + llama_model_params model_params = llama_model_default_params(); + + model_params.n_gpu_layers = n_gpu_layers; + + llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + + if (model == NULL) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); + return 1; + } + + llama_context_params ctx_params = llama_context_default_params(); + + ctx_params.seed = 1234; + ctx_params.n_ctx = n_kv_max; + ctx_params.n_batch = 512; + ctx_params.mul_mat_q = mmq; + + ctx_params.n_threads = params.n_threads; + ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + + llama_context * ctx = llama_new_context_with_model(model, ctx_params); + + if (ctx == NULL) { + fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + return 1; + } + + llama_batch batch = llama_batch_init(n_kv_max, 0); + + // decode in batches of ctx_params.n_batch tokens + auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) { + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.seq_id + i, + batch.logits + i, + 0, 0, 0, // unused + }; + + const int ret = llama_decode(ctx, batch_view); + if (ret != 0) { + LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); + return false; + } + } + + return true; + }; + + // warm up + { + batch.n_tokens = 16; + + for (int i = 0; i < batch.n_tokens; ++i) { + batch.token[i] = 0; + batch.pos[i] = i; + batch.seq_id[i] = 0; + batch.logits[i] = false; + } + + if (!decode_helper(ctx, batch, ctx_params.n_batch)) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + } + + LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s"); + LOG_TEE("|%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "----", "------", "--------", "--------", "--------", "--------", "--------", "--------"); + + for ( int i_pp = 0; i_pp < (int) n_pp.size(); ++i_pp) { + for ( int i_tg = 0; i_tg < (int) n_tg.size(); ++i_tg) { + for (int i_pl = 0; i_pl < (int) n_pl.size(); ++i_pl) { + const int pp = n_pp[i_pp]; + const int tg = n_tg[i_tg]; + const int pl = n_pl[i_pl]; + + const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg); + + if (n_ctx_req > n_kv_max) { + continue; + } + + batch.n_tokens = is_pp_shared ? pp : pl*pp; + + for (int i = 0; i < batch.n_tokens; ++i) { + batch.token[i] = 0; + batch.pos[i] = i; + batch.seq_id[i] = 0; + batch.logits[i] = false; + } + batch.logits[batch.n_tokens - 1] = true; + + const auto t_pp_start = ggml_time_us(); + + llama_kv_cache_tokens_rm(ctx, -1, -1); + + if (!decode_helper(ctx, batch, ctx_params.n_batch)) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + + if (is_pp_shared) { + for (int32_t i = 1; i < pl; ++i) { + llama_kv_cache_seq_cp(ctx, 0, i, 0, pp); + } + } + + const auto t_pp_end = ggml_time_us(); + + const auto t_tg_start = ggml_time_us(); + + for (int i = 0; i < tg; ++i) { + batch.n_tokens = pl; + + for (int j = 0; j < pl; ++j) { + batch.token[j] = 0; + batch.pos[j] = pp + i; + batch.seq_id[j] = j; + batch.logits[j] = true; + } + + if (!decode_helper(ctx, batch, ctx_params.n_batch)) { + LOG_TEE("%s: llama_decode() failed\n", __func__); + return 1; + } + } + + const auto t_tg_end = ggml_time_us(); + + const int32_t n_kv = n_ctx_req; + + const float t_pp = (t_pp_end - t_pp_start) / 1000000.0f; + const float t_tg = (t_tg_end - t_tg_start) / 1000000.0f; + const float t = t_pp + t_tg; + + const float speed_pp = is_pp_shared ? pp / t_pp : pl*pp / t_pp; + const float speed_tg = pl*tg / t_tg; + const float speed = n_kv / t; + + LOG_TEE("|%6d | %6d | %4d | %6d | %8.3f | %8.2f | %8.3f | %8.2f | %8.3f | %8.2f |\n", pp, tg, pl, n_kv, t_pp, speed_pp, t_tg, speed_tg, t, speed); + } + } + } + + llama_print_timings(ctx); + + llama_batch_free(batch); + + llama_free(ctx); + llama_free_model(model); + + llama_backend_free(); + + fprintf(stderr, "\n\n"); + + return 0; +} |