summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorKerfuffle <44031344+KerfuffleV2@users.noreply.github.com>2023-08-17 07:29:44 -0600
committerGitHub <noreply@github.com>2023-08-17 07:29:44 -0600
commit8dae7ce68437faf1fa96ec0e7687b8700956ef20 (patch)
tree01a7bbd89d5b930a4334185ed9e01cf13708aea6
parenta73ccf1aa34de49f61bfeb7f8a679c3bfdb3abe3 (diff)
Add --cfg-negative-prompt-file option for examples (#2591)
Add --cfg-negative-prompt-file option for examples
-rw-r--r--examples/common.cpp19
1 files changed, 18 insertions, 1 deletions
diff --git a/examples/common.cpp b/examples/common.cpp
index 9f8aab9a..bd39d922 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -274,6 +274,21 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.cfg_negative_prompt = argv[i];
+ } else if (arg == "--cfg-negative-prompt-file") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ std::ifstream file(argv[i]);
+ if (!file) {
+ fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+ invalid_param = true;
+ break;
+ }
+ std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.cfg_negative_prompt));
+ if (params.cfg_negative_prompt.back() == '\n') {
+ params.cfg_negative_prompt.pop_back();
+ }
} else if (arg == "--cfg-scale") {
if (++i >= argc) {
invalid_param = true;
@@ -567,8 +582,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
fprintf(stdout, " --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n");
fprintf(stdout, " --grammar-file FNAME file to read grammar from\n");
- fprintf(stdout, " --cfg-negative-prompt PROMPT \n");
+ fprintf(stdout, " --cfg-negative-prompt PROMPT\n");
fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n");
+ fprintf(stdout, " --cfg-negative-prompt-file FNAME\n");
+ fprintf(stdout, " negative prompt file to use for guidance. (default: empty)\n");
fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
fprintf(stdout, " --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale);
fprintf(stdout, " --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base);