summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
Diffstat (limited to 'common')
-rw-r--r--common/common.cpp18
-rw-r--r--common/common.h2
2 files changed, 20 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp
index eacaee18..6b4913a6 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -220,6 +220,20 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_ctx = std::stoi(argv[i]);
+ } else if (arg == "--grp-attn-n" || arg == "-gan") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+
+ params.grp_attn_n = std::stoi(argv[i]);
+ } else if (arg == "--grp-attn-w" || arg == "-gaw") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+
+ params.grp_attn_w = std::stoi(argv[i]);
} else if (arg == "--rope-freq-base") {
if (++i >= argc) {
invalid_param = true;
@@ -904,6 +918,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" Not recommended since this is both slower and uses more VRAM.\n");
#endif // GGML_USE_CUBLAS
#endif
+ printf(" -gan N, --grp-attn-n N\n");
+ printf(" group-attention factor (default: %d)\n", params.grp_attn_n);
+ printf(" -gat N, --grp-attn-w N\n");
+ printf(" group-attention width (default: %.1f)\n", (double)params.grp_attn_w);
printf(" --verbose-prompt print prompt before generation\n");
printf(" -dkvc, --dump-kv-cache\n");
printf(" verbose print of the KV cache\n");
diff --git a/common/common.h b/common/common.h
index 9659aa04..e2bbfc25 100644
--- a/common/common.h
+++ b/common/common.h
@@ -62,6 +62,8 @@ struct gpt_params {
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
int32_t n_beams = 0; // if non-zero then use beam search of given width.
+ int32_t grp_attn_n = 1; // group-attention factor
+ int32_t grp_attn_w = 512; // group-attention width
float rope_freq_base = 0.0f; // RoPE base frequency
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor