summaryrefslogtreecommitdiff
path: root/examples/llama-bench/llama-bench.cpp
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-02-09 19:48:44 +0200
committerGitHub <noreply@github.com>2025-02-09 19:48:44 +0200
commitc12f73ba6153d162f36434cb48e36dd3649b7701 (patch)
tree3594697f680959c7130a3f109cd9e4abbc8d6e7d /examples/llama-bench/llama-bench.cpp
parentcae2b81155fdad75b7beab3a835c438120412969 (diff)
Add optional MLA (#188)
* Deepseek MLA Optimizations Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com> * Make MLA optional * Remove some unnecessary copies in the MLA attention * Deepseek MLA Optimizations V2 (#195) * Avoid allocating MHA KV cache when MLA is turned on * Added missing gguf-py file * Added final optimizations Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com> * Make sure we do have wk_b and wv_b before enabling MLA --------- Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com> Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> * Use type_k and type_v to set the types of the MLA caches They were hard-coded at f16. On my Ryzen-7950X with native bf16 support I get a fairly significant PP performance boost with bf16 KV-cache: PP-4096 = 320 t/s up from 292 t/s with fp16 KV-cache. * Better gemm strategy when nth > nhead It gives a ~10% PP performance boost for DeepSeek-Lite with 32 threads (with or without MLA). Before this commit, when nth > nhead heads were processed sequentially with all nth threads participating in each matrix multiplication. Now we ind the gcd of nhead and nth and split threads into nth/gcd groups, each group processing nhead/gcd heads. --------- Co-authored-by: Saood Karim <saood05@gmail.com> Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com> Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'examples/llama-bench/llama-bench.cpp')
-rw-r--r--examples/llama-bench/llama-bench.cpp35
1 files changed, 32 insertions, 3 deletions
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index 42320da8..41b93df5 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -232,6 +232,7 @@ struct cmd_params {
std::vector<int> main_gpu;
std::vector<bool> no_kv_offload;
std::vector<bool> flash_attn;
+ std::vector<bool> mla_attn;
std::vector<std::vector<float>> tensor_split;
std::vector<bool> use_mmap;
std::vector<bool> embeddings;
@@ -261,6 +262,7 @@ static const cmd_params cmd_params_defaults = {
/* main_gpu */ {0},
/* no_kv_offload */ {false},
/* flash_attn */ {false},
+ /* mla_attn */ {false},
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
/* use_mmap */ {true},
/* embeddings */ {false},
@@ -294,6 +296,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
+ printf(" -mla, --mla-attn <0|1> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str());
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
@@ -526,6 +529,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
}
auto p = string_split<bool>(argv[i], split_delim);
params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end());
+ } else if (arg == "-mla" || arg == "--mla-attn") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ auto p = string_split<bool>(argv[i], split_delim);
+ params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end());
} else if (arg == "-mmp" || arg == "--mmap") {
if (++i >= argc) {
invalid_param = true;
@@ -621,6 +631,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
+ if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; }
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
@@ -656,6 +667,7 @@ struct cmd_params_instance {
int main_gpu;
bool no_kv_offload;
bool flash_attn;
+ bool mla_attn;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
@@ -698,6 +710,7 @@ struct cmd_params_instance {
cparams.type_v = type_v;
cparams.offload_kqv = !no_kv_offload;
cparams.flash_attn = flash_attn;
+ cparams.mla_attn = mla_attn;
cparams.embeddings = embeddings;
return cparams;
@@ -722,6 +735,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
for (const auto & tv : params.type_v)
for (const auto & nkvo : params.no_kv_offload)
for (const auto & fa : params.flash_attn)
+ for (const auto & mla : params.mla_attn)
for (const auto & nt : params.n_threads) {
for (const auto & n_prompt : params.n_prompt) {
if (n_prompt == 0) {
@@ -743,6 +757,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
+ /* .mla_attn = */ mla,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -771,6 +786,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
+ /* .mla_attn = */ mla,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -799,6 +815,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
+ /* .mla_attn = */ mla,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -827,6 +844,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
/* .main_gpu = */ mg,
/* .no_kv_offload= */ nkvo,
/* .flash_attn = */ fa,
+ /* .mla_attn = */ mla,
/* .tensor_split = */ ts,
/* .use_mmap = */ mmp,
/* .embeddings = */ embd,
@@ -866,6 +884,7 @@ struct test {
int main_gpu;
bool no_kv_offload;
bool flash_attn;
+ bool mla_attn;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
@@ -895,6 +914,7 @@ struct test {
main_gpu = inst.main_gpu;
no_kv_offload = inst.no_kv_offload;
flash_attn = inst.flash_attn;
+ mla_attn = inst.mla_attn;
tensor_split = inst.tensor_split;
use_mmap = inst.use_mmap;
embeddings = inst.embeddings;
@@ -988,7 +1008,7 @@ struct test {
"n_batch", "n_ubatch",
"n_threads", "type_k", "type_v",
"n_gpu_layers", "split_mode",
- "main_gpu", "no_kv_offload", "flash_attn",
+ "main_gpu", "no_kv_offload", "flash_attn", "mla_attn",
"tensor_split", "use_mmap", "embeddings", "repack",
"n_prompt", "n_gen", "test_time",
"avg_ns", "stddev_ns",
@@ -1010,7 +1030,7 @@ struct test {
}
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
- field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") {
+ field == "flash_attn" || field == "mla_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") {
return BOOL;
}
if (field == "avg_ts" || field == "stddev_ts") {
@@ -1044,7 +1064,7 @@ struct test {
std::to_string(n_batch), std::to_string(n_ubatch),
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
std::to_string(n_gpu_layers), split_mode_str(split_mode),
- std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
+ std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn),
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack),
std::to_string(n_prompt), std::to_string(n_gen), test_time,
std::to_string(avg_ns()), std::to_string(stdev_ns()),
@@ -1208,6 +1228,9 @@ struct markdown_printer : public printer {
if (field == "flash_attn") {
return 2;
}
+ if (field == "mla_attn") {
+ return 3;
+ }
if (field == "use_mmap") {
return 4;
}
@@ -1242,6 +1265,9 @@ struct markdown_printer : public printer {
if (field == "flash_attn") {
return "fa";
}
+ if (field == "mla_attn") {
+ return "mla";
+ }
if (field == "use_mmap") {
return "mmap";
}
@@ -1294,6 +1320,9 @@ struct markdown_printer : public printer {
if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) {
fields.emplace_back("flash_attn");
}
+ if (params.mla_attn.size() > 1 || params.mla_attn != cmd_params_defaults.mla_attn) {
+ fields.emplace_back("mla_attn");
+ }
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
fields.emplace_back("tensor_split");
}