diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-03-02 13:47:38 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-03-02 13:47:38 +0200 |
commit | a89adaa78f505675be7be6180f419b4b0158c15a (patch) | |
tree | ad82fa3ad44f66f37885bdf0d0d025166eff9535 /examples | |
parent | ef9a3d17b52bb5f6d55f7ef7e05e41e22f2ad81d (diff) |
SER - Smart Expert Reduction (#239)
* A better way to measure the cost of ggml_barrier
* Smart expert selection
* Add ser option to llama-bench
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'examples')
-rw-r--r-- | examples/llama-bench/llama-bench.cpp | 65 |
1 files changed, 63 insertions, 2 deletions
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index a08cb762..167525bc 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -215,6 +215,9 @@ static std::string pair_str(const std::pair<int, int> & p) { return buf; } +// Ser = Smart Expert Reduction +using Ser = std::pair<int,float>; + struct cmd_params { std::vector<std::string> model; std::vector<int> n_prompt; @@ -234,6 +237,7 @@ struct cmd_params { std::vector<bool> flash_attn; std::vector<int> mla_attn; std::vector<int> attn_max_batch; + std::vector<Ser> ser; std::vector<std::vector<float>> tensor_split; std::vector<bool> use_mmap; std::vector<bool> embeddings; @@ -267,6 +271,7 @@ static const cmd_params cmd_params_defaults = { /* flash_attn */ {false}, /* mla_attn */ {0}, /* attn_max_batch */ {0}, + /* ser */ {{-1,0.0f}}, /* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)}, /* use_mmap */ {true}, /* embeddings */ {false}, @@ -304,6 +309,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); printf(" -mla, --mla-attn <0|1|2> (default: %s)\n", join(cmd_params_defaults.mla_attn, ",").c_str()); printf(" -amb, --attn-max-batch <i> (default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str()); + printf(" -ser, --smart-expert-reduction <i,f>(default: %s)\n", join(cmd_params_defaults.attn_max_batch, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" --numa <distribute|isolate|numactl> (default: disabled)\n"); printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str()); @@ -387,6 +393,28 @@ bool parse_buft_overrides(const std::string& value, std::vector<llama_model_tens } return true; } +template<class T1, class T2> +std::vector<std::pair<T1,T2>> string_split_pairs(const std::string & str, char delim) { + std::vector<std::pair<T1,T2>> values; + std::istringstream str_stream(str); + std::string token; + T1 first_value; + int i = 0; + while (std::getline(str_stream, token, delim)) { + std::istringstream token_stream(token); + if (i%2 == 0) { + token_stream >> first_value; + if (token_stream.fail()) return {}; + } else { + T2 value; + token_stream >> value; + if (token_stream.fail()) return {}; + values.emplace_back(first_value, value); + } + i++; + } + return values; +} } static cmd_params parse_cmd_params(int argc, char ** argv) { @@ -588,6 +616,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { } auto p = string_split<int>(argv[i], split_delim); params.attn_max_batch.insert(params.attn_max_batch.end(), p.begin(), p.end()); + } else if (arg == "-ser" || arg == "--smart-expert-reduction") { + if (++i >= argc) { + invalid_param = true; + break; + } + auto p = string_split_pairs<int,float>(argv[i], split_delim); + params.ser.insert(params.ser.end(), p.begin(), p.end()); } else if (arg == "-mmp" || arg == "--mmap") { if (++i >= argc) { invalid_param = true; @@ -701,6 +736,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; } if (params.mla_attn.empty()) { params.mla_attn = cmd_params_defaults.mla_attn; } if (params.attn_max_batch.empty()){ params.attn_max_batch = cmd_params_defaults.attn_max_batch; } + if (params.ser.empty()) { params.ser = cmd_params_defaults.ser; } if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; } if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; } if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; } @@ -739,6 +775,7 @@ struct cmd_params_instance { bool flash_attn; int mla_attn; int attn_max_batch; + Ser ser; std::vector<float> tensor_split; bool use_mmap; bool embeddings; @@ -787,6 +824,8 @@ struct cmd_params_instance { cparams.mla_attn = mla_attn; cparams.attn_max_batch = attn_max_batch; cparams.fused_moe_up_gate = fmoe; + cparams.min_experts = ser.first; + cparams.thresh_experts = ser.second; cparams.embeddings = embeddings; return cparams; @@ -813,6 +852,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param for (const auto & fa : params.flash_attn) for (const auto & mla : params.mla_attn) for (const auto & amb : params.attn_max_batch) + for (const auto & ser : params.ser) for (const auto & nt : params.n_threads) { for (const auto & n_prompt : params.n_prompt) { if (n_prompt == 0) { @@ -836,6 +876,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .ser = */ ser, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -868,6 +909,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .ser = */ ser, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -900,6 +942,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .ser = */ ser, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -932,6 +975,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .flash_attn = */ fa, /* .mla_attn = */ mla, /* .attn_max_b = */ amb, + /* .ser = */ ser, /* .tensor_split = */ ts, /* .use_mmap = */ mmp, /* .embeddings = */ embd, @@ -975,6 +1019,7 @@ struct test { bool flash_attn; int mla_attn; int attn_max_batch; + Ser ser; std::vector<float> tensor_split; bool use_mmap; bool embeddings; @@ -1007,6 +1052,7 @@ struct test { flash_attn = inst.flash_attn; mla_attn = inst.mla_attn; attn_max_batch = inst.attn_max_batch; + ser = inst.ser; tensor_split = inst.tensor_split; use_mmap = inst.use_mmap; embeddings = inst.embeddings; @@ -1101,7 +1147,7 @@ struct test { "n_batch", "n_ubatch", "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", - "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", + "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", "attn_max_batch", "ser", "tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", @@ -1149,6 +1195,11 @@ struct test { tensor_split_str += "/"; } } + auto ser_to_string = [] (const Ser& ser) { + std::ostringstream str; + str << ser.first << ',' << ser.second; + return str.str(); + }; std::vector<std::string> values = { build_commit, std::to_string(build_number), std::to_string(cuda), std::to_string(vulkan), std::to_string(vulkan), @@ -1158,7 +1209,8 @@ struct test { std::to_string(n_batch), std::to_string(n_ubatch), std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), - std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), std::to_string(attn_max_batch), + std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), + std::to_string(mla_attn), std::to_string(attn_max_batch), ser_to_string(ser), tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(fmoe), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), @@ -1328,6 +1380,9 @@ struct markdown_printer : public printer { if (field == "attn_max_batch") { return 5; } + if (field == "ser") { + return 10; + } if (field == "use_mmap") { return 4; } @@ -1371,6 +1426,9 @@ struct markdown_printer : public printer { if (field == "attn_max_batch") { return "amb"; } + if (field == "attn_max_batch") { + return "ser"; + } if (field == "use_mmap") { return "mmap"; } @@ -1432,6 +1490,9 @@ struct markdown_printer : public printer { if (params.attn_max_batch.size() > 1 || params.attn_max_batch != cmd_params_defaults.mla_attn) { fields.emplace_back("attn_max_batch"); } + if (params.ser.size() > 1 || params.ser != cmd_params_defaults.ser) { + fields.emplace_back("ser"); + } if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) { fields.emplace_back("tensor_split"); } |