summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp159
1 files changed, 126 insertions, 33 deletions
diff --git a/llama.cpp b/llama.cpp
index d2a52bb0..b8bc0d82 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -165,6 +165,11 @@ struct llama_kv_cache {
if (ctx) {
ggml_free(ctx);
}
+
+#ifdef GGML_USE_CUBLAS
+ ggml_cuda_free_data(k);
+ ggml_cuda_free_data(v);
+#endif // GGML_USE_CUBLAS
}
};
@@ -210,6 +215,7 @@ struct llama_model {
for (size_t i = 0; i < tensors_by_name.size(); ++i) {
ggml_cuda_free_data(tensors_by_name[i].second);
}
+ ggml_cuda_free_scratch();
#elif defined(GGML_USE_CLBLAST)
for (size_t i = 0; i < tensors_by_name.size(); ++i) {
ggml_cl_free_data(tensors_by_name[i].second);
@@ -867,7 +873,8 @@ static bool kv_cache_init(
const struct llama_hparams & hparams,
struct llama_kv_cache & cache,
ggml_type wtype,
- int n_ctx) {
+ int n_ctx,
+ int n_gpu_layers) {
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
@@ -893,6 +900,15 @@ static bool kv_cache_init(
ggml_set_name(cache.k, "cache_k");
ggml_set_name(cache.v, "cache_v");
+#ifdef GGML_USE_CUBLAS
+ if (n_gpu_layers > n_layer + 1) {
+ ggml_cuda_assign_buffers_no_scratch(cache.v);
+ }
+ if (n_gpu_layers > n_layer + 2) {
+ ggml_cuda_assign_buffers_no_scratch(cache.k);
+ }
+#endif // GGML_USE_CUBLAS
+
return true;
}
@@ -903,6 +919,7 @@ struct llama_context_params llama_context_default_params() {
/*.gpu_layers =*/ 0,
/*.main_gpu =*/ 0,
/*.tensor_split =*/ {0},
+ /*.low_vram =*/ false,
/*.seed =*/ -1,
/*.f16_kv =*/ true,
/*.logits_all =*/ false,
@@ -1011,6 +1028,7 @@ static void llama_model_load_internal(
int n_gpu_layers,
int main_gpu,
const float * tensor_split,
+ bool low_vram,
ggml_type memory_type,
bool use_mmap,
bool use_mlock,
@@ -1137,18 +1155,34 @@ static void llama_model_load_internal(
ml->ggml_ctx = ctx;
model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU);
- model.norm = ml->get_tensor("norm.weight", {n_embd}, GGML_BACKEND_CPU);
// "output" tensor
{
+ ggml_backend backend_norm;
ggml_backend backend_output;
if (n_gpu_layers > int(n_layer)) { // NOLINT
+ // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
+ // on Windows however this is detrimental unless everything is on the GPU
+#ifndef _WIN32
+ backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+#else
+ backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+#endif // _WIN32
+
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
} else {
+ backend_norm = GGML_BACKEND_CPU;
backend_output = GGML_BACKEND_CPU;
}
+ model.norm = ml->get_tensor("norm.weight", {n_embd}, backend_norm);
model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output);
+ if (backend_norm == GGML_BACKEND_GPU) {
+ vram_weights += ggml_nbytes(model.norm);
+ }
+ if (backend_output == GGML_BACKEND_GPU_SPLIT) {
+ vram_weights += ggml_nbytes(model.output);
+ }
}
const int i_gpu_start = n_layer - n_gpu_layers;
@@ -1208,22 +1242,47 @@ static void llama_model_load_internal(
(void) vram_scratch;
(void) n_batch;
#ifdef GGML_USE_CUBLAS
- vram_scratch = n_batch * MB;
- ggml_cuda_set_scratch_size(vram_scratch);
- if (n_gpu_layers > 0) {
- fprintf(stderr, "%s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n",
- __func__, vram_scratch / MB);
+ if (low_vram) {
+ fprintf(stderr, "%s: not allocating a VRAM scratch buffer due to low VRAM option\n", __func__);
+ ggml_cuda_set_scratch_size(0); // disable scratch
+ } else {
+ vram_scratch = n_batch * MB;
+ ggml_cuda_set_scratch_size(vram_scratch);
+ if (n_gpu_layers > 0) {
+ fprintf(stderr, "%s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n",
+ __func__, vram_scratch / MB);
+ }
}
#endif // GGML_USE_CUBLAS
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
- fprintf(stderr, "%s: offloading %d layers to GPU\n", __func__, n_gpu);
+ fprintf(stderr, "%s: offloading %d repeating layers to GPU\n", __func__, n_gpu);
if (n_gpu_layers > (int) hparams.n_layer) {
- fprintf(stderr, "%s: offloading output layer to GPU\n", __func__);
+ fprintf(stderr, "%s: offloading non-repeating layers to GPU\n", __func__);
+ }
+ size_t vram_kv_cache = 0;
+ if (n_gpu_layers > (int) hparams.n_layer + 1) {
+ if (low_vram) {
+ fprintf(stderr, "%s: cannot offload v cache to GPU due to low VRAM option\n", __func__);
+ } else {
+ fprintf(stderr, "%s: offloading v cache to GPU\n", __func__);
+ vram_kv_cache += MEM_REQ_KV_SELF().at(model.type) / 2;
+ }
}
+ if (n_gpu_layers > (int) hparams.n_layer + 2) {
+ if (low_vram) {
+ fprintf(stderr, "%s: cannot offload k cache to GPU due to low VRAM option\n", __func__);
+ } else {
+ fprintf(stderr, "%s: offloading k cache to GPU\n", __func__);
+ vram_kv_cache += MEM_REQ_KV_SELF().at(model.type) / 2;
+ }
+ }
+ const int max_offloadable_layers = low_vram ? hparams.n_layer + 1 : hparams.n_layer + 3;
+ fprintf(stderr, "%s: offloaded %d/%d layers to GPU\n",
+ __func__, std::min(n_gpu_layers, max_offloadable_layers), hparams.n_layer + 3);
fprintf(stderr, "%s: total VRAM used: %zu MB\n",
- __func__, (vram_weights + vram_scratch + MB - 1) / MB); // round up
+ __func__, (vram_weights + vram_scratch + vram_kv_cache + MB - 1) / MB); // round up
#else
(void) n_gpu_layers;
#endif
@@ -1262,6 +1321,7 @@ static bool llama_model_load(
int n_gpu_layers,
int main_gpu,
float * tensor_split,
+ bool low_vram,
ggml_type memory_type,
bool use_mmap,
bool use_mlock,
@@ -1269,7 +1329,7 @@ static bool llama_model_load(
llama_progress_callback progress_callback,
void *progress_callback_user_data) {
try {
- llama_model_load_internal(fname, lctx, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, memory_type,
+ llama_model_load_internal(fname, lctx, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, low_vram, memory_type,
use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
return true;
} catch (const std::exception & err) {
@@ -1345,12 +1405,33 @@ static bool llama_eval_internal(
const int i_gpu_start = n_layer - n_gpu_layers;
(void) i_gpu_start;
+ // offload functions set the tensor output backend to GPU
+ // tensors are GPU-accelerated if any input or the output has been offloaded
+ //
+ // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
+ // in that case ggml_cuda_assign_buffers has no effect
+ offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
+ offload_func_t offload_func_kq = llama_nop;
+ offload_func_t offload_func_v = llama_nop;
+
+#ifdef GGML_USE_CUBLAS
+ if (n_gpu_layers > n_layer) {
+ offload_func_nr = ggml_cuda_assign_buffers;
+ }
+ if (n_gpu_layers > n_layer + 1) {
+ offload_func_v = ggml_cuda_assign_buffers;
+ }
+ if (n_gpu_layers > n_layer + 2) {
+ offload_func_kq = ggml_cuda_assign_buffers;
+ }
+#endif // GGML_USE_CUBLAS
+
for (int il = 0; il < n_layer; ++il) {
offload_func_t offload_func = llama_nop;
#ifdef GGML_USE_CUBLAS
if (il >= i_gpu_start) {
- offload_func = ggml_cuda_assign_buffers; // sets the output backend to GPU
+ offload_func = ggml_cuda_assign_buffers;
}
#endif // GGML_USE_CUBLAS
@@ -1373,31 +1454,42 @@ static bool llama_eval_internal(
// self-attention
{
// compute Q and K and RoPE them
- struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
- // offload_func(tmpq);
- ggml_set_name(tmpq, "tmpq");
-
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
- // offload_func(tmpk);
+ offload_func_kq(tmpk);
ggml_set_name(tmpk, "tmpk");
+ struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
+ offload_func_kq(tmpq);
+ ggml_set_name(tmpq, "tmpq");
+
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0);
+ offload_func_kq(Kcur);
ggml_set_name(Kcur, "Kcur");
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0);
+ offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur");
// store key and value to memory
{
// compute the transposed [N, n_embd] V matrix
- struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N));
+
+ struct ggml_tensor * tmpv = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
+ offload_func_v(tmpv);
+ ggml_set_name(tmpv, "tmpv");
+
+ struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, tmpv, n_embd, N));
+ offload_func_v(Vcur);
ggml_set_name(Vcur, "Vcur");
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
+ offload_func_kq(k);
ggml_set_name(k, "k");
+
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
( n_ctx)*ggml_element_size(kv_self.v),
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
+ offload_func_v(v);
ggml_set_name(v, "v");
// important: storing RoPE-ed version of K in the KV cache!
@@ -1409,6 +1501,7 @@ static bool llama_eval_internal(
ggml_permute(ctx0,
Qcur,
0, 2, 1, 3);
+ offload_func_kq(Q);
ggml_set_name(Q, "Q");
struct ggml_tensor * K =
@@ -1417,10 +1510,12 @@ static bool llama_eval_internal(
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
n_embd/n_head, n_head, n_past + N),
0, 2, 1, 3);
+ offload_func_kq(K);
ggml_set_name(K, "K");
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+ offload_func_kq(KQ);
ggml_set_name(KQ, "KQ");
// KQ_scaled = KQ / sqrt(n_embd/n_head)
@@ -1429,14 +1524,17 @@ static bool llama_eval_internal(
// KQ_scaled shape [n_past + N, N, n_head, 1]
struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
+ offload_func_kq(KQ_scaled);
ggml_set_name(KQ_scaled, "KQ_scaled");
// KQ_masked = mask_past(KQ_scaled)
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
+ offload_func_kq(KQ_masked);
ggml_set_name(KQ_masked, "KQ_masked");
// KQ = soft_max(KQ_masked)
struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
+ offload_func_v(KQ_soft_max);
ggml_set_name(KQ_soft_max, "KQ_soft_max");
// split cached V into n_head heads
@@ -1446,10 +1544,12 @@ static bool llama_eval_internal(
n_ctx*ggml_element_size(kv_self.v),
n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head,
il*n_ctx*ggml_element_size(kv_self.v)*n_embd);
+ offload_func_v(V);
ggml_set_name(V, "V");
#if 1
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
+ offload_func_v(KQV);
ggml_set_name(KQV, "KQV");
#else
// make V contiguous in memory to speed up the matmul, however we waste time on the copy
@@ -1461,12 +1561,14 @@ static bool llama_eval_internal(
// KQV_merged = KQV.permute(0, 2, 1, 3)
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+ offload_func_v(KQV_merged);
ggml_set_name(KQV_merged, "KQV_merged");
// cur = KQV_merged.contiguous().view(n_embd, N)
cur = ggml_cpy(ctx0,
KQV_merged,
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
+ offload_func_v(cur);
ggml_set_name(cur, "KQV_merged_contiguous");
// projection (no bias)
@@ -1478,7 +1580,6 @@ static bool llama_eval_internal(
}
lctx.use_buf(ctx0, 1);
- //ggml_cuda_set_scratch(1);
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
offload_func(inpFF);
@@ -1536,32 +1637,24 @@ static bool llama_eval_internal(
}
lctx.use_buf(ctx0, 0);
- //ggml_cuda_set_scratch(0);
// used at the end to optionally extract the embeddings
struct ggml_tensor * embeddings = NULL;
- offload_func_t offload_func = llama_nop;
-
-#ifdef GGML_USE_CUBLAS
- if (n_gpu_layers > n_layer) {
- offload_func = ggml_cuda_assign_buffers; // sets the output backend to GPU
- }
-#endif // GGML_USE_CUBLAS
// norm
{
cur = ggml_rms_norm(ctx0, inpL);
- offload_func(cur);
+ offload_func_nr(cur);
ggml_set_name(cur, "rms_norm_inpL");
cur = ggml_rms_norm(ctx0, cur);
- offload_func(cur);
+ offload_func_nr(cur);
ggml_set_name(cur, "rms_norm_after");
// cur = cur*norm(broadcasted)
cur = ggml_mul(ctx0, cur, model.norm);
- offload_func(cur);
+ offload_func_nr(cur);
ggml_set_name(cur, "result_norm");
embeddings = cur;
@@ -2552,8 +2645,8 @@ struct llama_context * llama_init_from_file(
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
- if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_batch, params.n_gpu_layers,
- params.main_gpu, params.tensor_split, memory_type, params.use_mmap, params.use_mlock,
+ if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_batch, params.n_gpu_layers, params.main_gpu,
+ params.tensor_split, params.low_vram, memory_type, params.use_mmap, params.use_mlock,
params.vocab_only, params.progress_callback, params.progress_callback_user_data)) {
fprintf(stderr, "%s: failed to load model\n", __func__);
llama_free(ctx);
@@ -2562,7 +2655,7 @@ struct llama_context * llama_init_from_file(
// reserve memory for context buffers
if (!params.vocab_only) {
- if (!kv_cache_init(ctx->model.hparams, ctx->model.kv_self, memory_type, ctx->model.hparams.n_ctx)) {
+ if (!kv_cache_init(ctx->model.hparams, ctx->model.kv_self, memory_type, ctx->model.hparams.n_ctx, params.n_gpu_layers)) {
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx);
return nullptr;