diff options
Diffstat (limited to 'common/common.cpp')
-rw-r--r-- | common/common.cpp | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp index 5c9070da..e62944b9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -322,6 +322,26 @@ bool parse_buft_overrides(const std::string& value, std::vector<llama_model_tens } return true; } +template<class T1, class T2> +std::vector<std::pair<T1,T2>> string_split_pairs(const std::string & str, char delim) { + std::vector<std::pair<T1,T2>> values; + std::istringstream str_stream(str); + std::string token; + T1 first_value; + int i = 0; + while (std::getline(str_stream, token, delim)) { + std::istringstream token_stream(token); + if (i%2 == 0) { + token_stream >> first_value; + } else { + T2 value; + token_stream >> value; + values.emplace_back(first_value, value); + } + i++; + } + return values; +} } #define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; } @@ -864,6 +884,17 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.fused_moe_up_gate = true; return true; } + if (arg == "-ser" || arg == "--smart-expert-reduction") { + CHECK_ARG + auto values = string_split_pairs<int,float>(argv[i], ','); + if (values.size() == 1) { + params.min_experts = values.front().first; + params.thresh_experts = values.front().second; + } else { + invalid_param = true; + } + return true; + } if (arg == "-co" || arg == "--color") { params.use_color = true; return true; @@ -1523,6 +1554,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn }); options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch}); options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); + options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts}); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" "(default: '%s')", params.prompt.c_str() }); @@ -2368,6 +2400,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.mla_attn = params.mla_attn; cparams.attn_max_batch = params.attn_max_batch; cparams.fused_moe_up_gate = params.fused_moe_up_gate; + cparams.min_experts = params.min_experts; + cparams.thresh_experts = params.thresh_experts; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); @@ -3368,6 +3402,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn); fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch); fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false"); + fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); |