summaryrefslogtreecommitdiff
path: root/examples/llama-bench/llama-bench.cpp
diff options
context:
space:
mode:
authorslaren <slarengh@gmail.com>2024-03-13 18:54:21 +0100
committerGitHub <noreply@github.com>2024-03-13 18:54:21 +0100
commitf30ea47a87ed4446ad55adb265755dc9102956a2 (patch)
treefc885962ca3d537cfdfbd6b4a2820b7c864b1ee0 /examples/llama-bench/llama-bench.cpp
parentd8fd0ccf6ac8b07791ffd1575eed436930854ae3 (diff)
llama : add pipeline parallelism support (#6017)
* llama : add pipeline parallelism support for batch processing with multiple CUDA GPUs ggml-ci * server : add -ub, --ubatch-size parameter * fix server embedding test * llama : fix Mamba inference for pipeline parallelism Tested to work correctly with both `main` and `parallel` examples. * llama : limit max batch size to n_batch * add LLAMA_SCHED_MAX_COPIES to configure the number of input copies for pipeline parallelism default increase to 4 (from 2) changing this value may improve performance for some systems, but increases memory usage * fix hip build * fix sycl build (disable cpy_tensor_async) * fix hip build * llama : limit n_batch and n_ubatch to n_ctx during context creation * llama : fix norm backend * batched-bench : sync after decode * swiftui : sync after decode * ggml : allow ggml_get_rows to use multiple threads if they are available * check n_ubatch >= n_tokens with non-casual attention * llama : do not limit n_batch to n_ctx with non-casual attn * server : construct batch with size of llama_n_batch * ggml_backend_cpu_graph_compute : fix return value when alloc fails * llama : better n_batch and n_ubatch comment * fix merge * small fix * reduce default n_batch to 2048 --------- Co-authored-by: Francis Couture-Harpin <git@compilade.net> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/llama-bench/llama-bench.cpp')
-rw-r--r--examples/llama-bench/llama-bench.cpp53
1 files changed, 43 insertions, 10 deletions
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index 2ff86ef6..bf94e7e7 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -164,6 +164,7 @@ struct cmd_params {
std::vector<int> n_prompt;
std::vector<int> n_gen;
std::vector<int> n_batch;
+ std::vector<int> n_ubatch;
std::vector<ggml_type> type_k;
std::vector<ggml_type> type_v;
std::vector<int> n_threads;
@@ -183,7 +184,8 @@ static const cmd_params cmd_params_defaults = {
/* model */ {"models/7B/ggml-model-q4_0.gguf"},
/* n_prompt */ {512},
/* n_gen */ {128},
- /* n_batch */ {512},
+ /* n_batch */ {2048},
+ /* n_ubatch */ {512},
/* type_k */ {GGML_TYPE_F16},
/* type_v */ {GGML_TYPE_F16},
/* n_threads */ {get_num_physical_cores()},
@@ -208,6 +210,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -p, --n-prompt <n> (default: %s)\n", join(cmd_params_defaults.n_prompt, ",").c_str());
printf(" -n, --n-gen <n> (default: %s)\n", join(cmd_params_defaults.n_gen, ",").c_str());
printf(" -b, --batch-size <n> (default: %s)\n", join(cmd_params_defaults.n_batch, ",").c_str());
+ printf(" -ub N, --ubatch-size <n> (default: %s)\n", join(cmd_params_defaults.n_ubatch, ",").c_str());
printf(" -ctk <t>, --cache-type-k <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_k, ggml_type_name), ",").c_str());
printf(" -ctv <t>, --cache-type-v <t> (default: %s)\n", join(transform_to_str(cmd_params_defaults.type_v, ggml_type_name), ",").c_str());
printf(" -t, --threads <n> (default: %s)\n", join(cmd_params_defaults.n_threads, ",").c_str());
@@ -217,7 +220,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
- printf(" -ts, --tensor_split <ts0/ts1/..> (default: 0)\n");
+ printf(" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n");
printf(" -r, --repetitions <n> (default: %d)\n", cmd_params_defaults.reps);
printf(" -o, --output <csv|json|md|sql> (default: %s)\n", output_format_str(cmd_params_defaults.output_format));
printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0");
@@ -297,6 +300,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
auto p = split<int>(argv[i], split_delim);
params.n_batch.insert(params.n_batch.end(), p.begin(), p.end());
+ } else if (arg == "-ub" || arg == "--ubatch-size") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ auto p = split<int>(argv[i], split_delim);
+ params.n_ubatch.insert(params.n_ubatch.end(), p.begin(), p.end());
} else if (arg == "-ctk" || arg == "--cache-type-k") {
if (++i >= argc) {
invalid_param = true;
@@ -455,6 +465,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.n_prompt.empty()) { params.n_prompt = cmd_params_defaults.n_prompt; }
if (params.n_gen.empty()) { params.n_gen = cmd_params_defaults.n_gen; }
if (params.n_batch.empty()) { params.n_batch = cmd_params_defaults.n_batch; }
+ if (params.n_ubatch.empty()) { params.n_ubatch = cmd_params_defaults.n_ubatch; }
if (params.type_k.empty()) { params.type_k = cmd_params_defaults.type_k; }
if (params.type_v.empty()) { params.type_v = cmd_params_defaults.type_v; }
if (params.n_gpu_layers.empty()) { params.n_gpu_layers = cmd_params_defaults.n_gpu_layers; }
@@ -474,6 +485,7 @@ struct cmd_params_instance {
int n_prompt;
int n_gen;
int n_batch;
+ int n_ubatch;
ggml_type type_k;
ggml_type type_v;
int n_threads;
@@ -511,6 +523,7 @@ struct cmd_params_instance {
cparams.n_ctx = n_prompt + n_gen;
cparams.n_batch = n_batch;
+ cparams.n_ubatch = n_ubatch;
cparams.type_k = type_k;
cparams.type_v = type_v;
cparams.offload_kqv = !no_kv_offload;
@@ -532,6 +545,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
for (const auto & mmp : params.use_mmap)
for (const auto & embd : params.embeddings)
for (const auto & nb : params.n_batch)
+ for (const auto & nub : params.n_ubatch)
for (const auto & tk : params.type_k)
for (const auto & tv : params.type_v)
for (const auto & nkvo : params.no_kv_offload)
@@ -545,6 +559,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .n_prompt = */ n_prompt,
/* .n_gen = */ 0,
/* .n_batch = */ nb,
+ /* .n_ubatch = */ nub,
/* .type_k = */ tk,
/* .type_v = */ tv,
/* .n_threads = */ nt,
@@ -568,6 +583,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .n_prompt = */ 0,
/* .n_gen = */ n_gen,
/* .n_batch = */ nb,
+ /* .n_ubatch = */ nub,
/* .type_k = */ tk,
/* .type_v = */ tv,
/* .n_threads = */ nt,
@@ -604,6 +620,7 @@ struct test {
uint64_t model_size;
uint64_t model_n_params;
int n_batch;
+ int n_ubatch;
int n_threads;
ggml_type type_k;
ggml_type type_v;
@@ -627,6 +644,7 @@ struct test {
model_size = llama_model_size(lmodel);
model_n_params = llama_model_n_params(lmodel);
n_batch = inst.n_batch;
+ n_ubatch = inst.n_ubatch;
n_threads = inst.n_threads;
type_k = inst.type_k;
type_v = inst.type_v;
@@ -705,7 +723,8 @@ struct test {
"cuda", "opencl", "vulkan", "kompute", "metal", "sycl", "gpu_blas", "blas",
"cpu_info", "gpu_info",
"model_filename", "model_type", "model_size", "model_n_params",
- "n_batch", "n_threads", "type_k", "type_v",
+ "n_batch", "n_ubatch",
+ "n_threads", "type_k", "type_v",
"n_gpu_layers", "split_mode",
"main_gpu", "no_kv_offload",
"tensor_split", "use_mmap", "embeddings",
@@ -719,7 +738,8 @@ struct test {
enum field_type {STRING, BOOL, INT, FLOAT};
static field_type get_field_type(const std::string & field) {
- if (field == "build_number" || field == "n_batch" || field == "n_threads" ||
+ if (field == "build_number" || field == "n_batch" || field == "n_ubatch" ||
+ field == "n_threads" ||
field == "model_size" || field == "model_n_params" ||
field == "n_gpu_layers" || field == "main_gpu" ||
field == "n_prompt" || field == "n_gen" ||
@@ -759,7 +779,8 @@ struct test {
std::to_string(metal), std::to_string(sycl), std::to_string(gpu_blas), std::to_string(blas),
cpu_info, gpu_info,
model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
- std::to_string(n_batch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
+ std::to_string(n_batch), std::to_string(n_ubatch),
+ std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
std::to_string(n_gpu_layers), split_mode_str(split_mode),
std::to_string(main_gpu), std::to_string(no_kv_offload),
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
@@ -957,6 +978,9 @@ struct markdown_printer : public printer {
if (params.n_batch.size() > 1 || params.n_batch != cmd_params_defaults.n_batch) {
fields.emplace_back("n_batch");
}
+ if (params.n_ubatch.size() > 1 || params.n_ubatch != cmd_params_defaults.n_ubatch) {
+ fields.emplace_back("n_ubatch");
+ }
if (params.type_k.size() > 1 || params.type_k != cmd_params_defaults.type_k) {
fields.emplace_back("type_k");
}
@@ -1096,25 +1120,32 @@ struct sql_printer : public printer {
};
static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
+ llama_set_n_threads(ctx, n_threads, n_threads);
+
+ //std::vector<llama_token> tokens(n_prompt, llama_token_bos(llama_get_model(ctx)));
+ //llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt, n_past, 0));
+ //GGML_UNUSED(n_batch);
+
std::vector<llama_token> tokens(n_batch, llama_token_bos(llama_get_model(ctx)));
int n_processed = 0;
- llama_set_n_threads(ctx, n_threads, n_threads);
-
while (n_processed < n_prompt) {
int n_tokens = std::min(n_prompt - n_processed, n_batch);
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0));
n_processed += n_tokens;
}
+
+ llama_synchronize(ctx);
}
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
- llama_token token = llama_token_bos(llama_get_model(ctx));
-
llama_set_n_threads(ctx, n_threads, n_threads);
+ llama_token token = llama_token_bos(llama_get_model(ctx));
+
for (int i = 0; i < n_gen; i++) {
llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0));
+ llama_synchronize(ctx);
}
}
@@ -1203,7 +1234,8 @@ int main(int argc, char ** argv) {
// warmup run
if (t.n_prompt > 0) {
- test_prompt(ctx, std::min(2, t.n_batch), 0, t.n_batch, t.n_threads);
+ //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
+ test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
}
if (t.n_gen > 0) {
test_gen(ctx, 1, 0, t.n_threads);
@@ -1219,6 +1251,7 @@ int main(int argc, char ** argv) {
if (t.n_gen > 0) {
test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
}
+
uint64_t t_ns = get_time_ns() - t_start;
t.samples_ns.push_back(t_ns);
}