summaryrefslogtreecommitdiff
path: root/examples/server/server.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r--examples/server/server.cpp14
1 files changed, 13 insertions, 1 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index d2a8e541..cf075d6c 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -1,6 +1,7 @@
#include "utils.hpp"
#include "common.h"
+#include "json-schema-to-grammar.h"
#include "llama.h"
#include "grammar-parser.h"
@@ -178,6 +179,7 @@ struct server_slot {
llama_token sampled;
struct llama_sampling_params sparams;
llama_sampling_context * ctx_sampling = nullptr;
+ json json_schema;
int32_t ga_i = 0; // group-attention state
int32_t ga_n = 1; // group-attention factor
@@ -845,7 +847,17 @@ struct server_context {
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
slot.params.seed = json_value(data, "seed", default_params.seed);
- slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
+ if (data.contains("json_schema") && !data.contains("grammar")) {
+ try {
+ auto schema = json_value(data, "json_schema", json::object());
+ slot.sparams.grammar = json_schema_to_grammar(schema);
+ } catch (const std::exception & e) {
+ send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
+ return false;
+ }
+ } else {
+ slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
+ }
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);