diff options
author | Radoslav Gerganov <rgerganov@gmail.com> | 2024-05-14 14:27:19 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-14 14:27:19 +0300 |
commit | 5e31828d3e35c76ecfee665bc23771a4bec1d130 (patch) | |
tree | 7f5f2edc7c3fc3e7655904316897e32202edd5d6 /llama.cpp | |
parent | 541600201e6480f54ae09e58d16b154d4b4b331d (diff) |
ggml : add RPC backend (#6829)
* ggml : add RPC backend
The RPC backend proxies all operations to a remote server which runs a
regular backend (CPU, CUDA, Metal, etc).
* set TCP_NODELAY
* add CI workflows
* Address review comments
* fix warning
* implement llama_max_devices() for RPC
* Address review comments
* Address review comments
* wrap sockfd into a struct
* implement get_alignment and get_max_size
* add get_device_memory
* fix warning
* win32 support
* add README
* readme : trim trailing whitespace
* Address review comments
* win32 fix
* Address review comments
* fix compile warnings on macos
Diffstat (limited to 'llama.cpp')
-rw-r--r-- | llama.cpp | 238 |
1 files changed, 140 insertions, 98 deletions
@@ -7,6 +7,10 @@ #include "ggml-alloc.h" #include "ggml-backend.h" +#ifdef GGML_USE_RPC +# include "ggml-rpc.h" +#endif + #ifdef GGML_USE_CUDA # include "ggml-cuda.h" #elif defined(GGML_USE_CLBLAST) @@ -1685,91 +1689,6 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer GGML_UNUSED(host_buffer); } -static ggml_backend_buffer_type_t llama_default_buffer_type_offload(int gpu) { - ggml_backend_buffer_type_t buft = nullptr; - -#ifdef GGML_USE_METAL - buft = ggml_backend_metal_buffer_type(); -#elif defined(GGML_USE_CUDA) - buft = ggml_backend_cuda_buffer_type(gpu); -#elif defined(GGML_USE_VULKAN) - buft = ggml_backend_vk_buffer_type(gpu); -#elif defined(GGML_USE_SYCL) - buft = ggml_backend_sycl_buffer_type(gpu); -#elif defined(GGML_USE_CLBLAST) - buft = ggml_backend_opencl_buffer_type(); -#elif defined(GGML_USE_KOMPUTE) - buft = ggml_backend_kompute_buffer_type(gpu); - if (buft == nullptr) { - LLAMA_LOG_WARN("%s: cannot use GPU %d, check `vulkaninfo --summary`\n", __func__, gpu); - } -#endif - - if (buft == nullptr) { - buft = llama_default_buffer_type_cpu(true); - } - return buft; - - GGML_UNUSED(gpu); -} - -static ggml_backend_buffer_type_t llama_default_buffer_type_split(int fallback_gpu, const float * tensor_split) { - ggml_backend_buffer_type_t buft = nullptr; - -#ifdef GGML_USE_CUDA - if (ggml_backend_cuda_get_device_count() > 1) { - buft = ggml_backend_cuda_split_buffer_type(tensor_split); - } -#endif - -#ifdef GGML_USE_SYCL - if (ggml_backend_sycl_get_device_count() > 1) { - buft = ggml_backend_sycl_split_buffer_type(tensor_split); - } -#endif - - if (buft == nullptr) { - buft = llama_default_buffer_type_offload(fallback_gpu); - } - return buft; - - GGML_UNUSED(tensor_split); -} - -static size_t llama_get_device_count() { -#if defined(GGML_USE_CUDA) - return ggml_backend_cuda_get_device_count(); -#elif defined(GGML_USE_SYCL) - return ggml_backend_sycl_get_device_count(); -#elif defined(GGML_USE_VULKAN) - return ggml_backend_vk_get_device_count(); -#else - return 1; -#endif -} - -static size_t llama_get_device_memory(int device) { -#if defined(GGML_USE_CUDA) - size_t total; - size_t free; - ggml_backend_cuda_get_device_memory(device, &free, &total); - return free; -#elif defined(GGML_USE_SYCL) - size_t total; - size_t free; - ggml_backend_sycl_get_device_memory(device, &free, &total); - return free; -#elif defined(GGML_USE_VULKAN) - size_t total; - size_t free; - ggml_backend_vk_get_device_memory(device, &free, &total); - return free; -#else - return 1; - GGML_UNUSED(device); -#endif -} - // // globals // @@ -2210,6 +2129,8 @@ struct llama_model { int main_gpu; int n_gpu_layers; + std::vector<std::string> rpc_servers; + // gguf metadata std::unordered_map<std::string, std::string> gguf_kv; @@ -2353,6 +2274,104 @@ struct llama_context { #endif }; +static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_model & model, int gpu) { + ggml_backend_buffer_type_t buft = nullptr; + +#ifdef GGML_USE_RPC + std::string endpoint = model.rpc_servers[gpu]; + buft = ggml_backend_rpc_buffer_type(endpoint.c_str()); +#elif defined(GGML_USE_METAL) + buft = ggml_backend_metal_buffer_type(); +#elif defined(GGML_USE_CUDA) + buft = ggml_backend_cuda_buffer_type(gpu); +#elif defined(GGML_USE_VULKAN) + buft = ggml_backend_vk_buffer_type(gpu); +#elif defined(GGML_USE_SYCL) + buft = ggml_backend_sycl_buffer_type(gpu); +#elif defined(GGML_USE_CLBLAST) + buft = ggml_backend_opencl_buffer_type(); +#elif defined(GGML_USE_KOMPUTE) + buft = ggml_backend_kompute_buffer_type(gpu); + if (buft == nullptr) { + LLAMA_LOG_WARN("%s: cannot use GPU %d, check `vulkaninfo --summary`\n", __func__, gpu); + } +#endif + + if (buft == nullptr) { + buft = llama_default_buffer_type_cpu(true); + } + return buft; + GGML_UNUSED(model); + GGML_UNUSED(gpu); +} + +static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_model & model, int fallback_gpu, const float * tensor_split) { + ggml_backend_buffer_type_t buft = nullptr; + +#ifdef GGML_USE_CUDA + if (ggml_backend_cuda_get_device_count() > 1) { + buft = ggml_backend_cuda_split_buffer_type(tensor_split); + } +#endif + +#ifdef GGML_USE_SYCL + if (ggml_backend_sycl_get_device_count() > 1) { + buft = ggml_backend_sycl_split_buffer_type(tensor_split); + } +#endif + + if (buft == nullptr) { + buft = llama_default_buffer_type_offload(model, fallback_gpu); + } + return buft; + + GGML_UNUSED(tensor_split); +} + +static size_t llama_get_device_count(const llama_model & model) { +#if defined(GGML_USE_RPC) + return model.rpc_servers.size(); +#elif defined(GGML_USE_CUDA) + return ggml_backend_cuda_get_device_count(); +#elif defined(GGML_USE_SYCL) + return ggml_backend_sycl_get_device_count(); +#elif defined(GGML_USE_VULKAN) + return ggml_backend_vk_get_device_count(); +#else + return 1; +#endif + GGML_UNUSED(model); +} + +static size_t llama_get_device_memory(const llama_model & model, int device) { +#if defined(GGML_USE_RPC) + size_t total; + size_t free; + std::string endpoint = model.rpc_servers[device]; + ggml_backend_rpc_get_device_memory(endpoint.c_str(), &free, &total); + return free; +#elif defined(GGML_USE_CUDA) + size_t total; + size_t free; + ggml_backend_cuda_get_device_memory(device, &free, &total); + return free; +#elif defined(GGML_USE_SYCL) + size_t total; + size_t free; + ggml_backend_sycl_get_device_memory(device, &free, &total); + return free; +#elif defined(GGML_USE_VULKAN) + size_t total; + size_t free; + ggml_backend_vk_get_device_memory(device, &free, &total); + return free; +#else + return 1; +#endif + GGML_UNUSED(model); + GGML_UNUSED(device); +} + // // kv cache helpers // @@ -4791,13 +4810,13 @@ static bool llm_load_tensors( if (split_mode == LLAMA_SPLIT_MODE_LAYER) { // calculate the split points - int device_count = llama_get_device_count(); + int device_count = llama_get_device_count(model); bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; }); std::vector<float> splits(device_count); if (all_zero) { // default split, by free memory for (int i = 0; i < device_count; ++i) { - splits[i] = llama_get_device_memory(i); + splits[i] = llama_get_device_memory(model, i); } } else { std::copy(tensor_split, tensor_split + device_count, splits.begin()); @@ -4817,35 +4836,35 @@ static bool llm_load_tensors( int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1); for (int64_t i = i_gpu_start; i < n_layer; ++i) { int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(i - i_gpu_start)/act_gpu_layers) - splits.begin(); - model.buft_layer[i] = llama_default_buffer_type_offload(layer_gpu); + model.buft_layer[i] = llama_default_buffer_type_offload(model, layer_gpu); } // assign the output layer if (n_gpu_layers > n_layer) { int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - splits.begin(); - model.buft_output = llama_default_buffer_type_offload(layer_gpu); + model.buft_output = llama_default_buffer_type_offload(model, layer_gpu); } else { model.buft_output = llama_default_buffer_type_cpu(true); } } else { ggml_backend_buffer_type_t split_buft; if (split_mode == LLAMA_SPLIT_MODE_ROW) { - split_buft = llama_default_buffer_type_split(main_gpu, tensor_split); + split_buft = llama_default_buffer_type_split(model, main_gpu, tensor_split); } else { // LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_LAYER in backends where it is not supported - split_buft = llama_default_buffer_type_offload(main_gpu); + split_buft = llama_default_buffer_type_offload(model, main_gpu); } // assign the repeating layers for (int64_t i = i_gpu_start; i < n_layer; ++i) { model.buft_layer[i] = { split_buft, - llama_default_buffer_type_offload(main_gpu) + llama_default_buffer_type_offload(model, main_gpu) }; } // assign the output layer if (n_gpu_layers > n_layer) { model.buft_output = { split_buft, - llama_default_buffer_type_offload(main_gpu) + llama_default_buffer_type_offload(model, main_gpu) }; } else { model.buft_output = llama_default_buffer_type_cpu(true); @@ -15390,6 +15409,7 @@ struct llama_model_params llama_model_default_params() { /*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER, /*.main_gpu =*/ 0, /*.tensor_split =*/ nullptr, + /*.rpc_servers =*/ nullptr, /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, /*.kv_overrides =*/ nullptr, @@ -15460,7 +15480,9 @@ struct llama_model_quantize_params llama_model_quantize_default_params() { } size_t llama_max_devices(void) { -#if defined(GGML_USE_METAL) +#if defined(GGML_USE_RPC) + return GGML_RPC_MAX_SERVERS; +#elif defined(GGML_USE_METAL) return 1; #elif defined(GGML_USE_CUDA) return GGML_CUDA_MAX_DEVICES; @@ -15483,7 +15505,7 @@ bool llama_supports_mlock(void) { bool llama_supports_gpu_offload(void) { #if defined(GGML_USE_CUDA) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \ - defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE) + defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE) || defined(GGML_USE_RPC) // Defined when llama.cpp is compiled with support for offloading model layers to GPU. return true; #else @@ -15546,7 +15568,17 @@ struct llama_model * llama_load_model_from_file( return true; }; } - + if (params.rpc_servers != nullptr) { + // split the servers set them into model->rpc_servers + std::string servers(params.rpc_servers); + size_t pos = 0; + while ((pos = servers.find(",")) != std::string::npos) { + std::string server = servers.substr(0, pos); + model->rpc_servers.push_back(server); + servers.erase(0, pos + 1); + } + model->rpc_servers.push_back(servers); + } int status = llama_model_load(path_model, *model, params); GGML_ASSERT(status <= 0); if (status < 0) { @@ -15693,7 +15725,17 @@ struct llama_context * llama_new_context_with_model( if (!hparams.vocab_only) { // initialize backends -#ifdef GGML_USE_METAL +#if defined(GGML_USE_RPC) + for (auto & server : model->rpc_servers) { + ggml_backend_t backend = ggml_backend_rpc_init(server.c_str()); + if (backend == nullptr) { + LLAMA_LOG_ERROR("%s: failed to connect RPC backend to %s\n", __func__, server.c_str()); + llama_free(ctx); + return nullptr; + } + ctx->backends.push_back(backend); + } +#elif defined(GGML_USE_METAL) if (model->n_gpu_layers > 0) { ctx->backend_metal = ggml_backend_metal_init(); if (ctx->backend_metal == nullptr) { @@ -15850,7 +15892,7 @@ struct llama_context * llama_new_context_with_model( // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary bool pipeline_parallel = - llama_get_device_count() > 1 && + llama_get_device_count(*model) > 1 && model->n_gpu_layers > (int)model->hparams.n_layer && model->split_mode == LLAMA_SPLIT_MODE_LAYER && params.offload_kqv; |