summaryrefslogtreecommitdiff
path: root/common/common.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'common/common.cpp')
-rw-r--r--common/common.cpp15
1 files changed, 15 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp
index dda51478..52576cba 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -1,4 +1,6 @@
#include "common.h"
+#include "json.hpp"
+#include "json-schema-to-grammar.h"
#include "llama.h"
#include <algorithm>
@@ -68,6 +70,8 @@
#define LLAMA_CURL_MAX_HEADER_LENGTH 256
#endif // LLAMA_USE_CURL
+using json = nlohmann::ordered_json;
+
int32_t get_num_physical_cores() {
#ifdef __linux__
// enumerate the set of thread siblings, num entries is num cores
@@ -1148,6 +1152,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
);
return true;
}
+ if (arg == "-j" || arg == "--json-schema") {
+ if (++i >= argc) {
+ invalid_param = true;
+ return true;
+ }
+ sparams.grammar = json_schema_to_grammar(json::parse(argv[i]));
+ return true;
+ }
if (arg == "--override-kv") {
if (++i >= argc) {
invalid_param = true;
@@ -1353,6 +1365,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
printf(" --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n");
printf(" --grammar-file FNAME file to read grammar from\n");
+ printf(" -j SCHEMA, --json-schema SCHEMA\n");
+ printf(" JSON schema to constrain generations (https://json-schema.org/), e.g. `{}` for any JSON object.\n");
+ printf(" For schemas w/ external $refs, use --grammar + example/json_schema_to_grammar.py instead\n");
printf(" --cfg-negative-prompt PROMPT\n");
printf(" negative prompt to use for guidance. (default: empty)\n");
printf(" --cfg-negative-prompt-file FNAME\n");