diff options
Diffstat (limited to 'examples')
-rw-r--r-- | examples/llama-bench/llama-bench.cpp | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 438d2a7c..5756843a 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -232,7 +232,7 @@ struct cmd_params { std::vector<int> main_gpu; std::vector<bool> no_kv_offload; std::vector<bool> flash_attn; - std::vector<bool> mla_attn; + std::vector<int> mla_attn; std::vector<std::vector<float>> tensor_split; std::vector<bool> use_mmap; std::vector<bool> embeddings; @@ -264,7 +264,7 @@ static const cmd_params cmd_params_defaults = { /* main_gpu */ {0}, /* no_kv_offload */ {false}, /* flash_attn */ {false}, - /* mla_attn */ {false}, + /* mla_attn */ {0}, /* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -300,7 +300,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str()); printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str()); printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); - printf(" -mla, --mla-attn <0|1> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str()); + printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" --numa <distribute|isolate|numactl> (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); @@ -576,7 +576,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { invalid_param = true; break; } - auto p = string_split<bool>(argv[i], split_delim); + auto p = string_split<int>(argv[i], split_delim); params.mla_attn.insert(params.mla_attn.end(), p.begin(), p.end()); } else if (arg == "-mmp" || arg == "--mmap") { if (++i >= argc) { @@ -726,7 +726,7 @@ struct cmd_params_instance { int main_gpu; bool no_kv_offload; bool flash_attn; - bool mla_attn; + int mla_attn; std::vector<float> tensor_split; bool use_mmap; bool embeddings; @@ -955,7 +955,7 @@ struct test { int main_gpu; bool no_kv_offload; bool flash_attn; - bool mla_attn; + int mla_attn; std::vector<float> tensor_split; bool use_mmap; bool embeddings; @@ -1097,13 +1097,13 @@ struct test { field == "n_threads" || field == "model_size" || field == "model_n_params" || field == "n_gpu_layers" || field == "main_gpu" || - field == "n_prompt" || field == "n_gen" || + field == "n_prompt" || field == "n_gen" || field == "mla_attn" || field == "avg_ns" || field == "stddev_ns") { return INT; } if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "flash_attn" || field == "mla_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || + field == "flash_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || field == "fused_moe") { return BOOL; } |