diff options
author | Michael Coppola <m18coppola@gmail.com> | 2024-01-11 12:51:17 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-11 19:51:17 +0200 |
commit | 27379455c38cb13f24de92dbd6fcdd04eeb1b9d9 (patch) | |
tree | 150a38b7780600f8c7494357c0b174df338ee71c /examples/server/server.cpp | |
parent | eab67950068e4b125007d027232c47d2a5831cd0 (diff) |
server : support for multiple api keys (#4864)
* server: added support for multiple api keys, added loading api keys from file
* minor: fix whitespace
* added file error handling to --api-key-file, changed code to better
reflect current style
* server: update README.md for --api-key-file
---------
Co-authored-by: Michael Coppola <info@michaeljcoppola.com>
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r-- | examples/server/server.cpp | 36 |
1 files changed, 30 insertions, 6 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 51a4b689..345004fa 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -39,7 +39,7 @@ using json = nlohmann::json; struct server_params { std::string hostname = "127.0.0.1"; - std::string api_key; + std::vector<std::string> api_keys; std::string public_path = "examples/server/public"; int32_t port = 8080; int32_t read_timeout = 600; @@ -2021,6 +2021,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" --port PORT port to listen (default (default: %d)\n", sparams.port); printf(" --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n"); + printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n"); printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); printf(" --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel); @@ -2081,7 +2082,28 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, invalid_param = true; break; } - sparams.api_key = argv[i]; + sparams.api_keys.push_back(argv[i]); + } + else if (arg == "--api-key-file") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + std::ifstream key_file(argv[i]); + if (!key_file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + std::string key; + while (std::getline(key_file, key)) { + if (key.size() > 0) { + sparams.api_keys.push_back(key); + } + } + key_file.close(); } else if (arg == "--timeout" || arg == "-to") { @@ -2881,8 +2903,10 @@ int main(int argc, char **argv) log_data["hostname"] = sparams.hostname; log_data["port"] = std::to_string(sparams.port); - if (!sparams.api_key.empty()) { - log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4); + if (sparams.api_keys.size() == 1) { + log_data["api_key"] = "api_key: ****" + sparams.api_keys[0].substr(sparams.api_keys[0].length() - 4); + } else if (sparams.api_keys.size() > 1) { + log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded"; } LOG_INFO("HTTP server listening", log_data); @@ -2912,7 +2936,7 @@ int main(int argc, char **argv) // Middleware for API key validation auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { // If API key is not set, skip validation - if (sparams.api_key.empty()) { + if (sparams.api_keys.empty()) { return true; } @@ -2921,7 +2945,7 @@ int main(int argc, char **argv) std::string prefix = "Bearer "; if (auth_header.substr(0, prefix.size()) == prefix) { std::string received_api_key = auth_header.substr(prefix.size()); - if (received_api_key == sparams.api_key) { + if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) { return true; // API key is valid } } |