summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/llama-bench/llama-bench.cpp16
1 files changed, 8 insertions, 8 deletions
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp
index 438d2a7c..5756843a 100644
--- a/examples/llama-bench/llama-bench.cpp
+++ b/examples/llama-bench/llama-bench.cpp
@@ -232,7 +232,7 @@ struct cmd_params {
std::vector<int> main_gpu;
std::vector<bool> no_kv_offload;
std::vector<bool> flash_attn;
- std::vector<bool> mla_attn;
+ std::vector<int> mla_attn;
std::vector<std::vector<float>> tensor_split;
std::vector<bool> use_mmap;
std::vector<bool> embeddings;
@@ -264,7 +264,7 @@ static const cmd_params cmd_params_defaults = {
/* main_gpu */ {0},
/* no_kv_offload */ {false},
/* flash_attn */ {false},
- /* mla_attn */ {false},
+ /* mla_attn */ {0},
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
/* use_mmap */ {true},
/* embeddings */ {false},
@@ -300,7 +300,7 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
- printf(" -mla, --mla-attn <0|1> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str());
+ printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str());
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
@@ -576,7 +576,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
invalid_param = true;
break;
}
- auto p = string_split<bool>(argv[i], split_delim);
+ auto p = string_split<int>(argv[i], split_delim);
params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end());
} else if (arg == "-mmp" || arg == "--mmap") {
if (++i >= argc) {
@@ -726,7 +726,7 @@ struct cmd_params_instance {
int main_gpu;
bool no_kv_offload;
bool flash_attn;
- bool mla_attn;
+ int mla_attn;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
@@ -955,7 +955,7 @@ struct test {
int main_gpu;
bool no_kv_offload;
bool flash_attn;
- bool mla_attn;
+ int mla_attn;
std::vector<float> tensor_split;
bool use_mmap;
bool embeddings;
@@ -1097,13 +1097,13 @@ struct test {
field == "n_threads" ||
field == "model_size" || field == "model_n_params" ||
field == "n_gpu_layers" || field == "main_gpu" ||
- field == "n_prompt" || field == "n_gen" ||
+ field == "n_prompt" || field == "n_gen" || field == "mla_attn" ||
field == "avg_ns" || field == "stddev_ns") {
return INT;
}
if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" ||
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
- field == "flash_attn" || field == "mla_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" ||
+ field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" ||
field == "fused_moe") {
return BOOL;
}