summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
authorcompilade <113953597+compilade@users.noreply.github.com>2024-03-26 10:46:41 -0400
committerGitHub <noreply@github.com>2024-03-26 16:46:41 +0200
commit557410b8f06380560155ac7fcb8316d71ddc9837 (patch)
treecccf337eae34b1e1e46aef188678132b777b2ff7 /llama.cpp
parent55c1b2a3bbd470e9e2a3a0618b92cf64a885f806 (diff)
llama : greatly reduce output buffer memory usage (#6122)
* llama : greatly reduce logits memory usage * llama : more compact state saving and reloading * llama : fix lctx.n_outputs not being set before building graph * perplexity : adapt to the logits API changes * perplexity : fix Winogrande, use correct logits for second choice start The first logits used to evaluate the second choice were not from the end of the common prefix; instead, they were the logits from the end of the first choice. This has been corrected. The previous implementation sometimes had outliers in the scores of choices for some tasks, and the logic to skip choices words in the log-likelihood evaluation probably was an attempt to reduce those, but it was complex and didn't quite seem to be the right thing. This is simpler now, and the outlier scores aren't there anymore. * perplexity : normalize spaces and punctuation in Winogrande sentences * llama : fix embedding conditions * llama : fix llama_get_embeddings_ith when the resulting id is 0 * llama : fix wrong n_outputs in llama_set_inputs A mismatch happened when using a smaller n_ubatch than n_batch and then using llama_batch_get_one(). The decision of what n_outputs should be now almost fully depends on how lctx.n_outputs is set in llama_decode_internal. The conditions are simpler this way. * llama : when saving the state, recalculate n_outputs This ensures the correct number of outputs for the entire previous batch is stored in the session file, even when n_ubatch is smaller than n_batch. * llama : fix not-skipping outputs of non-causal models * llama : fix running a batch with n_outputs == 0 It previously worked because lctx.inp_out_ids was not initialized, so it pointed to some garbage address which was somehow still valid when I ran my tests. * llama : keep same graph topology even when n_outputs == 0 * ggml : saner ggml_can_repeat with empty tensors * ggml : future-proof ggml_is_empty by using GGML_MAX_DIMS - 1 * ggml : do not multi-thread ops returning empty tensors * ggml : make ggml_is_empty public and work with views * llama : use a vector for ctx->output_ids * llama : rework reallocation logic for llama_output_reserve Now comparing the actual size with the new total size of the output buffer to allow more efficient enabling and disabling of the embeddings and/or logits output in the future. * ggml : skip empty tensors in all backends * llama : fix llama_output_reserve nullptr deref when new_size is 0 * perplexity : make Winogrande work as it does on master The problems with the Winogrande implementation will need to be fixed in a separate PR to ease review. * llama : clearer error messages for invalid logits or embeddings ids * llama : assert all models that can have inp_out_ids Since the graph topology is now constant, this presence check can be done even when there are no outputs. * llama : assert logits and embd buffers exist before writing to them * llama : handle errors from llama_output_reserve at call sites * perplexity : make hellaswag and multiple-choice outputs identical to master Due to how the KV cache is updated, the logprobs for tokens in a batch are very slightly affected by the other tokens present in the batch, so to make hellaswag and multiple-choice return exactly the same results as on master, the last token of each sequence needs to be evaluated even though its output is not used at all. This will probably be changed back in the future to make these benchmarks a tiny bit faster. * perplexity : fix division by zero when using less than 100 multiple-choice tasks * llama : allow loading state saved with a different ctx size When loading a session file, the context size is now only required to be at least enough to load the KV cells contained in that session file, instead of requiring to use exactly the same context size as when saving. Doing this enables the use-case of extending or shrinking the context size of a saved session. This breaks existing session files because the meaning of kv_buf_size is slightly changed (previously it was the size of the whole KV cache, now it's only the size of the saved part of it). This allows for finer-grained sanity checks when loading in an effort to keep kv_buf_size useful even when the kv_size is changed. * llama : minor ggml-ci * readme : update recent API changes, and warn about Vulkan --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp694
1 files changed, 560 insertions, 134 deletions
diff --git a/llama.cpp b/llama.cpp
index 68c360c7..22db79d6 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1777,6 +1777,7 @@ struct llama_cparams {
uint32_t n_ctx; // context size used during inference
uint32_t n_batch;
uint32_t n_ubatch;
+ uint32_t n_seq_max;
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing
@@ -2139,20 +2140,20 @@ struct llama_context {
// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_t buf_output = nullptr;
- // decode output (2-dimensional array: [n_tokens][n_vocab])
- size_t logits_size = 0;
- float * logits = nullptr;
+ // decode output (2-dimensional array: [n_outputs][n_vocab])
+ size_t logits_size = 0; // capacity (of floats) for logits
+ float * logits = nullptr;
+
+ std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
+ size_t output_size = 0; // capacity (of tokens positions) for the output buffers
+ int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch
-#ifndef NDEBUG
- // guard against access to unset logits
- std::vector<bool> logits_valid;
-#endif
bool logits_all = false;
- // embeddings output (2-dimensional array: [n_tokens][n_embd])
+ // embeddings output (2-dimensional array: [n_outputs][n_embd])
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
- size_t embd_size = 0;
- float * embd = nullptr;
+ size_t embd_size = 0; // capacity (of floats) for embeddings
+ float * embd = nullptr;
// sequence embeddings output (map of [n_embd] vectors)
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
@@ -2169,14 +2170,15 @@ struct llama_context {
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch]
+ struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
- struct ggml_tensor * inp_KQ_pos; // F32 [kv_size]
+ struct ggml_tensor * inp_KQ_pos; // F32 [n_kv]
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
- struct ggml_tensor * inp_s_mask; // F32 [1, kv_size]
- struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
+ struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
+ struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
// control vectors
struct llama_control_vector cvec;
@@ -5846,7 +5848,8 @@ struct llm_build_context {
const float norm_rms_eps;
const int32_t n_tokens;
- const int32_t n_kv; // size of KV cache to consider (n_kv <= n_ctx)
+ const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size)
+ const int32_t n_outputs;
const int32_t kv_head; // index of where we store new KV data in the cache
const int32_t n_orig_ctx;
@@ -5893,6 +5896,7 @@ struct llm_build_context {
norm_rms_eps (hparams.f_norm_rms_eps),
n_tokens (batch.n_tokens),
n_kv (worst_case ? kv_self.size : kv_self.n),
+ n_outputs (worst_case ? n_tokens : lctx.n_outputs),
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx),
pooling_type (cparams.pooling_type),
@@ -5914,6 +5918,7 @@ struct llm_build_context {
lctx.inp_tokens = nullptr;
lctx.inp_embd = nullptr;
lctx.inp_pos = nullptr;
+ lctx.inp_out_ids = nullptr;
lctx.inp_KQ_mask = nullptr;
lctx.inp_KQ_pos = nullptr;
lctx.inp_K_shift = nullptr;
@@ -6037,6 +6042,13 @@ struct llm_build_context {
return lctx.inp_pos;
}
+ struct ggml_tensor * build_inp_out_ids() {
+ lctx.inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
+ cb(lctx.inp_out_ids, "inp_out_ids", -1);
+ ggml_set_input(lctx.inp_out_ids);
+ return lctx.inp_out_ids;
+ }
+
struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
if (causal) {
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens);
@@ -6093,6 +6105,9 @@ struct llm_build_context {
struct ggml_cgraph * build_llama() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+ // mutable variable, needed during the last layer of the computation to skip unused tokens
+ int32_t n_tokens = this->n_tokens;
+
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
@@ -6160,6 +6175,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ n_tokens = n_outputs;
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@@ -6339,6 +6362,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@@ -6454,6 +6484,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ attn_norm = ggml_get_rows(ctx0, attn_norm, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = cur;
// feed forward
@@ -6497,6 +6535,9 @@ struct llm_build_context {
struct ggml_cgraph * build_grok() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+ // mutable variable, needed during the last layer of the computation to skip unused tokens
+ int32_t n_tokens = this->n_tokens;
+
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
@@ -6568,6 +6609,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ n_tokens = n_outputs;
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
// Grok
// if attn_out_norm is present then apply it before adding the input
if (model.layers[il].attn_out_norm) {
@@ -6745,6 +6794,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
@@ -6942,6 +6998,13 @@ struct llm_build_context {
Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ residual = ggml_get_rows(ctx0, residual, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
cb(ffn_inp, "ffn_inp", il);
@@ -7031,6 +7094,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@@ -7188,6 +7258,13 @@ struct llm_build_context {
}
cb(cur, "kqv_out", il);
+ if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// re-add the layer input
cur = ggml_add(ctx0, cur, inpL);
@@ -7310,6 +7387,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// Add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
@@ -7408,6 +7492,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// Add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
@@ -7521,6 +7612,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@@ -7627,6 +7725,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@@ -7739,6 +7844,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@@ -7857,6 +7969,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids);
+ }
+
// FF
{
ffn_output = llm_build_ffn(ctx0, attn_norm_output,
@@ -7954,6 +8074,14 @@ struct llm_build_context {
cur = attention_norm;
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ sa_out = ggml_get_rows(ctx0, sa_out, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// feed-forward network
{
cur = llm_build_ffn(ctx0, cur,
@@ -8046,6 +8174,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
@@ -8146,6 +8281,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
@@ -8255,6 +8397,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@@ -8365,6 +8514,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@@ -8488,6 +8644,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
// scale_res - scale the hidden states for residual connection
const float scale_res = scale_depth/sqrtf(float(n_layer));
cur = ggml_scale(ctx0, cur, scale_res);
@@ -8602,6 +8765,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
cb(sa_out, "sa_out", il);
@@ -8714,6 +8884,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@@ -8861,6 +9038,15 @@ struct llm_build_context {
struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ x = ggml_get_rows(ctx0, x, inp_out_ids);
+ y = ggml_get_rows(ctx0, y, inp_out_ids);
+ z = ggml_get_rows(ctx0, z, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens}
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
@@ -8963,6 +9149,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
struct ggml_tensor * attn_out = cur;
// feed-forward network
@@ -9260,9 +9453,39 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
}
+ if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
+ GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
+ const int64_t n_tokens = batch.n_tokens;
+
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
+ int32_t * data = (int32_t *) lctx.inp_out_ids->data;
+
+ if (lctx.n_outputs == n_tokens) {
+ for (int i = 0; i < n_tokens; ++i) {
+ data[i] = i;
+ }
+ } else if (batch.logits) {
+ int32_t n_outputs = 0;
+ for (int i = 0; i < n_tokens; ++i) {
+ if (batch.logits[i]) {
+ data[n_outputs++] = i;
+ }
+ }
+ // the graph needs to have been passed the correct number of outputs
+ GGML_ASSERT(lctx.n_outputs == n_outputs);
+ } else if (lctx.n_outputs == 1) {
+ // only keep last output
+ data[0] = n_tokens - 1;
+ } else {
+ GGML_ASSERT(lctx.n_outputs == 0);
+ }
+ }
+
GGML_ASSERT(
+ // (!a || b) is a logical implication (a -> b)
+ // !hparams.causal_attn -> !cparams.causal_attn
(hparams.causal_attn || !cparams.causal_attn) &&
- "non-causal attention with generative models is not supported"
+ "causal attention with embedding models is not supported"
);
if (lctx.inp_KQ_mask) {
@@ -9441,6 +9664,74 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
+// Make sure enough space is available for outputs.
+// Returns max number of outputs for which space was reserved.
+static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
+ const auto & cparams = lctx.cparams;
+ const auto & hparams = lctx.model.hparams;
+
+ const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
+
+ const auto n_batch = cparams.n_batch;
+ const auto n_vocab = hparams.n_vocab;
+ const auto n_embd = hparams.n_embd;
+
+ // TODO: use a per-batch flag for logits presence instead
+ const bool has_logits = cparams.causal_attn;
+ const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
+
+ const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
+ const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
+
+ if (lctx.output_ids.empty()) {
+ // init, never resized afterwards
+ lctx.output_ids.resize(n_batch);
+ }
+
+ const size_t prev_size = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output) : 0;
+ const size_t new_size = (logits_size + embd_size) * sizeof(float);
+
+ // alloc only when more than the current capacity is required
+ // TODO: also consider shrinking the buffer
+ if (!lctx.buf_output || prev_size < new_size) {
+ if (lctx.buf_output) {
+#ifndef NDEBUG
+ // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
+ LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
+#endif
+ ggml_backend_buffer_free(lctx.buf_output);
+ lctx.buf_output = nullptr;
+ lctx.logits = nullptr;
+ lctx.embd = nullptr;
+ }
+
+ lctx.buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), new_size);
+ if (lctx.buf_output == nullptr) {
+ LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
+ return 0;
+ }
+ }
+
+ float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output);
+
+ lctx.logits = has_logits ? output_base : nullptr;
+ lctx.embd = has_embd ? output_base + logits_size : nullptr;
+
+ lctx.output_size = n_outputs_max;
+ lctx.logits_size = logits_size;
+ lctx.embd_size = embd_size;
+
+ // set all ids as invalid (negative)
+ std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1);
+
+ ggml_backend_buffer_clear(lctx.buf_output, 0);
+
+ lctx.n_outputs = 0;
+
+ return n_outputs_max;
+}
+
+
static void llama_graph_compute(
llama_context & lctx,
ggml_cgraph * gf,
@@ -9516,16 +9807,8 @@ static int llama_decode_internal(
const int64_t n_embd = hparams.n_embd;
const int64_t n_vocab = hparams.n_vocab;
-
- auto * logits_out = lctx.logits;
-
-#ifndef NDEBUG
- auto & logits_valid = lctx.logits_valid;
- logits_valid.clear();
- logits_valid.resize(n_tokens_all);
-
- memset(logits_out, 0, lctx.logits_size*sizeof(float));
-#endif
+ uint32_t n_outputs = 0;
+ uint32_t n_outputs_prev = 0;
const auto n_ubatch = cparams.n_ubatch;
@@ -9534,6 +9817,38 @@ static int llama_decode_internal(
std::vector<llama_seq_id *> seq_id_arr;
std::vector<std::vector<llama_seq_id>> seq_id;
+ // count outputs
+ if (batch_all.logits) {
+ for (uint32_t i = 0; i < n_tokens_all; ++i) {
+ n_outputs += batch_all.logits[i] != 0;
+ }
+ } else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
+ n_outputs = n_tokens_all;
+ } else {
+ // keep last output only
+ n_outputs = 1;
+ }
+
+ // reserve output buffer
+ if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
+ return -2;
+ };
+
+ // set output mappings
+ if (batch_all.logits) {
+ int32_t i_logits = 0;
+ for (uint32_t i = 0; i < n_tokens_all; ++i) {
+ if (batch_all.logits[i]) {
+ lctx.output_ids[i] = i_logits++;
+ }
+ }
+ } else {
+ for (uint32_t i = 0; i < n_outputs; ++i) {
+ lctx.output_ids[i] = i;
+ }
+ }
+
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
llama_batch u_batch = {
@@ -9549,6 +9864,27 @@ static int llama_decode_internal(
/* .all_seq_id = */ batch_all.all_seq_id,
};
+ // count the outputs in this u_batch
+ {
+ int32_t n_outputs_new = 0;
+
+ if (u_batch.logits) {
+ for (uint32_t i = 0; i < n_tokens; i++) {
+ n_outputs_new += u_batch.logits[i] != 0;
+ }
+ } else if (n_outputs == n_tokens_all) {
+ n_outputs_new = n_tokens;
+ } else {
+ // keep last output only
+ if (cur_token + n_tokens >= n_tokens_all) {
+ n_outputs_new = 1;
+ }
+ }
+
+ // needs to happen before the graph is built
+ lctx.n_outputs = n_outputs_new;
+ }
+
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
GGML_ASSERT(n_threads > 0);
@@ -9612,23 +9948,37 @@ static int llama_decode_internal(
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
- if (!hparams.causal_attn) {
+ if (lctx.n_outputs == 0) {
+ // no output
+ res = nullptr;
+ embd = nullptr;
+ } else if (!hparams.causal_attn) {
res = nullptr; // do not extract logits for embedding models such as BERT
// token or sequence embeddings
embd = gf->nodes[gf->n_nodes - 1];
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
- } else {
- if (strcmp(res->name, "result_output") == 0) {
- // the token embeddings could be the second to last tensor, or the third to last tensor
- if (strcmp(embd->name, "result_norm") != 0) {
- embd = gf->nodes[gf->n_nodes - 3];
- GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
- }
- } else {
- GGML_ASSERT(false && "missing result_output tensor");
+ } else if (cparams.embeddings) {
+ // the embeddings could be in the second to last tensor, or any of the previous tensors
+ int i_embd = gf->n_nodes - 2;
+ for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) {
+ i_embd = gf->n_nodes - i;
+ if (i_embd < 0) { break; }
+ embd = gf->nodes[i_embd];
+ }
+ GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor");
+
+ // TODO: use a per-batch flag to know when to skip logits while keeping embeddings
+ if (!cparams.causal_attn) {
+ res = nullptr; // do not extract logits when not needed
+ // skip computing logits
+ // TODO: is this safe?
+ gf->n_nodes = i_embd + 1;
}
+ } else {
+ embd = nullptr; // do not extract embeddings when not needed
+ GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
}
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
@@ -9671,50 +10021,23 @@ static int llama_decode_internal(
//}
// extract logits
- // TODO: do not compute and extract logits if only embeddings are needed
- // update the graphs to skip "result_output" if logits are not needed
if (res) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
GGML_ASSERT(backend_res != nullptr);
- if (u_batch.logits) {
- int32_t i_first = -1;
- for (uint32_t i = 0; i < n_tokens; i++) {
- if (u_batch.logits[i] && i_first == -1) {
- i_first = (int32_t) i;
- }
- if (u_batch.logits[i] == 0 || i == n_tokens - 1) {
- if (i_first != -1) {
- int i_last = u_batch.logits[i] == 0 ? i : i + 1;
- // extract logits for the range [i_first, i_last)
- // group the requests to minimize the number of calls to the backend
- ggml_backend_tensor_get_async(backend_res, res,
- logits_out + n_vocab*(cur_token + i_first),
- i_first*n_vocab*sizeof(float),
- (i_last - i_first)*n_vocab*sizeof(float));
- i_first = -1;
- }
- }
-#ifndef NDEBUG
- logits_valid[cur_token + i] = u_batch.logits[i] != 0;;
-#endif
- }
- } else if (lctx.logits_all) {
- ggml_backend_tensor_get_async(backend_res, res, logits_out + n_vocab*cur_token, 0, n_vocab*n_tokens*sizeof(float));
-#ifndef NDEBUG
- std::fill(logits_valid.begin() + cur_token, logits_valid.begin() + cur_token + n_tokens, true);
-#endif
- } else {
- if (cur_token + n_tokens >= n_tokens_all) {
- ggml_backend_tensor_get_async(backend_res, res, logits_out, n_vocab*(n_tokens - 1)*sizeof(float), n_vocab*sizeof(float));
-#ifndef NDEBUG
- logits_valid[0] = true;
-#endif
- }
+ GGML_ASSERT(lctx.logits != nullptr);
+
+ float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
+ const int32_t n_outputs_new = lctx.n_outputs;
+
+ if (n_outputs_new) {
+ GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
+ GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
+ ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
}
}
// extract embeddings
- if (cparams.embeddings && embd) {
+ if (embd) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
GGML_ASSERT(backend_embd != nullptr);
@@ -9722,16 +10045,14 @@ static int llama_decode_internal(
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
- auto & embd_out = lctx.embd;
-
- if (u_batch.logits) {
- //embd_out.resize(n_embd * n_tokens);
- for (uint32_t i = 0; i < n_tokens; i++) {
- if (u_batch.logits[i] == 0) {
- continue;
- }
- ggml_backend_tensor_get_async(backend_embd, embd, embd_out + n_embd*(i + cur_token), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
- }
+ GGML_ASSERT(lctx.embd != nullptr);
+ float * embd_out = lctx.embd + n_outputs_prev*n_embd;
+ const int32_t n_outputs_new = lctx.n_outputs;
+
+ if (n_outputs_new) {
+ GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
+ GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_embd <= (int64_t) lctx.embd_size);
+ ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_CLS:
@@ -9758,6 +10079,7 @@ static int llama_decode_internal(
} break;
}
}
+ n_outputs_prev += lctx.n_outputs;
}
// wait for the computation to finish (automatically done when obtaining the model output)
@@ -13531,7 +13853,7 @@ struct llama_context * llama_new_context_with_model(
const auto & hparams = model->hparams;
auto & cparams = ctx->cparams;
- // TODO: maybe add n_seq_max here too
+ cparams.n_seq_max = std::max(1u, params.n_seq_max);
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
@@ -13733,25 +14055,12 @@ struct llama_context * llama_new_context_with_model(
// graph outputs buffer
{
- // resized during inference, reserve maximum
- ctx->logits_size = hparams.n_vocab*cparams.n_batch;
- ctx->embd_size = params.embeddings ? hparams.n_embd*cparams.n_batch : 0;
-
- const size_t buf_output_size = (ctx->logits_size + ctx->embd_size)*sizeof(float);
-
- ctx->buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buf_output_size);
- if (ctx->buf_output == nullptr) {
- LLAMA_LOG_ERROR("%s: failed to allocate logits buffer\n", __func__);
+ // resized during inference when a batch uses more outputs
+ if (llama_output_reserve(*ctx, params.n_seq_max) < params.n_seq_max) {
+ LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__);
llama_free(ctx);
return nullptr;
}
- ggml_backend_buffer_clear(ctx->buf_output, 0);
-
-
- ctx->logits = (float *) ggml_backend_buffer_get_base(ctx->buf_output);
- if (params.embeddings) {
- ctx->embd = ctx->logits + ctx->logits_size;
- }
LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__,
ggml_backend_buffer_name(ctx->buf_output),
@@ -14268,27 +14577,33 @@ void llama_kv_cache_update(struct llama_context * ctx) {
// Returns the *maximum* size of the state
size_t llama_get_state_size(const struct llama_context * ctx) {
+ const auto & cparams = ctx->cparams;
+ const auto & hparams = ctx->model.hparams;
+
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
// for reference, std::mt19937(1337) serializes to 6701 bytes.
const size_t s_rng_size = sizeof(size_t);
const size_t s_rng = LLAMA_MAX_RNG_STATE;
+ const size_t s_n_outputs = sizeof(size_t);
+ // assume worst case for outputs although only currently set ones are serialized
+ const size_t s_output_pos = ctx->cparams.n_batch * sizeof(int32_t);
const size_t s_logits_size = sizeof(size_t);
- // assume worst case for logits although only currently set ones are serialized
- const size_t s_logits = ctx->logits_size * sizeof(float);
+ const size_t s_logits = ctx->logits_size ? cparams.n_batch * hparams.n_vocab * sizeof(float) : 0;
const size_t s_embedding_size = sizeof(size_t);
- const size_t s_embedding = ctx->embd_size * sizeof(float);
+ const size_t s_embedding = ctx->embd_size ? cparams.n_batch * hparams.n_embd * sizeof(float) : 0;
const size_t s_kv_buf_size = sizeof(size_t);
const size_t s_kv_head = sizeof(uint32_t);
const size_t s_kv_size = sizeof(uint32_t);
const size_t s_kv_used = sizeof(uint32_t);
const size_t s_kv = ctx->kv_self.total_size();
- // TODO: assume the max is more than 1 seq_id per KV cell
- const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + sizeof(llama_seq_id);
+ const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id);
const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell;
const size_t s_total = (
+ s_rng_size
+ s_rng
+ + s_n_outputs
+ + s_output_pos
+ s_logits_size
+ s_logits
+ s_embedding_size
@@ -14363,7 +14678,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
std::ostringstream rng_ss;
rng_ss << ctx->rng;
- const std::string & rng_str = rng_ss.str();
+ const std::string & rng_str = rng_ss.str();
const size_t rng_size = rng_str.size();
GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
@@ -14372,25 +14687,61 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
data_ctx->write(rng_str.data(), rng_size);
}
- // copy logits
+ // copy outputs
{
- const size_t logits_size = ctx->logits_size;
+ // Can't use ctx->n_outputs because it's not for the
+ // entire last batch when n_ubatch is smaller than n_batch
+ size_t n_outputs = 0;
- data_ctx->write(&logits_size, sizeof(logits_size));
+ // copy output ids
+ {
+ std::vector<int32_t> output_pos;
- if (logits_size) {
- data_ctx->write(ctx->logits, logits_size * sizeof(float));
+ const size_t n_batch = ctx->cparams.n_batch;
+ const auto & output_ids = ctx->output_ids;
+
+ output_pos.resize(ctx->output_size);
+
+ // build a more compact representation of the output ids
+ for (size_t i = 0; i < n_batch; ++i) {
+ // map an output id to a position in the batch
+ int32_t pos = output_ids[i];
+ if (pos >= 0) {
+ if ((size_t) pos >= n_outputs) {
+ n_outputs = pos + 1;
+ }
+ GGML_ASSERT((size_t) pos < ctx->output_size);
+ output_pos[pos] = i;
+ }
+ }
+
+ data_ctx->write(&n_outputs, sizeof(n_outputs));
+
+ if (n_outputs) {
+ data_ctx->write(output_pos.data(), n_outputs * sizeof(int32_t));
+ }
}
- }
- // copy embeddings
- {
- const size_t embeddings_size = ctx->embd_size;
+ // copy logits
+ {
+ const size_t logits_size = std::min(ctx->logits_size, n_outputs * ctx->model.hparams.n_vocab);
- data_ctx->write(&embeddings_size, sizeof(embeddings_size));
+ data_ctx->write(&logits_size, sizeof(logits_size));
- if (embeddings_size) {
- data_ctx->write(ctx->embd, embeddings_size * sizeof(float));
+ if (logits_size) {
+ data_ctx->write(ctx->logits, logits_size * sizeof(float));
+ }
+ }
+
+ // copy embeddings
+ {
+ const size_t embeddings_size = std::min(ctx->embd_size, n_outputs * ctx->model.hparams.n_embd);
+
+ data_ctx->write(&embeddings_size, sizeof(embeddings_size));
+
+ if (embeddings_size) {
+ data_ctx->write(ctx->embd, embeddings_size * sizeof(float));
+ }
}
}
@@ -14403,9 +14754,10 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
- const size_t kv_buf_size = kv_self.total_size();
+ // NOTE: kv_size and kv_buf_size are mostly used for sanity checks
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
const uint32_t kv_size = kv_self.size;
+ const size_t kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head;
const uint32_t kv_used = kv_self.used;
data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
@@ -14414,6 +14766,8 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
data_ctx->write(&kv_used, sizeof(kv_used));
if (kv_buf_size) {
+ const size_t pre_kv_buf_size = data_ctx->get_size_written();
+
std::vector<uint8_t> tmp_buf;
for (int il = 0; il < (int) n_layer; ++il) {
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
@@ -14443,6 +14797,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
data_ctx->write(tmp_buf.data(), tmp_buf.size());
}
}
+ GGML_ASSERT(kv_buf_size == data_ctx->get_size_written() - pre_kv_buf_size);
}
for (uint32_t i = 0; i < kv_head; ++i) {
@@ -14487,6 +14842,28 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
GGML_ASSERT(!rng_ss.fail());
}
+ // set output ids
+ {
+ size_t n_outputs;
+ std::vector<int32_t> output_pos;
+
+ memcpy(&n_outputs, inp, sizeof(n_outputs)); inp += sizeof(n_outputs);
+
+ GGML_ASSERT(n_outputs <= llama_output_reserve(*ctx, n_outputs));
+
+ if (n_outputs) {
+ output_pos.resize(n_outputs);
+ memcpy(output_pos.data(), inp, n_outputs * sizeof(int32_t));
+ inp += n_outputs * sizeof(int32_t);
+
+ for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
+ int32_t id = output_pos[i];
+ GGML_ASSERT((uint32_t) id < ctx->cparams.n_batch);
+ ctx->output_ids[id] = i;
+ }
+ }
+ }
+
// set logits
{
size_t logits_size;
@@ -14507,7 +14884,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);
- GGML_ASSERT(ctx->embd_size == embeddings_size);
+ GGML_ASSERT(ctx->embd_size >= embeddings_size);
if (embeddings_size) {
memcpy(ctx->embd, inp, embeddings_size * sizeof(float));
@@ -14534,8 +14911,18 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used);
+ if (kv_self.size != kv_size) {
+ // the KV cache needs to be big enough to load all the KV cells from the saved state
+ GGML_ASSERT(kv_self.size >= kv_head);
+
+ LLAMA_LOG_INFO("%s: state contains %d KV cells, was saved with kv_size=%d, but is loaded with kv_size=%d (fine, but different)\n",
+ __func__, kv_head, kv_size, kv_self.size);
+ }
+
if (kv_buf_size) {
- GGML_ASSERT(kv_self.total_size() == kv_buf_size);
+ const size_t pre_kv_buf_size = inp - src;
+
+ GGML_ASSERT(kv_self.total_size() >= kv_buf_size);
for (int il = 0; il < (int) n_layer; ++il) {
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
@@ -14555,23 +14942,21 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
// v is not contiguous, copy row by row
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
- const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
+ const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_self.size);
for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size);
inp += v_row_size;
}
}
+ GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size);
}
- GGML_ASSERT(kv_self.size == kv_size);
+ llama_kv_cache_clear(ctx);
ctx->kv_self.head = kv_head;
- ctx->kv_self.size = kv_size;
ctx->kv_self.used = kv_used;
- ctx->kv_self.cells.resize(kv_size);
-
for (uint32_t i = 0; i < kv_head; ++i) {
llama_pos pos;
size_t seq_id_size;
@@ -14588,11 +14973,6 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
ctx->kv_self.cells[i].seq_id.insert(seq_id);
}
}
-
- for (uint32_t i = kv_head; i < kv_size; ++i) {
- ctx->kv_self.cells[i].pos = -1;
- ctx->kv_self.cells[i].seq_id.clear();
- }
}
const size_t nread = inp - src;
@@ -14798,11 +15178,33 @@ float * llama_get_logits(struct llama_context * ctx) {
}
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
- assert(ctx->logits_valid.at(i));
-
llama_synchronize(ctx);
- return ctx->logits + i*ctx->model.hparams.n_vocab;
+ try {
+ if (ctx->logits == nullptr) {
+ throw std::runtime_error("no logits");
+ }
+ if ((size_t) i >= ctx->output_ids.size()) {
+ throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
+ }
+ const int32_t j = ctx->output_ids[i];
+
+ if (j < 0) {
+ throw std::runtime_error(format("batch.logits[%d] != true", i));
+ }
+ if ((size_t) j >= ctx->output_size) {
+ // This should not happen
+ throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
+ }
+
+ return ctx->logits + j*ctx->model.hparams.n_vocab;
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
+#ifndef NDEBUG
+ GGML_ASSERT(false);
+#endif
+ return nullptr;
+ }
}
float * llama_get_embeddings(struct llama_context * ctx) {
@@ -14814,7 +15216,31 @@ float * llama_get_embeddings(struct llama_context * ctx) {
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
llama_synchronize(ctx);
- return ctx->embd + i*ctx->model.hparams.n_embd;
+ try {
+ if (ctx->embd == nullptr) {
+ throw std::runtime_error("no embeddings");
+ }
+ if ((size_t) i >= ctx->output_ids.size()) {
+ throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
+ }
+ const int32_t j = ctx->output_ids[i];
+
+ if (j < 0) {
+ throw std::runtime_error(format("batch.logits[%d] != true", i));
+ }
+ if ((size_t) j >= ctx->output_size) {
+ // This should not happen
+ throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
+ }
+
+ return ctx->embd + j*ctx->model.hparams.n_embd;
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
+#ifndef NDEBUG
+ GGML_ASSERT(false);
+#endif
+ return nullptr;
+ }
}
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {