diff options
Diffstat (limited to 'common')
-rw-r--r-- | common/common.cpp | 7 | ||||
-rw-r--r-- | common/common.h | 2 |
2 files changed, 5 insertions, 4 deletions
diff --git a/common/common.cpp b/common/common.cpp index 464b4710..6359426f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -851,7 +851,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "-mla" || arg == "--mla-use") { - params.mla_attn = true; + CHECK_ARG + params.mla_attn = std::stoi(argv[i]); return true; } if (arg == "-fmoe" || arg == "--fused-moe") { @@ -1514,7 +1515,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep }); options.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks }); options.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" }); - options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %s)", params.mla_attn ? "enabled" : "disabled" }); + options.push_back({ "*", "-mla, --mla-use", "enable MLA (default: %d)", params.mla_attn }); options.push_back({ "*", "-fmoe, --fused-moe", "enable fused MoE (default: %s)", params.fused_moe_up_gate ? "enabled" : "disabled" }); options.push_back({ "*", "-p, --prompt PROMPT", "prompt to start generation with\n" "in conversation mode, this will be used as system prompt\n" @@ -3357,7 +3358,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); - fprintf(stream, "mla_attn: %s # default: false\n", params.mla_attn ? "true" : "false"); + fprintf(stream, "mla_attn: %d # default: 0\n", params.mla_attn); fprintf(stream, "fused_moe: %s # default: false\n", params.fused_moe_up_gate ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); diff --git a/common/common.h b/common/common.h index 152fd1cf..ef5175f3 100644 --- a/common/common.h +++ b/common/common.h @@ -175,7 +175,7 @@ struct gpt_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention - bool mla_attn = false; // MLA + int mla_attn = false; // MLA 0: standard attention, 1: MLA with K and transposed V cache, 2: MLA with just K cache bool fused_moe_up_gate = false; // fused up*unary(gate) op for MoE models bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix |