summaryrefslogtreecommitdiff
path: root/common/json-schema-to-grammar.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'common/json-schema-to-grammar.cpp')
-rw-r--r--common/json-schema-to-grammar.cpp351
1 files changed, 335 insertions, 16 deletions
diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp
index 10b9b3d1..881eb49e 100644
--- a/common/json-schema-to-grammar.cpp
+++ b/common/json-schema-to-grammar.cpp
@@ -40,6 +40,233 @@ static std::string build_repetition(const std::string & item_rule, int min_items
return result;
}
+/* Minimalistic replacement for std::string_view, which is only available from C++17 onwards */
+class string_view {
+ const std::string & _str;
+ const size_t _start;
+ const size_t _end;
+public:
+ string_view(const std::string & str, size_t start = 0, size_t end = std::string::npos) : _str(str), _start(start), _end(end == std::string::npos ? str.length() : end) {}
+
+ size_t size() const {
+ return _end - _start;
+ }
+
+ size_t length() const {
+ return size();
+ }
+
+ operator std::string() const {
+ return str();
+ }
+
+ std::string str() const {
+ return _str.substr(_start, _end - _start);
+ }
+
+ string_view substr(size_t pos, size_t len = std::string::npos) const {
+ return string_view(_str, _start + pos, len == std::string::npos ? _end : _start + pos + len);
+ }
+
+ char operator[](size_t pos) const {
+ auto index = _start + pos;
+ if (index >= _end) {
+ throw std::out_of_range("string_view index out of range");
+ }
+ return _str[_start + pos];
+ }
+
+ bool operator==(const string_view & other) const {
+ std::string this_str = *this;
+ std::string other_str = other;
+ return this_str == other_str;
+ }
+};
+
+static void _build_min_max_int(int min_value, int max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) {
+ auto has_min = min_value != std::numeric_limits<int>::min();
+ auto has_max = max_value != std::numeric_limits<int>::max();
+
+ auto digit_range = [&](char from, char to) {
+ out << "[";
+ if (from == to) {
+ out << from;
+ } else {
+ out << from << "-" << to;
+ }
+ out << "]";
+ };
+ auto more_digits = [&](int min_digits, int max_digits) {
+ out << "[0-9]";
+ if (min_digits == max_digits && min_digits == 1) {
+ return;
+ }
+ out << "{";
+ out << min_digits;
+ if (max_digits != min_digits) {
+ out << ",";
+ if (max_digits != std::numeric_limits<int>::max()) {
+ out << max_digits;
+ }
+ }
+ out << "}";
+ };
+ std::function<void(const string_view &, const string_view &)> uniform_range =
+ [&](const string_view & from, const string_view & to) {
+ size_t i = 0;
+ while (i < from.length() && i < to.length() && from[i] == to[i]) {
+ i++;
+ }
+ if (i > 0) {
+ out << "\"" << from.substr(0, i).str() << "\"";
+ }
+ if (i < from.length() && i < to.length()) {
+ if (i > 0) {
+ out << " ";
+ }
+ auto sub_len = from.length() - i - 1;
+ if (sub_len > 0) {
+ auto from_sub = from.substr(i + 1);
+ auto to_sub = to.substr(i + 1);
+ auto sub_zeros = repeat("0", sub_len);
+ auto sub_nines = repeat("9", sub_len);
+
+ auto to_reached = false;
+ out << "(";
+ if (from_sub == sub_zeros) {
+ digit_range(from[i], to[i] - 1);
+ out << " ";
+ more_digits(sub_len, sub_len);
+ } else {
+ out << "[" << from[i] << "] ";
+ out << "(";
+ uniform_range(from_sub, sub_nines);
+ out << ")";
+ if (from[i] < to[i] - 1) {
+ out << " | ";
+ if (to_sub == sub_nines) {
+ digit_range(from[i] + 1, to[i]);
+ to_reached = true;
+ } else {
+ digit_range(from[i] + 1, to[i] - 1);
+ }
+ out << " ";
+ more_digits(sub_len, sub_len);
+ }
+ }
+ if (!to_reached) {
+ out << " | ";
+ digit_range(to[i], to[i]);
+ out << " ";
+ uniform_range(sub_zeros, to_sub);
+ }
+ out << ")";
+ } else {
+ out << "[" << from[i] << "-" << to[i] << "]";
+ }
+ }
+ };
+
+ if (has_min && has_max) {
+ if (min_value < 0 && max_value < 0) {
+ out << "\"-\" (";
+ _build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true);
+ out << ")";
+ return;
+ }
+
+ if (min_value < 0) {
+ out << "\"-\" (";
+ _build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true);
+ out << ") | ";
+ min_value = 0;
+ }
+
+ auto min_s = std::to_string(min_value);
+ auto max_s = std::to_string(max_value);
+ auto min_digits = min_s.length();
+ auto max_digits = max_s.length();
+
+ for (auto digits = min_digits; digits < max_digits; digits++) {
+ uniform_range(min_s, repeat("9", digits));
+ min_s = "1" + repeat("0", digits);
+ out << " | ";
+ }
+ uniform_range(min_s, max_s);
+ return;
+ }
+
+ auto less_decimals = std::max(decimals_left - 1, 1);
+
+ if (has_min) {
+ if (min_value < 0) {
+ out << "\"-\" (";
+ _build_min_max_int(std::numeric_limits<int>::min(), -min_value, out, decimals_left, /* top_level= */ false);
+ out << ") | [0] | [1-9] ";
+ more_digits(0, decimals_left - 1);
+ } else if (min_value == 0) {
+ if (top_level) {
+ out << "[0] | [1-9] ";
+ more_digits(0, less_decimals);
+ } else {
+ more_digits(1, decimals_left);
+ }
+ } else if (min_value <= 9) {
+ char c = '0' + min_value;
+ auto range_start = top_level ? '1' : '0';
+ if (c > range_start) {
+ digit_range(range_start, c - 1);
+ out << " ";
+ more_digits(1, less_decimals);
+ out << " | ";
+ }
+ digit_range(c, '9');
+ out << " ";
+ more_digits(0, less_decimals);
+ } else {
+ auto min_s = std::to_string(min_value);
+ auto len = min_s.length();
+ auto c = min_s[0];
+
+ if (c > '1') {
+ digit_range(top_level ? '1' : '0', c - 1);
+ out << " ";
+ more_digits(len, less_decimals);
+ out << " | ";
+ }
+ digit_range(c, c);
+ out << " (";
+ _build_min_max_int(std::stoi(min_s.substr(1)), std::numeric_limits<int>::max(), out, less_decimals, /* top_level= */ false);
+ out << ")";
+ if (c < '9') {
+ out << " | ";
+ digit_range(c + 1, '9');
+ out << " ";
+ more_digits(len - 1, less_decimals);
+ }
+ }
+ return;
+ }
+
+ if (has_max) {
+ if (max_value >= 0) {
+ if (top_level) {
+ out << "\"-\" [1-9] ";
+ more_digits(0, less_decimals);
+ out << " | ";
+ }
+ _build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true);
+ } else {
+ out << "\"-\" (";
+ _build_min_max_int(-max_value, std::numeric_limits<int>::max(), out, decimals_left, /* top_level= */ false);
+ out << ")";
+ }
+ return;
+ }
+
+ throw std::runtime_error("At least one of min_value or max_value must be set");
+}
+
const std::string SPACE_RULE = "| \" \" | \"\\n\" [ \\t]{0,20}";
struct BuiltinRule {
@@ -89,7 +316,7 @@ std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
};
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
-std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
+std::unordered_set<char> ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'};
template <typename Iterator>
std::string join(Iterator begin, Iterator end, const std::string & separator) {
@@ -160,7 +387,6 @@ static std::string format_literal(const std::string & literal) {
return "\"" + escaped + "\"";
}
-
class SchemaConverter {
private:
std::function<json(const std::string &)> _fetch_json;
@@ -388,6 +614,75 @@ private:
return _add_rule(name, "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space");
}
+ /*
+ Returns a rule that matches a JSON string that is none of the provided strings
+
+ not_strings({"a"})
+ -> ["] ( [a] char+ | [^"a] char* )? ["] space
+ not_strings({"and", "also"})
+ -> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space
+ */
+ std::string _not_strings(const std::vector<std::string> & strings) {
+
+ struct TrieNode {
+ std::map<char, TrieNode> children;
+ bool is_end_of_string;
+
+ TrieNode() : is_end_of_string(false) {}
+
+ void insert(const std::string & string) {
+ auto node = this;
+ for (char c : string) {
+ node = &node->children[c];
+ }
+ node->is_end_of_string = true;
+ }
+ };
+
+ TrieNode trie;
+ for (const auto & s : strings) {
+ trie.insert(s);
+ }
+
+ std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char"));
+ std::ostringstream out;
+ out << "[\"] ( ";
+ std::function<void(const TrieNode &)> visit = [&](const TrieNode & node) {
+ std::ostringstream rejects;
+ auto first = true;
+ for (const auto & kv : node.children) {
+ rejects << kv.first;
+ if (first) {
+ first = false;
+ } else {
+ out << " | ";
+ }
+ out << "[" << kv.first << "]";
+ if (!kv.second.children.empty()) {
+ out << " (";
+ visit(kv.second);
+ out << ")";
+ } else if (kv.second.is_end_of_string) {
+ out << " " << char_rule << "+";
+ }
+ }
+ if (!node.children.empty()) {
+ if (!first) {
+ out << " | ";
+ }
+ out << "[^\"" << rejects.str() << "] " << char_rule << "*";
+ }
+ };
+ visit(trie);
+
+ out << " )";
+ if (!trie.is_end_of_string) {
+ out << "?";
+ }
+ out << " [\"] space";
+ return out.str();
+ }
+
std::string _resolve_ref(const std::string & ref) {
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
@@ -408,6 +703,7 @@ private:
std::vector<std::string> required_props;
std::vector<std::string> optional_props;
std::unordered_map<std::string, std::string> prop_kv_rule_names;
+ std::vector<std::string> prop_names;
for (const auto & kv : properties) {
const auto &prop_name = kv.first;
const auto &prop_schema = kv.second;
@@ -422,11 +718,18 @@ private:
} else {
optional_props.push_back(prop_name);
}
+ prop_names.push_back(prop_name);
}
- if (additional_properties.is_object() || (additional_properties.is_boolean() && additional_properties.get<bool>())) {
+ if ((additional_properties.is_boolean() && additional_properties.get<bool>()) || additional_properties.is_object()) {
std::string sub_name = name + (name.empty() ? "" : "-") + "additional";
- std::string value_rule = visit(additional_properties.is_object() ? additional_properties : json::object(), sub_name + "-value");
- std::string kv_rule = _add_rule(sub_name + "-kv", _add_primitive("string", PRIMITIVE_RULES.at("string")) + " \":\" space " + value_rule);
+ std::string value_rule =
+ additional_properties.is_object() ? visit(additional_properties, sub_name + "-value")
+ : _add_primitive("value", PRIMITIVE_RULES.at("value"));
+
+ auto key_rule =
+ prop_names.empty() ? _add_primitive("string", PRIMITIVE_RULES.at("string"))
+ : _add_rule(sub_name + "-k", _not_strings(prop_names));
+ std::string kv_rule = _add_rule(sub_name + "-kv", key_rule + " \":\" space " + value_rule);
prop_kv_rule_names["*"] = kv_rule;
optional_props.push_back("*");
}
@@ -452,15 +755,11 @@ private:
}
std::string k = ks[0];
std::string kv_rule_name = prop_kv_rule_names[k];
- if (k == "*") {
- res = _add_rule(
- name + (name.empty() ? "" : "-") + "additional-kvs",
- kv_rule_name + " ( \",\" space " + kv_rule_name + " )*"
- );
- } else if (first_is_optional) {
- res = "( \",\" space " + kv_rule_name + " )?";
+ std::string comma_ref = "( \",\" space " + kv_rule_name + " )";
+ if (first_is_optional) {
+ res = comma_ref + (k == "*" ? "*" : "?");
} else {
- res = kv_rule_name;
+ res = kv_rule_name + (k == "*" ? " " + comma_ref + "*" : "");
}
if (ks.size() > 1) {
res += " " + _add_rule(
@@ -594,17 +893,19 @@ public:
} else if (schema_type.is_array()) {
std::vector<json> schema_types;
for (const auto & t : schema_type) {
- schema_types.push_back({{"type", t}});
+ json schema_copy(schema);
+ schema_copy["type"] = t;
+ schema_types.push_back(schema_copy);
}
return _add_rule(rule_name, _generate_union_rule(name, schema_types));
} else if (schema.contains("const")) {
- return _add_rule(rule_name, _generate_constant_rule(schema["const"]));
+ return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space");
} else if (schema.contains("enum")) {
std::vector<std::string> enum_values;
for (const auto & v : schema["enum"]) {
enum_values.push_back(_generate_constant_rule(v));
}
- return _add_rule(rule_name, join(enum_values.begin(), enum_values.end(), " | "));
+ return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space");
} else if ((schema_type.is_null() || schema_type == "object")
&& (schema.contains("properties") ||
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
@@ -686,6 +987,24 @@ public:
int min_len = schema.contains("minLength") ? schema["minLength"].get<int>() : 0;
int max_len = schema.contains("maxLength") ? schema["maxLength"].get<int>() : std::numeric_limits<int>::max();
return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space");
+ } else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) {
+ int min_value = std::numeric_limits<int>::min();
+ int max_value = std::numeric_limits<int>::max();
+ if (schema.contains("minimum")) {
+ min_value = schema["minimum"].get<int>();
+ } else if (schema.contains("exclusiveMinimum")) {
+ min_value = schema["exclusiveMinimum"].get<int>() + 1;
+ }
+ if (schema.contains("maximum")) {
+ max_value = schema["maximum"].get<int>();
+ } else if (schema.contains("exclusiveMaximum")) {
+ max_value = schema["exclusiveMaximum"].get<int>() - 1;
+ }
+ std::stringstream out;
+ out << "(";
+ _build_min_max_int(min_value, max_value, out);
+ out << ") space";
+ return _add_rule(rule_name, out.str());
} else if (schema.empty() || schema_type == "object") {
return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object")));
} else {