diff options
author | firecoperana <xuqiaowei1124@gmail.com> | 2025-06-08 14:27:00 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-08 17:27:00 +0300 |
commit | 58f08e43859a942dcc4d585f04b729eb50603264 (patch) | |
tree | e1f1370970eb4f871c69468a83d96c2c216a91e5 /examples/rpc/rpc-server.cpp | |
parent | 1eabdb420b3b7b8464bb2b44d9e797b141a580f6 (diff) |
Fix non rpc build error (#506)
* Add RPC backend in device list to override tensors.
* rpc : prevent crashes on invalid input (#9040)
Add more checks which prevent RPC server from crashing if invalid input
is received from client
# Conflicts:
# ggml/src/ggml-rpc.cpp
* rpc : print error message when failed to connect endpoint (#9042)
* Fix RPC error
* Add vulkan, sycl to rpc backend
* add thread in rpc cpu backend
* add cache folder and other improvement in rpc
* add header file
* support for models with non-512 aligned tensors
* rpc : do not wait for response when sending RPC_CMD_SET_TENSOR (#12943)
RPC_CMD_SET_TENSOR always returns an empty response and we send this 4
times per token. We can improve TG speed if we don't wait for this empty
response.
The performance impact of this change depends on the network latency.
# Conflicts:
# ggml/src/ggml-rpc.cpp
* fix(rpc): Improve input validation and error handling (#13069)
* fix(rpc): Improve input validation and error handling
The `rpc-server` was vulnerable to Denial of Service attacks via
several RPC commands (`SET_TENSOR`, `GRAPH_COMPUTE`, etc.). Malformed
messages could trigger failed assertions (e.g., invalid `ggml_type`)
or out-of-bounds reads/writes leading to `GGML_ABORT` calls,
crashing the server process.
This PR introduces robust input validation and replaces `abort()`
calls with graceful error handling:
- **Type Validation:** `deserialize_tensor` now checks if the
`tensor->type` is within the valid `GGML_TYPE_COUNT` range
*before* calling `ggml_new_tensor_4d`. Returns `nullptr` on
invalid type.
- **Bounds Checks:** Replaced `GGML_ABORT` in `set_tensor`,
`set_tensor_hash`, and `get_tensor` handlers with error
logging and returning `false` when data/offset parameters
are out of buffer bounds.
- **Size Checks:** Added safe arithmetic checks (for overflow) in
`graph_compute` when calculating required message sizes based
on client-provided `n_nodes` and `n_tensors`. Returns early
if the reported sizes conflict with the actual message size or
would lead to overflow.
- **Error Propagation:**
- `create_node` now checks for `nullptr` return values from
`deserialize_tensor` and its recursive calls, propagating
`nullptr` upwards on failure. Uses `find` instead of `at`
for safer map access.
- `copy_tensor` now checks for `nullptr` from `deserialize_tensor`
and sets the response status to failure if deserialization
or bounds checks fail.
- `graph_compute` now checks for `nullptr` return from
`create_node` and returns failure status correctly. The final
return value now reflects the actual computation status.
These changes improve the RPC server's resilience
against malformed client requests, preventing crashes and ensuring
errors are handled more gracefully.
Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
* refactor(rpc): address pr comments
removed comments and unnecessary returns
Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
* refactor(rpc): ambiguous nullptr from create_node
rpc_server::create_node could previously return nullptr if the input ID
was 0 (valid) or if an internal error (deserialization, recursion
failure) occurred (invalid). This ambiguity made error handling
difficult for the caller (`graph_compute`).
This commit clarifies the meaning of nullptr:
- `graph_compute` now checks if the input 'id' was non-zero when
`create_node` returns nullptr, correctly identifying failures
versus intentional null links.
- `create_node` avoids recursive calls for zero IDs and propagates
nullptr unambiguously on failure during recursion.
Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
* refactor(rpc): initial zero check in create_node
The caller (`graph_compute`) already checks `id != 0` when handling
a `nullptr` return from `create_node`, correctly distinguishing
intentional null links from actual errors. This makes the initial
`if (id == 0)` check redundant.
Also removes the log message when a tensor ID is not found in the
provided map which was added in this branch.
Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
* fix(rpc): Handle get_alloc_size failure in server
Check the return value of `server.get_alloc_size` in the RPC server
loop. If the call fails, return early to close the connection.
Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
* refactor(rpc): input size validation in graph_compute
Removes detailed, step-by-step size calculations and overflow
checks in favor of simpler direct comparisons, assuming 64-bit
overflow is unlikely.
Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
* refactor(rpc): remove extra status code setting
Removes the explicit setting of `response.result = GGML_STATUS_FAILED`
when `create_node` returns `nullptr` within `graph_compute`.
Primary signal is the `false` return value in case of failure.
Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
* refactor(rpc): remove redundant check for tensor->type
Breaks CI on ubuntu-cpu-make. Tensor type is uint32_t, thus
the check is not needed.
Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
---------
Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
# Conflicts:
# ggml/src/ggml-rpc.cpp
* rpc : fix cache directory initialization (#13188)
Signed-off-by: xiaofei <hbuxiaofei@gmail.com>
# Conflicts:
# examples/rpc/rpc-server.cpp
* rpc : avoid uninitialized memory in serialize_tensor (#13210)
Zero out the name and padding buffers.
* fix merge error
* Add hello command in RPC
* bug fix
* add rpc header
* fix bug for missing rpc names
* add tpc no delay for rpc
* add back webui
* fix rpc function not found error
---------
Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
Signed-off-by: xiaofei <hbuxiaofei@gmail.com>
Co-authored-by: firecoperana <firecoperana>
Co-authored-by: Radoslav Gerganov <rgerganov@gmail.com>
Co-authored-by: matt23456 <matt23456>
Co-authored-by: Ville Vesilehto <ville@vesilehto.fi>
Co-authored-by: xiaofei <hbuxiaofei@gmail.com>
Co-authored-by: Justin Santa Barbara <justinsb@google.com>
Diffstat (limited to 'examples/rpc/rpc-server.cpp')
-rw-r--r-- | examples/rpc/rpc-server.cpp | 214 |
1 files changed, 199 insertions, 15 deletions
diff --git a/examples/rpc/rpc-server.cpp b/examples/rpc/rpc-server.cpp index 6342e648..943c1b1c 100644 --- a/examples/rpc/rpc-server.cpp +++ b/examples/rpc/rpc-server.cpp @@ -5,33 +5,166 @@ #ifdef GGML_USE_METAL #include "ggml-metal.h" #endif +#ifdef GGML_USE_VULKAN +#include "ggml-vulkan.h" +#endif +#ifdef GGML_USE_SYCL +#include "ggml-sycl.h" +#endif #include "ggml-rpc.h" #ifdef _WIN32 +# define DIRECTORY_SEPARATOR '\\' +# define NOMINMAX +# include <locale> # include <windows.h> +# include <fcntl.h> +# include <io.h> #else +# define DIRECTORY_SEPARATOR '/' # include <unistd.h> +# include <sys/stat.h> #endif #include <string> #include <stdio.h> +#include <algorithm> +#include <thread> +#include <fstream> +#include <filesystem> +#include <codecvt> + +namespace fs = std::filesystem; + +// NOTE: this is copied from common.cpp to avoid linking with libcommon +// returns true if successful, false otherwise +static bool fs_create_directory_with_parents(const std::string& path) { +#ifdef _WIN32 + std::wstring_convert<std::codecvt_utf8<wchar_t>> converter; + std::wstring wpath = converter.from_bytes(path); + + // if the path already exists, check whether it's a directory + const DWORD attributes = GetFileAttributesW(wpath.c_str()); + if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return true; + } + + size_t pos_slash = 0; + + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) { + const std::wstring subpath = wpath.substr(0, pos_slash); + const wchar_t* test = subpath.c_str(); + + const bool success = CreateDirectoryW(test, NULL); + if (!success) { + const DWORD error = GetLastError(); + + // if the path already exists, ensure that it's a directory + if (error == ERROR_ALREADY_EXISTS) { + const DWORD attributes = GetFileAttributesW(subpath.c_str()); + if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) { + return false; + } + } + else { + return false; + } + } + + pos_slash += 1; + } + + return true; +#else + // if the path already exists, check whether it's a directory + struct stat info; + if (stat(path.c_str(), &info) == 0) { + return S_ISDIR(info.st_mode); + } + + size_t pos_slash = 1; // skip leading slashes for directory creation + + // process path from front to back, procedurally creating directories + while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) { + const std::string subpath = path.substr(0, pos_slash); + struct stat info; + + // if the path already exists, ensure that it's a directory + if (stat(subpath.c_str(), &info) == 0) { + if (!S_ISDIR(info.st_mode)) { + return false; + } + } + else { + // create parent directories + const int ret = mkdir(subpath.c_str(), 0755); + if (ret != 0) { + return false; + } + } + + pos_slash += 1; + } + + return true; +#endif // _WIN32 +} + +// NOTE: this is copied from common.cpp to avoid linking with libcommon +static std::string fs_get_cache_directory() { + std::string cache_directory = ""; + auto ensure_trailing_slash = [](std::string p) { + // Make sure to add trailing slash + if (p.back() != DIRECTORY_SEPARATOR) { + p += DIRECTORY_SEPARATOR; + } + return p; + }; + if (getenv("LLAMA_CACHE")) { + cache_directory = std::getenv("LLAMA_CACHE"); + } + else { +#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) + if (std::getenv("XDG_CACHE_HOME")) { + cache_directory = std::getenv("XDG_CACHE_HOME"); + } + else { + cache_directory = std::getenv("HOME") + std::string("/.cache/"); + } +#elif defined(__APPLE__) + cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); +#elif defined(_WIN32) + cache_directory = std::getenv("LOCALAPPDATA"); +#else +# error Unknown architecture +#endif + cache_directory = ensure_trailing_slash(cache_directory); + cache_directory += "llama.cpp"; + } + return ensure_trailing_slash(cache_directory); +} struct rpc_server_params { std::string host = "127.0.0.1"; int port = 50052; size_t backend_mem = 0; + bool use_cache = false; + int n_threads = std::max(1U, std::thread::hardware_concurrency() / 2); }; -static void print_usage(int /*argc*/, char ** argv, rpc_server_params params) { +static void print_usage(int /*argc*/, char** argv, rpc_server_params params) { fprintf(stderr, "Usage: %s [options]\n\n", argv[0]); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help show this help message and exit\n"); - fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str()); - fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port); - fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " -t, --threads number of threads for the CPU backend (default: %d)\n", params.n_threads); + fprintf(stderr, " -H HOST, --host HOST host to bind to (default: %s)\n", params.host.c_str()); + fprintf(stderr, " -p PORT, --port PORT port to bind to (default: %d)\n", params.port); + fprintf(stderr, " -m MEM, --mem MEM backend memory size (in MB)\n"); + fprintf(stderr, " -c, --cache enable local file cache\n"); fprintf(stderr, "\n"); } -static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & params) { +static bool rpc_server_params_parse(int argc, char** argv, rpc_server_params& params) { std::string arg; for (int i = 1; i < argc; i++) { arg = argv[i]; @@ -40,7 +173,18 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & return false; } params.host = argv[i]; - } else if (arg == "-p" || arg == "--port") { + } + else if (arg == "-t" || arg == "--threads") { + if (++i >= argc) { + return false; + } + params.n_threads = std::stoi(argv[i]); + if (params.n_threads <= 0) { + fprintf(stderr, "error: invalid number of threads: %d\n", params.n_threads); + return false; + } + } + else if (arg == "-p" || arg == "--port") { if (++i >= argc) { return false; } @@ -48,15 +192,21 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & if (params.port <= 0 || params.port > 65535) { return false; } - } else if (arg == "-m" || arg == "--mem") { + } + else if (arg == "-c" || arg == "--cache") { + params.use_cache = true; + } + else if (arg == "-m" || arg == "--mem") { if (++i >= argc) { return false; } params.backend_mem = std::stoul(argv[i]) * 1024 * 1024; - } else if (arg == "-h" || arg == "--help") { + } + else if (arg == "-h" || arg == "--help") { print_usage(argc, argv, params); exit(0); - } else { + } + else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); print_usage(argc, argv, params); exit(0); @@ -65,7 +215,7 @@ static bool rpc_server_params_parse(int argc, char ** argv, rpc_server_params & return true; } -static ggml_backend_t create_backend() { +static ggml_backend_t create_backend(const rpc_server_params& params) { ggml_backend_t backend = NULL; #ifdef GGML_USE_CUDA fprintf(stderr, "%s: using CUDA backend\n", __func__); @@ -79,12 +229,25 @@ static ggml_backend_t create_backend() { if (!backend) { fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__); } +#elif GGML_USE_VULKAN + fprintf(stderr, "%s: using Vulkan backend\n", __func__); + backend = ggml_backend_vk_init(0); // init device 0 + if (!backend) { + fprintf(stderr, "%s: ggml_backend_vulkan_init() failed\n", __func__); + } +#elif GGML_USE_SYCL + fprintf(stderr, "%s: using SYCL backend\n", __func__); + backend = ggml_backend_sycl_init(0); // init device 0 + if (!backend) { + fprintf(stderr, "%s: ggml_backend_sycl_init() failed\n", __func__); + } #endif // if there aren't GPU Backends fallback to CPU backend if (!backend) { fprintf(stderr, "%s: using CPU backend\n", __func__); backend = ggml_backend_cpu_init(); + ggml_backend_cpu_set_n_threads(backend, params.n_threads); } return backend; } @@ -92,6 +255,10 @@ static ggml_backend_t create_backend() { static void get_backend_memory(size_t * free_mem, size_t * total_mem) { #ifdef GGML_USE_CUDA ggml_backend_cuda_get_device_memory(0, free_mem, total_mem); +#elif GGML_USE_VULKAN + ggml_backend_vk_get_device_memory(0, free_mem, total_mem); +#elif GGML_USE_SYCL + ggml_backend_sycl_get_device_memory(0, free_mem, total_mem); #else #ifdef _WIN32 MEMORYSTATUSEX status; @@ -125,7 +292,7 @@ int main(int argc, char * argv[]) { fprintf(stderr, "\n"); } - ggml_backend_t backend = create_backend(); + ggml_backend_t backend = create_backend(params); if (!backend) { fprintf(stderr, "Failed to create backend\n"); return 1; @@ -135,11 +302,28 @@ int main(int argc, char * argv[]) { if (params.backend_mem > 0) { free_mem = params.backend_mem; total_mem = params.backend_mem; - } else { + } + else { get_backend_memory(&free_mem, &total_mem); } - printf("Starting RPC server on %s, backend memory: %zu MB\n", endpoint.c_str(), free_mem / (1024 * 1024)); - start_rpc_server(backend, endpoint.c_str(), free_mem, total_mem); + const char * cache_dir = nullptr; + std::string cache_dir_str; + if (params.use_cache) { + cache_dir_str = fs_get_cache_directory() + "rpc/"; + if (!fs_create_directory_with_parents(cache_dir_str)) { + fprintf(stderr, "Failed to create cache directory: %s\n", cache_dir_str.c_str()); + return 1; + } + cache_dir = cache_dir_str.c_str(); + } + printf("Starting RPC server v%d.%d.%d\n", + RPC_PROTO_MAJOR_VERSION, + RPC_PROTO_MINOR_VERSION, + RPC_PROTO_PATCH_VERSION); + printf(" endpoint : %s\n", endpoint.c_str()); + printf(" local cache : %s\n", cache_dir ? cache_dir : "n/a"); + printf(" backend memory : %zu MB\n", free_mem / (1024 * 1024)); + ggml_backend_rpc_start_server(backend, endpoint.c_str(), cache_dir, free_mem, total_mem); ggml_backend_free(backend); return 0; } |