summaryrefslogtreecommitdiff
path: root/examples/common.cpp
diff options
context:
space:
mode:
authorEvan Jones <evan.q.jones@gmail.com>2023-07-23 23:58:10 -0400
committerGitHub <noreply@github.com>2023-07-23 23:58:10 -0400
commit84e09a7d8bc4ab6d658b5cd81295ac0add60be78 (patch)
tree934c5480d917325ac8baa29f4edfae99137b56bb /examples/common.cpp
parent2f9cf974a066ac0e03fbb235d834b01b0164d743 (diff)
llama : add grammar-based sampling (#1773)
* llama, main : constrain sampling to grammar * allow loading grammar from file * fix whitespace errors * handle & print parser errors * add comments to grammar syntax and allow newlines where unambiguous * add missing include * support alternates in root rule * fix bugs with empty token and EOS * adjust JSON grammar * remove swp file * rewrite ternary expressions Co-authored-by: Henri Vasserman <henv@hot.ee> * use struct for grammar elements and add Unicode support * add unicode escapes * add inverse char ranges * only sample full tokens (no peeking or truncation) * llama : minor style changes blindly applied in online editor - hopefully I didn't break something * update help text * add warning message if EOS is disabled --------- Co-authored-by: Henri Vasserman <henv@hot.ee> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/common.cpp')
-rw-r--r--examples/common.cpp24
1 files changed, 24 insertions, 0 deletions
diff --git a/examples/common.cpp b/examples/common.cpp
index 7a1928f2..779605f9 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -438,6 +438,28 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.input_suffix = argv[i];
+ } else if (arg == "--grammar") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.grammar = argv[i];
+ } else if (arg == "--grammar-file") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ std::ifstream file(argv[i]);
+ if (!file) {
+ fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+ invalid_param = true;
+ break;
+ }
+ std::copy(
+ std::istreambuf_iterator<char>(file),
+ std::istreambuf_iterator<char>(),
+ std::back_inserter(params.grammar)
+ );
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
gpt_print_usage(argc, argv, default_params);
@@ -514,6 +536,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " modifies the likelihood of token appearing in the completion,\n");
fprintf(stdout, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
fprintf(stdout, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
+ fprintf(stdout, " --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n");
+ fprintf(stdout, " --grammar-file FNAME file to read grammar from\n");
fprintf(stdout, " --cfg-negative-prompt PROMPT \n");
fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n");
fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);