summaryrefslogtreecommitdiff
path: root/common/common.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'common/common.cpp')
-rw-r--r--common/common.cpp35
1 files changed, 35 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 5c9070da..e62944b9 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -322,6 +322,26 @@ bool parse_buft_overrides(const std::string& value, std::vector<llama_model_tens
}
return true;
}
+template<class T1, class T2>
+std::vector<std::pair<T1,T2>> string_split_pairs(const std::string & str, char delim) {
+ std::vector<std::pair<T1,T2>> values;
+ std::istringstream str_stream(str);
+ std::string token;
+ T1 first_value;
+ int i = 0;
+ while (std::getline(str_stream, token, delim)) {
+ std::istringstream token_stream(token);
+ if (i%2 == 0) {
+ token_stream >> first_value;
+ } else {
+ T2 value;
+ token_stream >> value;
+ values.emplace_back(first_value, value);
+ }
+ i++;
+ }
+ return values;
+}
}
#define CHECK_ARG if (++i >= argc) { invalid_param = true; return true; }
@@ -864,6 +884,17 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
params.fused_moe_up_gate = true;
return true;
}
+ if (arg == "-ser" || arg == "--smart-expert-reduction") {
+ CHECK_ARG
+ auto values = string_split_pairs<int,float>(argv[i], ',');
+ if (values.size() == 1) {
+ params.min_experts = values.front().first;
+ params.thresh_experts = values.front().second;
+ } else {
+ invalid_param = true;
+ }
+ return true;
+ }
if (arg == "-co" || arg == "--color") {
params.use_color = true;
return true;
@@ -1523,6 +1554,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn });
options.push_back({ "*", "-amb, --attention-max-batch", "max batch size for attention computations (default: %d)", params.attn_max_batch});
options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" });
+ options.push_back({ "*", "-ser, --smart-expert-reduction,","experts reduction (default: %d,%g)", params.min_experts, params.thresh_experts});
options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n"
"in conversation mode, this will be used as system prompt\n"
"(default: '%s')", params.prompt.c_str() });
@@ -2368,6 +2400,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.mla_attn = params.mla_attn;
cparams.attn_max_batch = params.attn_max_batch;
cparams.fused_moe_up_gate = params.fused_moe_up_gate;
+ cparams.min_experts = params.min_experts;
+ cparams.thresh_experts = params.thresh_experts;
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
@@ -3368,6 +3402,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn);
fprintf(stream, "attn_max_batch: %d # default: 0\n", params.attn_max_batch);
fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false");
+ fprintf(stream, "ser: %d,%g # defaulr: -1,0\n", params.min_experts, params.thresh_experts);
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());