diff options
Diffstat (limited to 'examples/rpc/rpc-server.cpp')
-rw-r--r-- | examples/rpc/rpc-server.cpp | 68 |
1 files changed, 57 insertions, 11 deletions
diff --git a/examples/rpc/rpc-server.cpp b/examples/rpc/rpc-server.cpp index 496af849..021185b8 100644 --- a/examples/rpc/rpc-server.cpp +++ b/examples/rpc/rpc-server.cpp @@ -10,6 +10,52 @@ #include <string> #include <stdio.h> +struct rpc_server_params { + std::string host = "0.0.0.0"; + int port = 50052; + size_t backend_mem = 0; +}; + +static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) { + fprintf(stderr, "Usage: %s [options]\n\n", argv[0]); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str()); + fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port); + fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n"); + fprintf(stderr, "\n"); +} + +static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & params) { + std::string arg; + for (int i = 1; i < argc; i++) { + arg = argv[i]; + if (arg == "-H" || arg == "--host") { + if (++i >= argc) { + return false; + } + params.host = argv[i]; + } else if (arg == "-p" || arg == "--port") { + if (++i >= argc) { + return false; + } + params.port = std::stoi(argv[i]); + if (params.port <= 0 || params.port > 65535) { + return false; + } + } else if (arg == "-m" || arg == "--mem") { + if (++i >= argc) { + return false; + } + params.backend_mem = std::stoul(argv[i]) * 1024 * 1024; + } else if (arg == "-h" || arg == "--help") { + print_usage(argc, argv, params); + exit(0); + } + } + return true; +} + static ggml_backend_t create_backend() { ggml_backend_t backend = NULL; #ifdef GGML_USE_CUDA @@ -45,14 +91,9 @@ static void get_backend_memory(size_t * free_mem, size_t * total_mem) { } int main(int argc, char * argv[]) { - if (argc < 3) { - fprintf(stderr, "Usage: %s <host> <port>\n", argv[0]); - return 1; - } - const char * host = argv[1]; - int port = std::stoi(argv[2]); - if (port <= 0 || port > 65535) { - fprintf(stderr, "Invalid port number: %d\n", port); + rpc_server_params params; + if (!rpc_server_params_parse(argc, argv, params)) { + fprintf(stderr, "Invalid parameters\n"); return 1; } ggml_backend_t backend = create_backend(); @@ -60,10 +101,15 @@ int main(int argc, char * argv[]) { fprintf(stderr, "Failed to create backend\n"); return 1; } - printf("Starting RPC server on %s:%d\n", host, port); + std::string endpoint = params.host + ":" + std::to_string(params.port); size_t free_mem, total_mem; - get_backend_memory(&free_mem, &total_mem); - std::string endpoint = std::string(host) + ":" + std::to_string(port); + if (params.backend_mem > 0) { + free_mem = params.backend_mem; + total_mem = params.backend_mem; + } else { + get_backend_memory(&free_mem, &total_mem); + } + printf("Starting RPC server on %s, backend memory: %zu MB\n", endpoint.c_str(), free_mem / (1024 * 1024)); start_rpc_server(backend, endpoint.c_str(), free_mem, total_mem); ggml_backend_free(backend); return 0; |