diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/llama.cpp | 44 |
1 files changed, 20 insertions, 24 deletions
diff --git a/src/llama.cpp b/src/llama.cpp index 28e887ee..eed7aa61 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2516,6 +2516,7 @@ struct llama_cparams { bool offload_kqv; bool flash_attn; bool mla_attn; + bool fused_moe_up_gate; enum llama_pooling_type pooling_type; @@ -8628,30 +8629,20 @@ llm_expert_gating_func_type gating_op, } cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); - ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(up, "ffn_moe_up", il); - - ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(gate, "ffn_moe_gate", il); - - // This is equivalent to the commented out code below - ggml_tensor * par = ggml_fused_mul_unary(ctx, gate, up, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); - - //switch (type_op) { - // case LLM_FFN_SILU: - // { - // gate = ggml_silu(ctx, gate); - // cb(gate, "ffn_moe_silu", il); - // } break; - // case LLM_FFN_GELU: - // { - // gate = ggml_gelu(ctx, gate); - // cb(gate, "ffn_moe_gelu", il); - // } break; - // default: - // GGML_ABORT("fatal error"); - //} - //ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens] + + ggml_tensor * par; + if (lctx.cparams.fused_moe_up_gate) { + par = ggml_moe_up_gate(ctx, up_exps, gate_exps, cur, selected_experts, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); + } else { + ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); + + ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(gate, "ffn_moe_gate", il); + + // This is equivalent to the commented out code below + par = ggml_fused_mul_unary(ctx, gate, up, type_op == LLM_FFN_SILU ? GGML_UNARY_OP_SILU : GGML_UNARY_OP_GELU); + } cb(par, "ffn_moe_gate_par", il); ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] @@ -8907,6 +8898,7 @@ struct llm_build_context { const bool flash_attn; const bool mla_attn; + const bool fused_moe_up_gate; const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -8958,6 +8950,7 @@ struct llm_build_context { n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), mla_attn (cparams.mla_attn), + fused_moe_up_gate(cparams.fused_moe_up_gate), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -17605,6 +17598,7 @@ struct llama_context_params llama_context_default_params() { /*.offload_kqv =*/ true, /*.flash_attn =*/ false, /*.mla_attn =*/ false, + /*.fused_moe_up_gate =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -17804,6 +17798,7 @@ struct llama_context * llama_new_context_with_model( cparams.offload_kqv = params.offload_kqv; cparams.flash_attn = params.flash_attn; cparams.mla_attn = params.mla_attn; + cparams.fused_moe_up_gate= params.fused_moe_up_gate; cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; @@ -17871,6 +17866,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); + LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); |