summaryrefslogtreecommitdiff
path: root/llama.h
diff options
context:
space:
mode:
authorslaren <slarengh@gmail.com>2024-01-12 20:07:38 +0100
committerGitHub <noreply@github.com>2024-01-12 20:07:38 +0100
commite7e4df031b9e29d4b55a4e0b0295187f6b213db1 (patch)
tree93211b7800be3c2c5f9eb1d55f3b7b3acdc56c9b /llama.h
parent584d674be622fbf1578694ada6e62eebedbfd377 (diff)
llama : ggml-backend integration (#4766)
* llama : ggml-backend integration * ggml-backend : add names to buffers * fix unmap after loading * batched-bench : add tensor_split param * llama : check for null tensor_split * ggml-backend : increase GGML_MAX_BACKENDS * improve graph splitting, partial fix for --no-kv-offload * cuda : add ggml-backend split buffer support * cuda : do not create buffer types for devices that don't exist (fixes usage without CUDA devices available) * ggml : fix null backend dereference (#4807) * ggml : fix null backend dereference * ggml : also check ggml_backend_is_cpu * test-backend-ops : check buffer allocation failures * llama : add cparam (split_mode) and command line argument (--split-mode, -sm) to configure the split mode (none, layer or row) * ggml : fix mul_mat_id work size * llama : rewrite session kv load/set without graphs * minor * llama : only initialize used backends, free backends on context free * llama : abort ctx if cuda backend init fails * llama : rewrite lora with ggml-backend and compute on CPU ggml-ci * llama : only map to a backend buffer the region of the file mapping containing the tensors used in the buffer * opencl : add ggml-backend buffer type * cuda : only use batched_cublas with batched mat muls (fixes fp16 tg perf) * llama : on Metal, by default offload the full model ggml-ci * metal : page align the data ptr (#4854) * Apply suggestions from code review Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * cuda : fix split buffer free * address review comments * llama-bench : add split-mode parameter * fix whitespace * opencl : fix double initialization * server : add --split-mode parameter * use async copy and compute to improve multi-gpu performance ggml-ci * use async memcpys to copy the graph outputs to the CPU * fix opencl * use a host buffer for the cpu compute buffer for faster copies to the gpu --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
Diffstat (limited to 'llama.h')
-rw-r--r--llama.h18
1 files changed, 16 insertions, 2 deletions
diff --git a/llama.h b/llama.h
index 43d41b8f..689e12d7 100644
--- a/llama.h
+++ b/llama.h
@@ -118,6 +118,12 @@ extern "C" {
LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN,
};
+ enum llama_split_mode {
+ LLAMA_SPLIT_NONE = 0, // single GPU
+ LLAMA_SPLIT_LAYER = 1, // split layers and KV across GPUs
+ LLAMA_SPLIT_ROW = 2, // split rows across GPUs
+ };
+
typedef struct llama_token_data {
llama_token id; // token id
float logit; // log-odds of the token
@@ -180,8 +186,16 @@ extern "C" {
struct llama_model_params {
int32_t n_gpu_layers; // number of layers to store in VRAM
- int32_t main_gpu; // the GPU that is used for scratch and small tensors
- const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
+ enum llama_split_mode split_mode; // how to split the model across multiple GPUs
+
+ // main_gpu interpretation depends on split_mode:
+ // LLAMA_SPLIT_NONE: the GPU that is used for the entire model
+ // LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results
+ // LLAMA_SPLIT_LAYER: ignored
+ int32_t main_gpu;
+
+ // proportion of the model (layers or rows) to offload to each GPU, size: LLAMA_MAX_DEVICES
+ const float * tensor_split;
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
// If the provided progress_callback returns true, model loading continues.