summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
author0cc4m <picard12@live.de>2024-02-07 07:54:50 +0100
committerGitHub <noreply@github.com>2024-02-07 07:54:50 +0100
commitee1628bdfea8b0079fed0140ac2f00ef1b465b57 (patch)
tree42ee597afa79a6c4e0bb772d78a7cfcd54777696 /llama.cpp
parented0bf32290ee5b30ffad5becd99cbecef74aedd7 (diff)
Basic Vulkan Multi-GPU implementation (#5321)
* Initial Vulkan multi-gpu implementation Move most global variables into backend context * Add names to backend device functions * Add further missing cleanup code * Reduce code duplication in tensor split layer assignment * generalize LLAMA_SPLIT_LAYER for all backends, do not expose device count and memory in llama.h * Only do device info print in the beginning and initialize one backend for cpu assist Add missing cleanup code * Rework backend memory management to make sure devices and buffers get properly allocated and freed * Rename cpu assist free function --------- Co-authored-by: slaren <slarengh@gmail.com>
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp69
1 files changed, 49 insertions, 20 deletions
diff --git a/llama.cpp b/llama.cpp
index f3c5146d..c45ae1d5 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1355,7 +1355,7 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(int gpu) {
#elif defined(GGML_USE_CUBLAS)
buft = ggml_backend_cuda_buffer_type(gpu);
#elif defined(GGML_USE_VULKAN)
- buft = ggml_backend_vk_buffer_type();
+ buft = ggml_backend_vk_buffer_type(gpu);
#elif defined(GGML_USE_SYCL)
buft = ggml_backend_sycl_buffer_type(gpu);
#elif defined(GGML_USE_CLBLAST)
@@ -1392,6 +1392,33 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_split(int fallback_g
GGML_UNUSED(tensor_split);
}
+static size_t llama_get_device_count() {
+#if defined(GGML_USE_CUBLAS)
+ return ggml_backend_cuda_get_device_count();
+#elif defined(GGML_USE_VULKAN)
+ return ggml_backend_vk_get_device_count();
+#else
+ return 1;
+#endif
+}
+
+static size_t llama_get_device_memory(int device) {
+#if defined(GGML_USE_CUBLAS)
+ size_t total;
+ size_t free;
+ ggml_backend_cuda_get_device_memory(device, &total, &free);
+ return free;
+#elif defined(GGML_USE_VULKAN)
+ size_t total;
+ size_t free;
+ ggml_backend_vk_get_device_memory(device, &total, &free);
+ return free;
+#else
+ return 1;
+ GGML_UNUSED(device);
+#endif
+}
+
//
// globals
//
@@ -1763,6 +1790,10 @@ struct llama_context {
ggml_backend_free(backend);
}
+#ifdef GGML_USE_VULKAN
+ ggml_vk_free_cpu_assist();
+#endif
+
ggml_backend_buffer_free(buf_input);
ggml_free(ctx_input);
}
@@ -3436,22 +3467,18 @@ static bool llm_load_tensors(
model.buft_layer[i] = llama_default_buffer_type_cpu(true);
}
-#ifdef GGML_USE_CUBLAS
if (split_mode == LLAMA_SPLIT_LAYER) {
// calculate the split points
- int device_count = ggml_backend_cuda_get_device_count();
+ int device_count = llama_get_device_count();
bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; });
- float splits[GGML_CUDA_MAX_DEVICES];
+ std::vector<float> splits(device_count);
if (all_zero) {
// default split, by free memory
for (int i = 0; i < device_count; ++i) {
- size_t total;
- size_t free;
- ggml_backend_cuda_get_device_memory(i, &total, &free);
- splits[i] = free;
+ splits[i] = llama_get_device_memory(i);
}
} else {
- std::copy(tensor_split, tensor_split + device_count, splits);
+ std::copy(tensor_split, tensor_split + device_count, splits.begin());
}
// sum and normalize the splits to get the split points
@@ -3467,19 +3494,17 @@ static bool llm_load_tensors(
// assign the repeating layers to the devices according to the splits
int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1);
for (int64_t i = i_gpu_start; i < n_layer; ++i) {
- int layer_gpu = std::upper_bound(splits, splits + device_count, float(i - i_gpu_start)/act_gpu_layers) - splits;
+ int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(i - i_gpu_start)/act_gpu_layers) - splits.begin();
model.buft_layer[i] = llama_default_buffer_type_offload(layer_gpu);
}
// assign the output layer
if (n_gpu_layers > n_layer) {
- int layer_gpu = std::upper_bound(splits, splits + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - splits;
+ int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - splits.begin();
model.buft_output = llama_default_buffer_type_offload(layer_gpu);
} else {
model.buft_output = llama_default_buffer_type_cpu(true);
}
- } else
-#endif
- {
+ } else {
ggml_backend_buffer_type_t split_buft;
if (split_mode == LLAMA_SPLIT_ROW) {
split_buft = llama_default_buffer_type_split(main_gpu, tensor_split);
@@ -10483,6 +10508,8 @@ size_t llama_max_devices(void) {
return GGML_CUDA_MAX_DEVICES;
#elif defined(GGML_USE_SYCL)
return GGML_SYCL_MAX_DEVICES;
+#elif defined(GGML_USE_VULKAN)
+ return GGML_VK_MAX_DEVICES;
#else
return 1;
#endif
@@ -10690,13 +10717,15 @@ struct llama_context * llama_new_context_with_model(
}
#elif defined(GGML_USE_VULKAN)
if (model->n_gpu_layers > 0) {
- ggml_backend_t backend = ggml_backend_vk_init();
- if (backend == nullptr) {
- LLAMA_LOG_ERROR("%s: failed to initialize Vulkan backend\n", __func__);
- llama_free(ctx);
- return nullptr;
+ for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) {
+ ggml_backend_t backend = ggml_backend_vk_init(device);
+ if (backend == nullptr) {
+ LLAMA_LOG_ERROR("%s: failed to initialize Vulkan%d backend\n", __func__, device);
+ llama_free(ctx);
+ return nullptr;
+ }
+ ctx->backends.push_back(backend);
}
- ctx->backends.push_back(backend);
}
#elif defined(GGML_USE_SYCL)
if (model->n_gpu_layers > 0) {