summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
Diffstat (limited to 'common')
-rw-r--r--common/common.cpp55
-rw-r--r--common/common.h2
2 files changed, 57 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 8e6d74d0..4e823c52 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -690,6 +690,47 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
std::istreambuf_iterator<char>(),
std::back_inserter(sparams.grammar)
);
+ } else if (arg == "--override-kv") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ char * sep = strchr(argv[i], '=');
+ if (sep == nullptr || sep - argv[i] >= 128) {
+ fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]);
+ invalid_param = true;
+ break;
+ }
+ struct llama_model_kv_override kvo;
+ std::strncpy(kvo.key, argv[i], sep - argv[i]);
+ kvo.key[sep - argv[i]] = 0;
+ sep++;
+ if (strncmp(sep, "int:", 4) == 0) {
+ sep += 4;
+ kvo.tag = LLAMA_KV_OVERRIDE_INT;
+ kvo.int_value = std::atol(sep);
+ } else if (strncmp(sep, "float:", 6) == 0) {
+ sep += 6;
+ kvo.tag = LLAMA_KV_OVERRIDE_FLOAT;
+ kvo.float_value = std::atof(sep);
+ } else if (strncmp(sep, "bool:", 5) == 0) {
+ sep += 5;
+ kvo.tag = LLAMA_KV_OVERRIDE_BOOL;
+ if (std::strcmp(sep, "true") == 0) {
+ kvo.bool_value = true;
+ } else if (std::strcmp(sep, "false") == 0) {
+ kvo.bool_value = false;
+ } else {
+ fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
+ invalid_param = true;
+ break;
+ }
+ } else {
+ fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]);
+ invalid_param = true;
+ break;
+ }
+ params.kv_overrides.push_back(kvo);
#ifndef LOG_DISABLE_LOGS
// Parse args for logging parameters
} else if ( log_param_single_parse( argv[i] ) ) {
@@ -733,6 +774,11 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
}
}
+ if (!params.kv_overrides.empty()) {
+ params.kv_overrides.emplace_back(llama_model_kv_override());
+ params.kv_overrides.back().key[0] = 0;
+ }
+
return true;
}
@@ -864,6 +910,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" draft model for speculative decoding (default: %s)\n", params.model.c_str());
printf(" -ld LOGDIR, --logdir LOGDIR\n");
printf(" path under which to save YAML logs (no logging if unset)\n");
+ printf(" --override-kv KEY=TYPE:VALUE\n");
+ printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
+ printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
printf("\n");
#ifndef LOG_DISABLE_LOGS
log_print_usage();
@@ -956,6 +1005,12 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock;
+ if (params.kv_overrides.empty()) {
+ mparams.kv_overrides = NULL;
+ } else {
+ GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
+ mparams.kv_overrides = params.kv_overrides.data();
+ }
return mparams;
}
diff --git a/common/common.h b/common/common.h
index 534f7b13..02467938 100644
--- a/common/common.h
+++ b/common/common.h
@@ -86,6 +86,8 @@ struct gpt_params {
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::string logdir = ""; // directory in which to save YAML log files
+ std::vector<llama_model_kv_override> kv_overrides;
+
// TODO: avoid tuple, use struct
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
std::string lora_base = ""; // base model path for the lora adapter