summaryrefslogtreecommitdiff
path: root/src/llama.cpp
diff options
context:
space:
mode:
authorKawrakow <iwankawrakow@gmail.com>2025-02-23 14:31:11 +0200
committerGitHub <noreply@github.com>2025-02-23 14:31:11 +0200
commitac1d259b93eccfa7371c6b00c5749400ff2b2aea (patch)
treefe8bb34c9dcbea805595c5087f00b188bb89fc05 /src/llama.cpp
parent46bf73a37f1aabe6f0b40365b0c7b2ba831905f5 (diff)
Fused MoE ffn_up and ffn_gate (#229)
* Fusing MoE up * unary(gate) * Fusing MoE up * unary(gate): CUDA We get ~13% speedup for PP-512 and ~2% for TG-128 for DeepSeek-Lite * On CUDA also fuse MoE down * (up * unary(gate)) in case the MUL_MAT_ID op for the down experts is the next op in the graph. * Command line option to enable fused MoE up*unary(gate) * Add fmoe option to llama-bench * Adding forgotten gelu, relu, silu on ARM --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'src/llama.cpp')
-rw-r--r--src/llama.cpp44
1 files changed, 20 insertions, 24 deletions
diff --git a/src/llama.cpp b/src/llama.cpp
index 28e887ee..eed7aa61 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -2516,6 +2516,7 @@ struct llama_cparams {
bool offload_kqv;
bool flash_attn;
bool mla_attn;
+ bool fused_moe_up_gate;
enum llama_pooling_type pooling_type;
@@ -8628,30 +8629,20 @@ llm_expert_gating_func_type gating_op,
}
cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
- ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
- cb(up, "ffn_moe_up", il);
-
- ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
- cb(gate, "ffn_moe_gate", il);
-
- // This is equivalent to the commented out code below
- ggml_tensor * par = ggml_fused_mul_unary(ctx, gate, up, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
-
- //switch (type_op) {
- // case LLM_FFN_SILU:
- // {
- // gate = ggml_silu(ctx, gate);
- // cb(gate, "ffn_moe_silu", il);
- // } break;
- // case LLM_FFN_GELU:
- // {
- // gate = ggml_gelu(ctx, gate);
- // cb(gate, "ffn_moe_gelu", il);
- // } break;
- // default:
- // GGML_ABORT("fatal error");
- //}
- //ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens]
+
+ ggml_tensor * par;
+ if (lctx.cparams.fused_moe_up_gate) {
+ par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
+ } else {
+ ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+ cb(up, "ffn_moe_up", il);
+
+ ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+ cb(gate, "ffn_moe_gate", il);
+
+ // This is equivalent to the commented out code below
+ par = ggml_fused_mul_unary(ctx, gate, up, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU);
+ }
cb(par, "ffn_moe_gate_par", il);
ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
@@ -8907,6 +8898,7 @@ struct llm_build_context {
const bool flash_attn;
const bool mla_attn;
+ const bool fused_moe_up_gate;
const enum llama_pooling_type pooling_type;
const enum llama_rope_type rope_type;
@@ -8958,6 +8950,7 @@ struct llm_build_context {
n_ctx_orig (cparams.n_ctx_orig_yarn),
flash_attn (cparams.flash_attn),
mla_attn (cparams.mla_attn),
+ fused_moe_up_gate(cparams.fused_moe_up_gate),
pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type),
cb (cb),
@@ -17605,6 +17598,7 @@ struct llama_context_params llama_context_default_params() {
/*.offload_kqv =*/ true,
/*.flash_attn =*/ false,
/*.mla_attn =*/ false,
+ /*.fused_moe_up_gate =*/ false,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
};
@@ -17804,6 +17798,7 @@ struct llama_context * llama_new_context_with_model(
cparams.offload_kqv = params.offload_kqv;
cparams.flash_attn = params.flash_attn;
cparams.mla_attn = params.mla_attn;
+ cparams.fused_moe_up_gate= params.fused_moe_up_gate;
cparams.pooling_type = params.pooling_type;
cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
@@ -17871,6 +17866,7 @@ struct llama_context * llama_new_context_with_model(
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn);
+ LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate);
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);