diff options
Diffstat (limited to 'src/llama.cpp')
-rw-r--r-- | src/llama.cpp | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/src/llama.cpp b/src/llama.cpp index 0dcc78dc..3a8b54ca 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2513,6 +2513,8 @@ struct llama_cparams { int mla_attn; int attn_max_batch; bool fused_moe_up_gate; + int min_experts; + float thresh_experts; enum llama_pooling_type pooling_type; @@ -8631,7 +8633,8 @@ llm_expert_gating_func_type gating_op, } // select experts - ggml_tensor * selected_experts = ggml_top_k(ctx, selection_probs, n_expert_used); // [n_expert_used, n_tokens] + ggml_tensor * selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used, + lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens] cb(selected_experts->src[0], "ffn_moe_argsort", il); cb(selected_experts, "ffn_moe_topk", il); @@ -8974,6 +8977,8 @@ struct llm_build_context { const int mla_attn; const int attn_max_batch; const bool fused_moe_up_gate; + const int min_experts; + const float thresh_experts; const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; @@ -9027,6 +9032,8 @@ struct llm_build_context { mla_attn (cparams.mla_attn), attn_max_batch (cparams.attn_max_batch), fused_moe_up_gate(cparams.fused_moe_up_gate), + min_experts (cparams.min_experts), + thresh_experts (cparams.thresh_experts), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), @@ -17725,6 +17732,8 @@ struct llama_context_params llama_context_default_params() { /*.mla_attn =*/ 0, /*.attn_max_batch =*/ 0, /*.fused_moe_up_gate =*/ false, + /*.min_experts =*/ -1, + /*.thtesh_experts =*/ 0.0f, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, }; @@ -17926,6 +17935,9 @@ struct llama_context * llama_new_context_with_model( cparams.mla_attn = params.mla_attn; cparams.attn_max_batch = params.attn_max_batch; cparams.fused_moe_up_gate= params.fused_moe_up_gate; + cparams.min_experts = params.min_experts; + cparams.thresh_experts = params.thresh_experts; + cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; @@ -17995,6 +18007,7 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: mla_attn = %d\n", __func__, cparams.mla_attn); LLAMA_LOG_INFO("%s: attn_max_b = %d\n", __func__, cparams.attn_max_batch); LLAMA_LOG_INFO("%s: fused_moe = %d\n", __func__, cparams.fused_moe_up_gate); + LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); |