From 7ee76e45afae7f9a7a53e93393accfb5b36684e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20L=C3=BCtke?= Date: Tue, 4 Jul 2023 10:05:27 -0400 Subject: Simple webchat for server (#1998) * expose simple web interface on root domain * embed index and add --path for choosing static dir * allow server to multithread because web browsers send a lot of garbage requests we want the server to multithread when serving 404s for favicon's etc. To avoid blowing up llama we just take a mutex when it's invoked. * let's try this with the xxd tool instead and see if msvc is happier with that * enable server in Makefiles * add /completion.js file to make it easy to use the server from js * slightly nicer css * rework state management into session, expose historyTemplate to settings --------- Co-authored-by: Georgi Gerganov --- examples/server/server.cpp | 69 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 62 insertions(+), 7 deletions(-) (limited to 'examples/server/server.cpp') diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 3bf98595..043e4975 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2,8 +2,6 @@ #include "llama.h" #include "build-info.h" -// single thread -#define CPPHTTPLIB_THREAD_POOL_COUNT 1 #ifndef NDEBUG // crash the server in debug mode, otherwise send an http 500 error #define CPPHTTPLIB_NO_EXCEPTIONS 1 @@ -12,6 +10,11 @@ #include "httplib.h" #include "json.hpp" +// auto generated files (update with ./deps.sh) +#include "index.html.hpp" +#include "index.js.hpp" +#include "completion.js.hpp" + #ifndef SERVER_VERBOSE #define SERVER_VERBOSE 1 #endif @@ -21,6 +24,7 @@ using json = nlohmann::json; struct server_params { std::string hostname = "127.0.0.1"; + std::string public_path = "examples/server/public"; int32_t port = 8080; int32_t read_timeout = 600; int32_t write_timeout = 600; @@ -172,6 +176,12 @@ struct llama_server_context { std::string stopping_word; int32_t multibyte_pending = 0; + std::mutex mutex; + + std::unique_lock lock() { + return std::unique_lock(mutex); + } + ~llama_server_context() { if (ctx) { llama_free(ctx); @@ -539,6 +549,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); fprintf(stderr, " --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port); + fprintf(stderr, " --path PUBLIC_PATH path from which to serve static files (default %s)\n", sparams.public_path.c_str()); fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout); fprintf(stderr, " --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled"); fprintf(stderr, "\n"); @@ -565,6 +576,12 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, break; } sparams.hostname = argv[i]; + } else if (arg == "--path") { + if (++i >= argc) { + invalid_param = true; + break; + } + sparams.public_path = argv[i]; } else if (arg == "--timeout" || arg == "-to") { if (++i >= argc) { invalid_param = true; @@ -839,17 +856,24 @@ static void parse_options_completion(const json & body, llama_server_context & l LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); } + static void log_server_request(const Request & req, const Response & res) { LOG_INFO("request", { { "remote_addr", req.remote_addr }, { "remote_port", req.remote_port }, { "status", res.status }, + { "method", req.method }, { "path", req.path }, + { "params", req.params }, + }); + + LOG_VERBOSE("request", { { "request", req.body }, { "response", res.body }, }); } + int main(int argc, char ** argv) { // own arguments required by this example gpt_params params; @@ -884,16 +908,34 @@ int main(int argc, char ** argv) { Server svr; svr.set_default_headers({ + { "Server", "llama.cpp" }, { "Access-Control-Allow-Origin", "*" }, { "Access-Control-Allow-Headers", "content-type" } }); + // this is only called if no index.js is found in the public --path + svr.Get("/index.js", [](const Request &, Response & res) { + res.set_content(reinterpret_cast(&index_js), index_js_len, "text/javascript"); + return false; + }); + + // this is only called if no index.html is found in the public --path svr.Get("/", [](const Request &, Response & res) { - res.set_content("

llama.cpp server works

", "text/html"); + res.set_content(reinterpret_cast(&index_html), index_html_len, "text/html"); + return false; + }); + + // this is only called if no index.html is found in the public --path + svr.Get("/completion.js", [](const Request &, Response & res) { + res.set_content(reinterpret_cast(&completion_js), completion_js_len, "application/javascript"); + return false; }); svr.Post("/completion", [&llama](const Request & req, Response & res) { + auto lock = llama.lock(); + llama.rewind(); + llama_reset_timings(llama.ctx); parse_options_completion(json::parse(req.body), llama); @@ -1002,6 +1044,8 @@ int main(int argc, char ** argv) { }); svr.Post("/tokenize", [&llama](const Request & req, Response & res) { + auto lock = llama.lock(); + const json body = json::parse(req.body); const std::string content = body.value("content", ""); const std::vector tokens = llama_tokenize(llama.ctx, content, false); @@ -1010,6 +1054,8 @@ int main(int argc, char ** argv) { }); svr.Post("/embedding", [&llama](const Request & req, Response & res) { + auto lock = llama.lock(); + const json body = json::parse(req.body); llama.rewind(); @@ -1040,18 +1086,27 @@ int main(int argc, char ** argv) { res.status = 500; }); + svr.set_error_handler([](const Request &, Response & res) { + res.set_content("File Not Found", "text/plain"); + res.status = 404; + }); + + // set timeouts and change hostname and port svr.set_read_timeout(sparams.read_timeout); svr.set_write_timeout(sparams.write_timeout); if (!svr.bind_to_port(sparams.hostname, sparams.port)) { - LOG_ERROR("couldn't bind to server socket", { - { "hostname", sparams.hostname }, - { "port", sparams.port }, - }); + fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port); return 1; } + // Set the base directory for serving static files + svr.set_base_dir(sparams.public_path); + + // to make it ctrl+clickable: + fprintf(stdout, "\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port); + LOG_INFO("HTTP server listening", { { "hostname", sparams.hostname }, { "port", sparams.port }, -- cgit v1.2.3