summaryrefslogtreecommitdiff
path: root/examples/server/server.cpp
diff options
context:
space:
mode:
authorLaura <Tijntje_7@msn.com>2024-01-11 19:02:48 +0100
committerGitHub <noreply@github.com>2024-01-11 20:02:48 +0200
commit4330bd83feb39683de4bd7a34cfcf672ff8ac3e4 (patch)
tree22c6f63555d442b93cd84fde09868ba47c91cdaf /examples/server/server.cpp
parent27379455c38cb13f24de92dbd6fcdd04eeb1b9d9 (diff)
server : implement credentialed CORS (#4514)
* Implement credentialed CORS according to MDN * Fix syntax error * Move validate_api_key up so it is defined before its first usage
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r--examples/server/server.cpp26
1 files changed, 20 insertions, 6 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 345004fa..031824e1 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -2822,9 +2822,15 @@ int main(int argc, char **argv)
std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
- svr.set_default_headers({{"Server", "llama.cpp"},
- {"Access-Control-Allow-Origin", "*"},
- {"Access-Control-Allow-Headers", "content-type"}});
+ svr.set_default_headers({{"Server", "llama.cpp"}});
+
+ // CORS preflight
+ svr.Options(R"(.*)", [](const httplib::Request &req, httplib::Response &res) {
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
+ res.set_header("Access-Control-Allow-Credentials", "true");
+ res.set_header("Access-Control-Allow-Methods", "POST");
+ res.set_header("Access-Control-Allow-Headers", "*");
+ });
svr.Get("/health", [&](const httplib::Request&, httplib::Response& res) {
server_state current_state = state.load();
@@ -2987,9 +2993,9 @@ int main(int argc, char **argv)
return false;
});
- svr.Get("/props", [&llama](const httplib::Request & /*req*/, httplib::Response &res)
+ svr.Get("/props", [&llama](const httplib::Request & req, httplib::Response &res)
{
- res.set_header("Access-Control-Allow-Origin", "*");
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = {
{ "user_name", llama.name_user.c_str() },
{ "assistant_name", llama.name_assistant.c_str() }
@@ -2999,6 +3005,7 @@ int main(int argc, char **argv)
svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
{
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
}
@@ -3066,8 +3073,9 @@ int main(int argc, char **argv)
}
});
- svr.Get("/v1/models", [&params](const httplib::Request&, httplib::Response& res)
+ svr.Get("/v1/models", [&params](const httplib::Request& req, httplib::Response& res)
{
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
std::time_t t = std::time(0);
json models = {
@@ -3085,9 +3093,11 @@ int main(int argc, char **argv)
res.set_content(models.dump(), "application/json; charset=utf-8");
});
+
// TODO: add mount point without "/v1" prefix -- how?
svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
{
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
}
@@ -3161,6 +3171,7 @@ int main(int argc, char **argv)
svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
{
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!validate_api_key(req, res)) {
return;
}
@@ -3233,6 +3244,7 @@ int main(int argc, char **argv)
svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res)
{
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
std::vector<llama_token> tokens;
if (body.count("content") != 0)
@@ -3245,6 +3257,7 @@ int main(int argc, char **argv)
svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res)
{
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
std::string content;
if (body.count("tokens") != 0)
@@ -3259,6 +3272,7 @@ int main(int argc, char **argv)
svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res)
{
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
json prompt;
if (body.count("content") != 0)