summaryrefslogtreecommitdiff
path: root/common
diff options
context:
space:
mode:
Diffstat (limited to 'common')
-rw-r--r--common/common.cpp88
-rw-r--r--common/common.h2
2 files changed, 51 insertions, 39 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 9ab8aef7..d42fa131 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -234,8 +234,54 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
return result;
}
+bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides) {
+ const char * sep = strchr(data, '=');
+ if (sep == nullptr || sep - data >= 128) {
+ fprintf(stderr, "%s: malformed KV override '%s'\n", __func__, data);
+ return false;
+ }
+ llama_model_kv_override kvo;
+ std::strncpy(kvo.key, data, sep - data);
+ kvo.key[sep - data] = 0;
+ sep++;
+ if (strncmp(sep, "int:", 4) == 0) {
+ sep += 4;
+ kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
+ kvo.val_i64 = std::atol(sep);
+ } else if (strncmp(sep, "float:", 6) == 0) {
+ sep += 6;
+ kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
+ kvo.val_f64 = std::atof(sep);
+ } else if (strncmp(sep, "bool:", 5) == 0) {
+ sep += 5;
+ kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
+ if (std::strcmp(sep, "true") == 0) {
+ kvo.val_bool = true;
+ } else if (std::strcmp(sep, "false") == 0) {
+ kvo.val_bool = false;
+ } else {
+ fprintf(stderr, "%s: invalid boolean value for KV override '%s'\n", __func__, data);
+ return false;
+ }
+ } else if (strncmp(sep, "str:", 4) == 0) {
+ sep += 4;
+ kvo.tag = LLAMA_KV_OVERRIDE_TYPE_STR;
+ if (strlen(sep) > 127) {
+ fprintf(stderr, "%s: malformed KV override '%s', value cannot exceed 127 chars\n", __func__, data);
+ return false;
+ }
+ strncpy(kvo.val_str, sep, 127);
+ kvo.val_str[127] = '\0';
+ } else {
+ fprintf(stderr, "%s: invalid type for KV override '%s'\n", __func__, data);
+ return false;
+ }
+ overrides.emplace_back(std::move(kvo));
+ return true;
+}
+
bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) {
- llama_sampling_params& sparams = params.sparams;
+ llama_sampling_params & sparams = params.sparams;
if (arg == "-s" || arg == "--seed") {
if (++i >= argc) {
@@ -1244,47 +1290,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
invalid_param = true;
return true;
}
- char* sep = strchr(argv[i], '=');
- if (sep == nullptr || sep - argv[i] >= 128) {
- fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]);
- invalid_param = true;
- return true;
- }
- struct llama_model_kv_override kvo;
- std::strncpy(kvo.key, argv[i], sep - argv[i]);
- kvo.key[sep - argv[i]] = 0;
- sep++;
- if (strncmp(sep, "int:", 4) == 0) {
- sep += 4;
- kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
- kvo.int_value = std::atol(sep);
- }
- else if (strncmp(sep, "float:", 6) == 0) {
- sep += 6;
- kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
- kvo.float_value = std::atof(sep);
- }
- else if (strncmp(sep, "bool:", 5) == 0) {
- sep += 5;
- kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
- if (std::strcmp(sep, "true") == 0) {
- kvo.bool_value = true;
- }
- else if (std::strcmp(sep, "false") == 0) {
- kvo.bool_value = false;
- }
- else {
- fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
- invalid_param = true;
- return true;
- }
- }
- else {
+ if (!parse_kv_override(argv[i], params.kv_overrides)) {
fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]);
invalid_param = true;
return true;
}
- params.kv_overrides.push_back(kvo);
return true;
}
#ifndef LOG_DISABLE_LOGS
@@ -1555,7 +1565,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" path to dynamic lookup cache to use for lookup decoding (updated by generation)\n");
printf(" --override-kv KEY=TYPE:VALUE\n");
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
- printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
+ printf(" types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
printf(" -ptc N, --print-token-count N\n");
printf(" print token count every N tokens (default: %d)\n", params.n_print);
printf(" --check-tensors check model tensor data for invalid values\n");
diff --git a/common/common.h b/common/common.h
index 0005f143..96a28a6c 100644
--- a/common/common.h
+++ b/common/common.h
@@ -171,6 +171,8 @@ struct gpt_params {
std::string image = ""; // path to an image file
};
+bool parse_kv_override(const char * data, std::vector<llama_model_kv_override> & overrides);
+
bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params);
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);