diff options
author | Kawrakow <iwankawrakow@gmail.com> | 2025-02-23 14:31:11 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-02-23 14:31:11 +0200 |
commit | ac1d259b93eccfa7371c6b00c5749400ff2b2aea (patch) | |
tree | fe8bb34c9dcbea805595c5087f00b188bb89fc05 /examples | |
parent | 46bf73a37f1aabe6f0b40365b0c7b2ba831905f5 (diff) |
Fused MoE ffn_up and ffn_gate (#229)
* Fusing MoE up * unary(gate)
* Fusing MoE up * unary(gate): CUDA
We get ~13% speedup for PP-512 and ~2% for TG-128
for DeepSeek-Lite
* On CUDA also fuse MoE down * (up * unary(gate))
in case the MUL_MAT_ID op for the down experts is the next
op in the graph.
* Command line option to enable fused MoE up*unary(gate)
* Add fmoe option to llama-bench
* Adding forgotten gelu, relu, silu on ARM
---------
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'examples')
-rw-r--r-- | examples/llama-bench/llama-bench.cpp | 33 |
1 files changed, 30 insertions, 3 deletions
diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 0222c213..b0790e20 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -241,6 +241,7 @@ struct cmd_params { bool verbose; bool warmup; bool repack = false; + bool fmoe = false; output_formats output_format; output_formats output_format_stderr; }; @@ -271,6 +272,7 @@ static const cmd_params cmd_params_defaults = { /* verbose */ false, /* warmup */ true, /* repack */ false, + /* fmoe */ false, /* output_format */ MARKDOWN, /* output_format_stderr */ NONE, }; @@ -307,6 +309,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0"); printf(" -w, --warmup <0|1> (default: %s)\n", cmd_params_defaults.warmup ? "1" : "0"); printf(" -rtr, --run-time-repack <0|1> (default: %s)\n", cmd_params_defaults.repack ? "1" : "0"); + printf(" -fmoe, --fused-moe <0|1> (default: %s)\n", cmd_params_defaults.fmoe? "1" : "0"); printf("\n"); printf("Multiple values can be given for each parameter by separating them with ',' or by specifying the parameter multiple times.\n"); } @@ -607,6 +610,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.repack = std::stoi(argv[i]); + } else if (arg == "-fmoe" || arg == "--fused-moe") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.fmoe = std::stoi(argv[i]); } else { invalid_param = true; break; @@ -675,6 +684,7 @@ struct cmd_params_instance { bool use_mmap; bool embeddings; bool repack = false; + bool fmoe = false; llama_model_params to_llama_mparams() const { llama_model_params mparams = llama_model_default_params(); @@ -714,6 +724,7 @@ struct cmd_params_instance { cparams.offload_kqv = !no_kv_offload; cparams.flash_attn = flash_attn; cparams.mla_attn = mla_attn; + cparams.fused_moe_up_gate = fmoe; cparams.embeddings = embeddings; return cparams; @@ -765,6 +776,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, + /* .fmoe = */ params.fmoe, }; instances.push_back(instance); } @@ -794,6 +806,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, + /* .fmoe = */ params.fmoe, }; instances.push_back(instance); } @@ -823,6 +836,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, + /* .fmoe = */ params.fmoe, }; instances.push_back(instance); } @@ -852,6 +866,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, + /* .fmoe = */ params.fmoe, }; instances.push_back(instance); } @@ -892,6 +907,7 @@ struct test { bool use_mmap; bool embeddings; bool repack = false; + bool fmoe = false; int n_prompt; int n_gen; std::string test_time; @@ -922,6 +938,7 @@ struct test { use_mmap = inst.use_mmap; embeddings = inst.embeddings; repack = inst.repack; + fmoe = inst.fmoe; n_prompt = inst.n_prompt; n_gen = inst.n_gen; test_kind = inst.test_kind; @@ -1012,7 +1029,7 @@ struct test { "n_threads", "type_k", "type_v", "n_gpu_layers", "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "mla_attn", - "tensor_split", "use_mmap", "embeddings", "repack", + "tensor_split", "use_mmap", "embeddings", "repack", "fused_moe", "n_prompt", "n_gen", "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", "test", @@ -1033,7 +1050,8 @@ struct test { } if (field == "cuda" || field == "vulkan" || field == "kompute" || field == "metal" || field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" || - field == "flash_attn" || field == "mla_attn" || field == "use_mmap" || field == "embeddings" || field == "repack") { + field == "flash_attn" || field == "mla_attn" || field == "use_mmap" || field == "embeddings" || field == "repack" || + field == "fused_moe") { return BOOL; } if (field == "avg_ts" || field == "stddev_ts") { @@ -1068,7 +1086,7 @@ struct test { std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v), std::to_string(n_gpu_layers), split_mode_str(split_mode), std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn), std::to_string(mla_attn), - tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), + tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings), std::to_string(repack), std::to_string(fmoe), std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(avg_ns()), std::to_string(stdev_ns()), std::to_string(avg_ts()), std::to_string(stdev_ts()), @@ -1240,6 +1258,9 @@ struct markdown_printer : public printer { if (field == "repack") { return 3; } + if (field == "fused_moe") { + return 4; + } if (field == "test") { return 13; } @@ -1277,6 +1298,9 @@ struct markdown_printer : public printer { if (field == "repack") { return "rtr"; } + if (field == "fused_moe") { + return "fmoe"; + } if (field == "embeddings") { return "embd"; } @@ -1338,6 +1362,9 @@ struct markdown_printer : public printer { if (params.repack != cmd_params_defaults.repack) { fields.emplace_back("repack"); } + if (params.fmoe != cmd_params_defaults.fmoe) { + fields.emplace_back("fused_moe"); + } fields.emplace_back("test"); fields.emplace_back("t/s"); |