summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorcompilade <113953597+compilade@users.noreply.github.com>2024-03-26 10:46:41 -0400
committerGitHub <noreply@github.com>2024-03-26 16:46:41 +0200
commit557410b8f06380560155ac7fcb8316d71ddc9837 (patch)
treecccf337eae34b1e1e46aef188678132b777b2ff7
parent55c1b2a3bbd470e9e2a3a0618b92cf64a885f806 (diff)
llama : greatly reduce output buffer memory usage (#6122)
* llama : greatly reduce logits memory usage * llama : more compact state saving and reloading * llama : fix lctx.n_outputs not being set before building graph * perplexity : adapt to the logits API changes * perplexity : fix Winogrande, use correct logits for second choice start The first logits used to evaluate the second choice were not from the end of the common prefix; instead, they were the logits from the end of the first choice. This has been corrected. The previous implementation sometimes had outliers in the scores of choices for some tasks, and the logic to skip choices words in the log-likelihood evaluation probably was an attempt to reduce those, but it was complex and didn't quite seem to be the right thing. This is simpler now, and the outlier scores aren't there anymore. * perplexity : normalize spaces and punctuation in Winogrande sentences * llama : fix embedding conditions * llama : fix llama_get_embeddings_ith when the resulting id is 0 * llama : fix wrong n_outputs in llama_set_inputs A mismatch happened when using a smaller n_ubatch than n_batch and then using llama_batch_get_one(). The decision of what n_outputs should be now almost fully depends on how lctx.n_outputs is set in llama_decode_internal. The conditions are simpler this way. * llama : when saving the state, recalculate n_outputs This ensures the correct number of outputs for the entire previous batch is stored in the session file, even when n_ubatch is smaller than n_batch. * llama : fix not-skipping outputs of non-causal models * llama : fix running a batch with n_outputs == 0 It previously worked because lctx.inp_out_ids was not initialized, so it pointed to some garbage address which was somehow still valid when I ran my tests. * llama : keep same graph topology even when n_outputs == 0 * ggml : saner ggml_can_repeat with empty tensors * ggml : future-proof ggml_is_empty by using GGML_MAX_DIMS - 1 * ggml : do not multi-thread ops returning empty tensors * ggml : make ggml_is_empty public and work with views * llama : use a vector for ctx->output_ids * llama : rework reallocation logic for llama_output_reserve Now comparing the actual size with the new total size of the output buffer to allow more efficient enabling and disabling of the embeddings and/or logits output in the future. * ggml : skip empty tensors in all backends * llama : fix llama_output_reserve nullptr deref when new_size is 0 * perplexity : make Winogrande work as it does on master The problems with the Winogrande implementation will need to be fixed in a separate PR to ease review. * llama : clearer error messages for invalid logits or embeddings ids * llama : assert all models that can have inp_out_ids Since the graph topology is now constant, this presence check can be done even when there are no outputs. * llama : assert logits and embd buffers exist before writing to them * llama : handle errors from llama_output_reserve at call sites * perplexity : make hellaswag and multiple-choice outputs identical to master Due to how the KV cache is updated, the logprobs for tokens in a batch are very slightly affected by the other tokens present in the batch, so to make hellaswag and multiple-choice return exactly the same results as on master, the last token of each sequence needs to be evaluated even though its output is not used at all. This will probably be changed back in the future to make these benchmarks a tiny bit faster. * perplexity : fix division by zero when using less than 100 multiple-choice tasks * llama : allow loading state saved with a different ctx size When loading a session file, the context size is now only required to be at least enough to load the KV cells contained in that session file, instead of requiring to use exactly the same context size as when saving. Doing this enables the use-case of extending or shrinking the context size of a saved session. This breaks existing session files because the meaning of kv_buf_size is slightly changed (previously it was the size of the whole KV cache, now it's only the size of the saved part of it). This allows for finer-grained sanity checks when loading in an effort to keep kv_buf_size useful even when the kv_size is changed. * llama : minor ggml-ci * readme : update recent API changes, and warn about Vulkan --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
-rw-r--r--README.md10
-rw-r--r--examples/imatrix/imatrix.cpp1
-rw-r--r--examples/parallel/parallel.cpp1
-rw-r--r--examples/perplexity/perplexity.cpp129
-rw-r--r--examples/server/server.cpp3
-rw-r--r--examples/speculative/speculative.cpp1
-rw-r--r--ggml-cuda.cu2
-rw-r--r--ggml-kompute.cpp4
-rw-r--r--ggml-metal.m4
-rw-r--r--ggml-opencl.cpp5
-rw-r--r--ggml-sycl.cpp2
-rw-r--r--ggml-vulkan.cpp2
-rw-r--r--ggml.c20
-rw-r--r--ggml.h1
-rw-r--r--llama.cpp694
-rw-r--r--llama.h24
16 files changed, 705 insertions, 198 deletions
diff --git a/README.md b/README.md
index ce678f0c..a56a6004 100644
--- a/README.md
+++ b/README.md
@@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
### Recent API changes
+- [2024 Mar 26] Logits and embeddings API updated for compactness https://github.com/ggerganov/llama.cpp/pull/6122
- [2024 Mar 13] Add `llama_synchronize()` + `llama_context_params.n_ubatch` https://github.com/ggerganov/llama.cpp/pull/6017
- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_seq_max()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328
- [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796
@@ -630,6 +631,15 @@ Building the program with BLAS support may lead to some performance improvements
- #### Vulkan
+> [!WARNING]
+>
+> Vulkan support has been broken in https://github.com/ggerganov/llama.cpp/pull/6122
+> due to relying on `GGML_OP_GET_ROWS` which is not yet properly supported by the Vulkan backend,
+> but should be fixed relatively soon (possibly in https://github.com/ggerganov/llama.cpp/pull/6155
+> (ref: https://github.com/ggerganov/llama.cpp/pull/6122#issuecomment-2015327635)).
+>
+> Meanwhile, if you want to use the Vulkan backend, you should use the commit right before the breaking change, https://github.com/ggerganov/llama.cpp/commit/55c1b2a3bbd470e9e2a3a0618b92cf64a885f806
+
**With docker**:
You don't need to install Vulkan SDK. It will be installed inside the container.
diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp
index 264e73f4..12d34462 100644
--- a/examples/imatrix/imatrix.cpp
+++ b/examples/imatrix/imatrix.cpp
@@ -424,6 +424,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params, bool
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
}
+ // TODO: use batch.logits to save computations instead of relying on logits_all == true
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp
index a2ef0fb0..f66c9101 100644
--- a/examples/parallel/parallel.cpp
+++ b/examples/parallel/parallel.cpp
@@ -132,7 +132,6 @@ int main(int argc, char ** argv) {
llama_context * ctx = NULL;
// load the target model
- params.logits_all = true;
std::tie(model, ctx) = llama_init_from_gpt_params(params);
// load the prompts from an external file if there are any
diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp
index d766aef6..c70385c6 100644
--- a/examples/perplexity/perplexity.cpp
+++ b/examples/perplexity/perplexity.cpp
@@ -380,6 +380,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
const int batch_size = std::min(end - batch_start, n_batch);
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
+ // TODO: use llama_batch.logits instead of relying on logits_all == true
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
//fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
@@ -552,6 +553,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
+ int n_outputs = 0;
+
batch.n_tokens = 0;
for (int seq = 0; seq < n_seq_batch; seq++) {
int seq_start = batch_start + seq*n_ctx;
@@ -566,11 +569,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
for (int k = 0; k < batch_size; ++k) {
const int idx = seq*n_ctx + k;
- batch.token[idx] = tokens[seq_start + k];
- batch.pos[idx] = j*n_batch + k;
- batch.n_seq_id[idx] = 1;
- batch.seq_id[idx][0] = seq;
- batch.logits[idx] = batch.pos[idx] >= first ? 1 : 0;
+ batch.token [idx] = tokens[seq_start + k];
+ batch.pos [idx] = j*n_batch + k;
+ batch.n_seq_id[idx] = 1;
+ batch.seq_id [idx][0] = seq;
+ batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
+
+ n_outputs += batch.logits[idx] != 0;
}
batch.n_tokens += batch_size;
@@ -583,9 +588,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
return {tokens, -1, logit_history, prob_history};
}
- if (num_batches > 1) {
+ if (num_batches > 1 && n_outputs > 0) {
const auto * batch_logits = llama_get_logits(ctx);
- logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
+ logits.insert(logits.end(), batch_logits, batch_logits + n_outputs * n_vocab);
}
}
@@ -604,14 +609,15 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
}
for (int seq = 0; seq < n_seq_batch; seq++) {
- const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx);
+ const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
+
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
if (!params.logits_file.empty()) {
- process_logits(logits_stream, n_vocab, all_logits + first*n_vocab,
+ process_logits(logits_stream, n_vocab, all_logits,
tokens_data, n_ctx - 1 - first,
workers, log_probs, nll, nll2);
} else {
- process_logits(n_vocab, all_logits + first*n_vocab,
+ process_logits(n_vocab, all_logits,
tokens_data, n_ctx - 1 - first,
workers, nll, nll2,
logit_history.data() + start + seq*n_ctx + first,
@@ -652,6 +658,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
}
static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<float> & batch_logits, int32_t n_batch, int32_t n_vocab) {
+ int prev_outputs = 0;
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
@@ -672,7 +679,14 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
return false;
}
- memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
+ int n_outputs = 0;
+ for (int i = 0; i < n_tokens; ++i) {
+ n_outputs += batch_view.logits[i] != 0;
+ }
+
+ memcpy(batch_logits.data() + prev_outputs*n_vocab, llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float));
+
+ prev_outputs += n_outputs;
}
return true;
@@ -779,7 +793,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
size_t ending_logprob_count[4];
double ending_logprob[4];
- size_t i_batch; // starting index in the llama_batch
+ size_t i_logits; // starting index of logits in the llama_batch
size_t common_prefix; // max number of initial tokens that are the same in all sentences
size_t required_tokens; // needed number of tokens to evaluate all 4 endings
std::vector<llama_token> seq_tokens[4];
@@ -844,9 +858,10 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
- llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
+ llama_batch batch = llama_batch_init(n_ctx, 0, 4);
std::vector<float> tok_logits(n_vocab);
+ // TODO: this could be made smaller; it's currently the worst-case size
std::vector<float> batch_logits(n_vocab*n_ctx);
std::vector<std::pair<size_t, llama_token>> eval_pairs;
@@ -857,16 +872,17 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
int n_cur = 0;
size_t i1 = i0;
- size_t i_batch = 0; // this tells us where in `llama_batch` we are currently
+ size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
llama_batch_clear(batch);
// batch as much tasks as possible into the available context
- // each task has 4 unique seuqnce ids - one for each ending
+ // each task has 4 unique sequence ids - one for each ending
// the common prefix is shared among the 4 sequences to save tokens
// we extract logits only from the last common token and from all ending tokens of each sequence
while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
auto & hs_cur = hs_data[i1];
+ int n_logits = 0;
const int s0 = 4*(i1 - i0);
if (s0 + 4 > max_seq) {
@@ -874,18 +890,23 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
}
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
- llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
+ llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
+ n_logits += 1;
for (int s = 0; s < 4; ++s) {
- for (size_t i = hs_cur.common_prefix; i < hs_cur.seq_tokens[s].size(); ++i) {
- llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, true);
+ const size_t seq_tokens_size = hs_cur.seq_tokens[s].size();
+ // TODO: don't evaluate the last token of each sequence
+ for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) {
+ const bool needs_logits = i < seq_tokens_size - 1;
+ llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits);
+ n_logits += needs_logits;
}
}
- hs_cur.i_batch = i_batch;
- i_batch += hs_cur.required_tokens;
+ hs_cur.i_logits = i_logits;
+ i_logits += n_logits;
n_cur += hs_data[i1].required_tokens;
if (++i1 == hs_task_count) {
@@ -911,12 +932,11 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
eval_pairs.clear();
for (size_t i = i0; i < i1; ++i) {
auto & hs_cur = hs_data[i];
- size_t li = hs_cur.common_prefix;
+ size_t li = 1; // skip the last logit of the common prefix (computed separately below)
for (int s = 0; s < 4; ++s) {
for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
- eval_pairs.emplace_back(hs_cur.i_batch + li++, hs_cur.seq_tokens[s][j + 1]);
+ eval_pairs.emplace_back(hs_cur.i_logits + li++, hs_cur.seq_tokens[s][j + 1]);
}
- ++li;
}
}
// Then we do the actual calculation
@@ -928,7 +948,8 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
for (size_t i = i0; i < i1; ++i) {
auto & hs_cur = hs_data[i];
- std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + hs_cur.common_prefix - 1), n_vocab*sizeof(float));
+ // get the logits of the last token of the common prefix
+ std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*hs_cur.i_logits, n_vocab*sizeof(float));
const auto first_probs = softmax(tok_logits);
@@ -978,7 +999,7 @@ struct winogrande_entry {
std::array<std::string, 2> choices;
int answer;
- size_t i_batch;
+ size_t i_logits;
size_t common_prefix;
size_t required_tokens;
size_t n_base1; // number of tokens for context + choice 1
@@ -1104,6 +1125,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
task.common_prefix++;
}
+ // TODO: the last token of each of the sequences don't need to be evaluated
task.required_tokens = task.common_prefix +
task.seq_tokens[0].size() - task.common_prefix +
task.seq_tokens[1].size() - task.common_prefix;
@@ -1121,9 +1143,10 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
const int max_tasks_per_batch = 128;
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
- llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
+ llama_batch batch = llama_batch_init(n_ctx, 0, 2);
std::vector<float> tok_logits(n_vocab);
+ // TODO: this could be made smaller; it's currently the worst-case size
std::vector<float> batch_logits(n_vocab*n_ctx);
std::vector<std::pair<size_t, llama_token>> eval_pairs;
@@ -1137,29 +1160,33 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
int n_cur = 0;
size_t i1 = i0;
- size_t i_batch = 0;
+ size_t i_logits = 0;
llama_batch_clear(batch);
while (n_cur + (int) data[i1].required_tokens <= n_ctx) {
+ int n_logits = 0;
const int s0 = 2*(i1 - i0);
if (s0 + 2 > max_seq) {
break;
}
for (size_t i = 0; i < data[i1].common_prefix; ++i) {
- llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1}, false);
+ llama_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false);
}
batch.logits[batch.n_tokens - 1] = true;
+ n_logits += 1;
for (int s = 0; s < 2; ++s) {
+ // TODO: end before the last token, no need to predict past the end of the sequences
for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) {
llama_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true);
+ n_logits += 1;
}
}
- data[i1].i_batch = i_batch;
- i_batch += data[i1].required_tokens;
+ data[i1].i_logits = i_logits;
+ i_logits += n_logits;
n_cur += data[i1].required_tokens;
if (++i1 == data.size()) {
@@ -1190,15 +1217,16 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
const auto& n_base1 = skip_choice ? task.n_base1 : task.common_prefix;
const int last_1st = task.seq_tokens[0].size() - n_base1 > 1 ? 1 : 0;
- size_t li = n_base1 - 1;
+ size_t li = n_base1 - task.common_prefix;
for (size_t j = n_base1-1; j < task.seq_tokens[0].size()-1-last_1st; ++j) {
- eval_pairs.emplace_back(task.i_batch + li++, task.seq_tokens[0][j+1]);
+ eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[0][j+1]);
}
const auto& n_base2 = skip_choice ? task.n_base2 : task.common_prefix;
const int last_2nd = task.seq_tokens[1].size() - n_base2 > 1 ? 1 : 0;
- li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - 1;
+ // FIXME: this uses the wrong first logits when not skipping the choice word
+ li = task.seq_tokens[0].size() - task.common_prefix + n_base2 - task.common_prefix;
for (size_t j = n_base2-1; j < task.seq_tokens[1].size()-1-last_2nd; ++j) {
- eval_pairs.emplace_back(task.i_batch + li++, task.seq_tokens[1][j+1]);
+ eval_pairs.emplace_back(task.i_logits + li++, task.seq_tokens[1][j+1]);
}
}
compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
@@ -1287,7 +1315,7 @@ struct multiple_choice_task {
}
// For evaluation
- size_t i_batch; // starting index in the llama_batch
+ size_t i_logits; // starting index of logits in the llama_batch
size_t common_prefix; // max number of initial tokens that are the same in all sentences
size_t required_tokens; // needed number of tokens to evaluate all answers
std::vector<std::vector<llama_token>> seq_tokens;
@@ -1366,7 +1394,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
std::vector<uint32_t> task_pos(n_task);
strstream.read((char *)task_pos.data(), task_pos.size()*sizeof(uint32_t));
if (strstream.fail()) {
- printf("%s: failed to raad task positions from prompt\n", __func__);
+ printf("%s: failed to read task positions from prompt\n", __func__);
return;
}
@@ -1447,7 +1475,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
return;
}
} else {
- int n_dot = n_task/100;
+ int n_dot = std::max((int) n_task/100, 1);
int i_task = 0;
for (auto& task : tasks) {
++i_task;
@@ -1491,17 +1519,18 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
int n_cur = 0;
size_t i1 = i0;
- size_t i_batch = 0; // this tells us where in `llama_batch` we are currently
+ size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch
llama_batch_clear(batch);
// batch as much tasks as possible into the available context
- // each task has 4 unique seuqnce ids - one for each ending
+ // each task has 4 unique sequence ids - one for each ending
// the common prefix is shared among the 4 sequences to save tokens
// we extract logits only from the last common token and from all ending tokens of each sequence
int s0 = 0;
while (n_cur + (int) tasks[i1].required_tokens <= n_ctx) {
auto& cur_task = tasks[i1];
+ int n_logits = 0;
int num_answers = cur_task.seq_tokens.size();
if (s0 + num_answers > max_seq) {
@@ -1518,17 +1547,22 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
llama_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
+ n_logits += 1;
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
- for (size_t i = cur_task.common_prefix; i < cur_task.seq_tokens[s].size(); ++i) {
- llama_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, true);
+ const size_t seq_tokens_size = cur_task.seq_tokens[s].size();
+ // TODO: don't evaluate the last token of each sequence
+ for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) {
+ const bool needs_logits = i < seq_tokens_size - 1;
+ llama_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits);
+ n_logits += needs_logits;
}
}
s0 += num_answers;
- cur_task.i_batch = i_batch;
- i_batch += cur_task.required_tokens;
+ cur_task.i_logits = i_logits;
+ i_logits += n_logits;
n_cur += cur_task.required_tokens;
if (++i1 == tasks.size()) {
@@ -1554,12 +1588,11 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
eval_pairs.clear();
for (size_t i = i0; i < i1; ++i) {
auto& cur_task = tasks[i];
- size_t li = cur_task.common_prefix;
+ size_t li = 1; // skip the last logit of the common prefix (computed separately below)
for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) {
for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) {
- eval_pairs.emplace_back(cur_task.i_batch + li++, cur_task.seq_tokens[s][j + 1]);
+ eval_pairs.emplace_back(cur_task.i_logits + li++, cur_task.seq_tokens[s][j + 1]);
}
- ++li;
}
}
// Then we do the actual calculation
@@ -1578,7 +1611,8 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
//}
//printf("\n common_prefix: %zu\n", cur_task.common_prefix);
- std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(cur_task.i_batch + cur_task.common_prefix - 1), n_vocab*sizeof(float));
+ // get the logits of the last token of the common prefix
+ std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*cur_task.i_logits, n_vocab*sizeof(float));
const auto first_probs = softmax(tok_logits);
@@ -1730,6 +1764,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
}
+ // TODO: use llama_batch.logits instead of relying on logits_all == true
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 526de596..53ad9239 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -747,7 +747,8 @@ struct server_context {
{
const int32_t n_batch = llama_n_batch(ctx);
- batch = llama_batch_init(n_batch, 0, params.n_parallel);
+ // only a single seq_id per token is needed
+ batch = llama_batch_init(n_batch, 0, 1);
}
metrics.init();
diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp
index 8b31b678..6e0815b3 100644
--- a/examples/speculative/speculative.cpp
+++ b/examples/speculative/speculative.cpp
@@ -65,7 +65,6 @@ int main(int argc, char ** argv) {
llama_context * ctx_dft = NULL;
// load the target model
- params.logits_all = true;
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
// load the draft model
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 48232b6e..be8e33a5 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -2505,7 +2505,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
- if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+ if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
diff --git a/ggml-kompute.cpp b/ggml-kompute.cpp
index 81dd5067..407062e6 100644
--- a/ggml-kompute.cpp
+++ b/ggml-kompute.cpp
@@ -1430,6 +1430,10 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
struct ggml_tensor * dst = gf->nodes[i];
GGML_ASSERT(dst->data != nullptr);
+ if (ggml_is_empty(dst)) {
+ continue;
+ }
+
switch (dst->op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
diff --git a/ggml-metal.m b/ggml-metal.m
index cbe22aa3..a08abbc2 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -847,6 +847,10 @@ static enum ggml_status ggml_metal_graph_compute(
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
struct ggml_tensor * dst = gf->nodes[i];
+ if (ggml_is_empty(dst)) {
+ continue;
+ }
+
switch (dst->op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp
index aa73d67d..b3f8b7ea 100644
--- a/ggml-opencl.cpp
+++ b/ggml-opencl.cpp
@@ -2234,6 +2234,11 @@ static ggml_backend_buffer_type_t ggml_backend_opencl_get_default_buffer_type(gg
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
for (int i = 0; i < graph->n_nodes; ++i) {
ggml_tensor * node = graph->nodes[i];
+
+ if (ggml_is_empty(node)) {
+ continue;
+ }
+
switch (node->op) {
case GGML_OP_MUL_MAT:
ggml_cl_mul_mat(node->src[0], node->src[1], node, nullptr, 0);
diff --git a/ggml-sycl.cpp b/ggml-sycl.cpp
index fc4d2964..789ba97b 100644
--- a/ggml-sycl.cpp
+++ b/ggml-sycl.cpp
@@ -16973,7 +16973,7 @@ GGML_CALL static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t back
params.ith = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
- if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+ if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
#ifndef NDEBUG
diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp
index cbceaa19..521a1314 100644
--- a/ggml-vulkan.cpp
+++ b/ggml-vulkan.cpp
@@ -5566,7 +5566,7 @@ GGML_CALL static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backen
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
- if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
+ if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
diff --git a/ggml.c b/ggml.c
index a86b41c1..eb469d0f 100644
--- a/ggml.c
+++ b/ggml.c
@@ -2607,6 +2607,16 @@ static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
}
+GGML_CALL bool ggml_is_empty(const struct ggml_tensor * tensor) {
+ for (int i = 0; i < GGML_MAX_DIMS; ++i) {
+ if (tensor->ne[i] == 0) {
+ // empty if any dimension has no elements
+ return true;
+ }
+ }
+ return false;
+}
+
bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
@@ -2621,7 +2631,7 @@ bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor
static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
- return
+ return ggml_is_empty(t0) ? ggml_is_empty(t1) :
(t1->ne[0]%t0->ne[0] == 0) &&
(t1->ne[1]%t0->ne[1] == 0) &&
(t1->ne[2]%t0->ne[2] == 0) &&
@@ -16114,7 +16124,7 @@ static void ggml_compute_forward_cross_entropy_loss_back(
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
GGML_ASSERT(params);
- if (tensor->op == GGML_OP_NONE) {
+ if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
return;
}
@@ -17983,6 +17993,12 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const
static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_threads) {
int n_tasks = 0;
+ if (ggml_is_empty(node)) {
+ // no need to multi-thread a no-op
+ n_tasks = 1;
+ return n_tasks;
+ }
+
switch (node->op) {
case GGML_OP_CPY:
case GGML_OP_DUP:
diff --git a/ggml.h b/ggml.h
index 425c9b6a..5d4a4ceb 100644
--- a/ggml.h
+++ b/ggml.h
@@ -750,6 +750,7 @@ extern "C" {
GGML_API GGML_CALL bool ggml_is_transposed(const struct ggml_tensor * tensor);
GGML_API GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor);
GGML_API GGML_CALL bool ggml_is_permuted (const struct ggml_tensor * tensor);
+ GGML_API GGML_CALL bool ggml_is_empty (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor);
GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor);
diff --git a/llama.cpp b/llama.cpp
index 68c360c7..22db79d6 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1777,6 +1777,7 @@ struct llama_cparams {
uint32_t n_ctx; // context size used during inference
uint32_t n_batch;
uint32_t n_ubatch;
+ uint32_t n_seq_max;
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing
@@ -2139,20 +2140,20 @@ struct llama_context {
// host buffer for the model output (logits and embeddings)
ggml_backend_buffer_t buf_output = nullptr;
- // decode output (2-dimensional array: [n_tokens][n_vocab])
- size_t logits_size = 0;
- float * logits = nullptr;
+ // decode output (2-dimensional array: [n_outputs][n_vocab])
+ size_t logits_size = 0; // capacity (of floats) for logits
+ float * logits = nullptr;
+
+ std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
+ size_t output_size = 0; // capacity (of tokens positions) for the output buffers
+ int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch
-#ifndef NDEBUG
- // guard against access to unset logits
- std::vector<bool> logits_valid;
-#endif
bool logits_all = false;
- // embeddings output (2-dimensional array: [n_tokens][n_embd])
+ // embeddings output (2-dimensional array: [n_outputs][n_embd])
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
- size_t embd_size = 0;
- float * embd = nullptr;
+ size_t embd_size = 0; // capacity (of floats) for embeddings
+ float * embd = nullptr;
// sequence embeddings output (map of [n_embd] vectors)
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
@@ -2169,14 +2170,15 @@ struct llama_context {
struct ggml_tensor * inp_tokens; // I32 [n_batch]
struct ggml_tensor * inp_embd; // F32 [n_embd, n_batch]
struct ggml_tensor * inp_pos; // I32 [n_batch]
+ struct ggml_tensor * inp_out_ids; // I32 [n_outputs]
struct ggml_tensor * inp_KQ_mask; // F32 [kv_size, n_batch]
- struct ggml_tensor * inp_KQ_pos; // F32 [kv_size]
+ struct ggml_tensor * inp_KQ_pos; // F32 [n_kv]
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
struct ggml_tensor * inp_cls; // I32 [n_batch]
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
- struct ggml_tensor * inp_s_mask; // F32 [1, kv_size]
- struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
+ struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]
+ struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
// control vectors
struct llama_control_vector cvec;
@@ -5846,7 +5848,8 @@ struct llm_build_context {
const float norm_rms_eps;
const int32_t n_tokens;
- const int32_t n_kv; // size of KV cache to consider (n_kv <= n_ctx)
+ const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size)
+ const int32_t n_outputs;
const int32_t kv_head; // index of where we store new KV data in the cache
const int32_t n_orig_ctx;
@@ -5893,6 +5896,7 @@ struct llm_build_context {
norm_rms_eps (hparams.f_norm_rms_eps),
n_tokens (batch.n_tokens),
n_kv (worst_case ? kv_self.size : kv_self.n),
+ n_outputs (worst_case ? n_tokens : lctx.n_outputs),
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx),
pooling_type (cparams.pooling_type),
@@ -5914,6 +5918,7 @@ struct llm_build_context {
lctx.inp_tokens = nullptr;
lctx.inp_embd = nullptr;
lctx.inp_pos = nullptr;
+ lctx.inp_out_ids = nullptr;
lctx.inp_KQ_mask = nullptr;
lctx.inp_KQ_pos = nullptr;
lctx.inp_K_shift = nullptr;
@@ -6037,6 +6042,13 @@ struct llm_build_context {
return lctx.inp_pos;
}
+ struct ggml_tensor * build_inp_out_ids() {
+ lctx.inp_out_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_outputs);
+ cb(lctx.inp_out_ids, "inp_out_ids", -1);
+ ggml_set_input(lctx.inp_out_ids);
+ return lctx.inp_out_ids;
+ }
+
struct ggml_tensor * build_inp_KQ_mask(bool causal = true) {
if (causal) {
lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens);
@@ -6093,6 +6105,9 @@ struct llm_build_context {
struct ggml_cgraph * build_llama() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+ // mutable variable, needed during the last layer of the computation to skip unused tokens
+ int32_t n_tokens = this->n_tokens;
+
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
@@ -6160,6 +6175,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ n_tokens = n_outputs;
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@@ -6339,6 +6362,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@@ -6454,6 +6484,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ attn_norm = ggml_get_rows(ctx0, attn_norm, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = cur;
// feed forward
@@ -6497,6 +6535,9 @@ struct llm_build_context {
struct ggml_cgraph * build_grok() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+ // mutable variable, needed during the last layer of the computation to skip unused tokens
+ int32_t n_tokens = this->n_tokens;
+
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
@@ -6568,6 +6609,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ n_tokens = n_outputs;
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
// Grok
// if attn_out_norm is present then apply it before adding the input
if (model.layers[il].attn_out_norm) {
@@ -6745,6 +6794,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
@@ -6942,6 +6998,13 @@ struct llm_build_context {
Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ residual = ggml_get_rows(ctx0, residual, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
cb(ffn_inp, "ffn_inp", il);
@@ -7031,6 +7094,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@@ -7188,6 +7258,13 @@ struct llm_build_context {
}
cb(cur, "kqv_out", il);
+ if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// re-add the layer input
cur = ggml_add(ctx0, cur, inpL);
@@ -7310,6 +7387,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// Add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
@@ -7408,6 +7492,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// Add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
@@ -7521,6 +7612,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@@ -7627,6 +7725,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@@ -7739,6 +7844,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@@ -7857,6 +7969,14 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids);
+ }
+
// FF
{
ffn_output = llm_build_ffn(ctx0, attn_norm_output,
@@ -7954,6 +8074,14 @@ struct llm_build_context {
cur = attention_norm;
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ sa_out = ggml_get_rows(ctx0, sa_out, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// feed-forward network
{
cur = llm_build_ffn(ctx0, cur,
@@ -8046,6 +8174,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
@@ -8146,6 +8281,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
cb(ffn_inp, "ffn_inp", il);
@@ -8255,6 +8397,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@@ -8365,6 +8514,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@@ -8488,6 +8644,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
// scale_res - scale the hidden states for residual connection
const float scale_res = scale_depth/sqrtf(float(n_layer));
cur = ggml_scale(ctx0, cur, scale_res);
@@ -8602,6 +8765,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
cb(sa_out, "sa_out", il);
@@ -8714,6 +8884,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
@@ -8861,6 +9038,15 @@ struct llm_build_context {
struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0);
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ x = ggml_get_rows(ctx0, x, inp_out_ids);
+ y = ggml_get_rows(ctx0, y, inp_out_ids);
+ z = ggml_get_rows(ctx0, z, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
// {d_inner, n_tokens} * {d_inner} => {d_inner, n_tokens}
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
@@ -8963,6 +9149,13 @@ struct llm_build_context {
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
struct ggml_tensor * attn_out = cur;
// feed-forward network
@@ -9260,9 +9453,39 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
}
+ if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
+ GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
+ const int64_t n_tokens = batch.n_tokens;
+
+ GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
+ int32_t * data = (int32_t *) lctx.inp_out_ids->data;
+
+ if (lctx.n_outputs == n_tokens) {
+ for (int i = 0; i < n_tokens; ++i) {
+ data[i] = i;
+ }
+ } else if (batch.logits) {
+ int32_t n_outputs = 0;
+ for (int i = 0; i < n_tokens; ++i) {
+ if (batch.logits[i]) {
+ data[n_outputs++] = i;
+ }
+ }
+ // the graph needs to have been passed the correct number of outputs
+ GGML_ASSERT(lctx.n_outputs == n_outputs);
+ } else if (lctx.n_outputs == 1) {
+ // only keep last output
+ data[0] = n_tokens - 1;
+ } else {
+ GGML_ASSERT(lctx.n_outputs == 0);
+ }
+ }
+
GGML_ASSERT(
+ // (!a || b) is a logical implication (a -> b)
+ // !hparams.causal_attn -> !cparams.causal_attn
(hparams.causal_attn || !cparams.causal_attn) &&
- "non-causal attention with generative models is not supported"
+ "causal attention with embedding models is not supported"
);
if (lctx.inp_KQ_mask) {
@@ -9441,6 +9664,74 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
+// Make sure enough space is available for outputs.
+// Returns max number of outputs for which space was reserved.
+static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
+ const auto & cparams = lctx.cparams;
+ const auto & hparams = lctx.model.hparams;
+
+ const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
+
+ const auto n_batch = cparams.n_batch;
+ const auto n_vocab = hparams.n_vocab;
+ const auto n_embd = hparams.n_embd;
+
+ // TODO: use a per-batch flag for logits presence instead
+ const bool has_logits = cparams.causal_attn;
+ const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE);
+
+ const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
+ const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
+
+ if (lctx.output_ids.empty()) {
+ // init, never resized afterwards
+ lctx.output_ids.resize(n_batch);
+ }
+
+ const size_t prev_size = lctx.buf_output ? ggml_backend_buffer_get_size(lctx.buf_output) : 0;
+ const size_t new_size = (logits_size + embd_size) * sizeof(float);
+
+ // alloc only when more than the current capacity is required
+ // TODO: also consider shrinking the buffer
+ if (!lctx.buf_output || prev_size < new_size) {
+ if (lctx.buf_output) {
+#ifndef NDEBUG
+ // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark)
+ LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
+#endif
+ ggml_backend_buffer_free(lctx.buf_output);
+ lctx.buf_output = nullptr;
+ lctx.logits = nullptr;
+ lctx.embd = nullptr;
+ }
+
+ lctx.buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), new_size);
+ if (lctx.buf_output == nullptr) {
+ LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0));
+ return 0;
+ }
+ }
+
+ float * output_base = (float *) ggml_backend_buffer_get_base(lctx.buf_output);
+
+ lctx.logits = has_logits ? output_base : nullptr;
+ lctx.embd = has_embd ? output_base + logits_size : nullptr;
+
+ lctx.output_size = n_outputs_max;
+ lctx.logits_size = logits_size;
+ lctx.embd_size = embd_size;
+
+ // set all ids as invalid (negative)
+ std::fill(lctx.output_ids.begin(), lctx.output_ids.end(), -1);
+
+ ggml_backend_buffer_clear(lctx.buf_output, 0);
+
+ lctx.n_outputs = 0;
+
+ return n_outputs_max;
+}
+
+
static void llama_graph_compute(
llama_context & lctx,
ggml_cgraph * gf,
@@ -9516,16 +9807,8 @@ static int llama_decode_internal(
const int64_t n_embd = hparams.n_embd;
const int64_t n_vocab = hparams.n_vocab;
-
- auto * logits_out = lctx.logits;
-
-#ifndef NDEBUG
- auto & logits_valid = lctx.logits_valid;
- logits_valid.clear();
- logits_valid.resize(n_tokens_all);
-
- memset(logits_out, 0, lctx.logits_size*sizeof(float));
-#endif
+ uint32_t n_outputs = 0;
+ uint32_t n_outputs_prev = 0;
const auto n_ubatch = cparams.n_ubatch;
@@ -9534,6 +9817,38 @@ static int llama_decode_internal(
std::vector<llama_seq_id *> seq_id_arr;
std::vector<std::vector<llama_seq_id>> seq_id;
+ // count outputs
+ if (batch_all.logits) {
+ for (uint32_t i = 0; i < n_tokens_all; ++i) {
+ n_outputs += batch_all.logits[i] != 0;
+ }
+ } else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) {
+ n_outputs = n_tokens_all;
+ } else {
+ // keep last output only
+ n_outputs = 1;
+ }
+
+ // reserve output buffer
+ if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
+ LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
+ return -2;
+ };
+
+ // set output mappings
+ if (batch_all.logits) {
+ int32_t i_logits = 0;
+ for (uint32_t i = 0; i < n_tokens_all; ++i) {
+ if (batch_all.logits[i]) {
+ lctx.output_ids[i] = i_logits++;
+ }
+ }
+ } else {
+ for (uint32_t i = 0; i < n_outputs; ++i) {
+ lctx.output_ids[i] = i;
+ }
+ }
+
for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) {
const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
llama_batch u_batch = {
@@ -9549,6 +9864,27 @@ static int llama_decode_internal(
/* .all_seq_id = */ batch_all.all_seq_id,
};
+ // count the outputs in this u_batch
+ {
+ int32_t n_outputs_new = 0;
+
+ if (u_batch.logits) {
+ for (uint32_t i = 0; i < n_tokens; i++) {
+ n_outputs_new += u_batch.logits[i] != 0;
+ }
+ } else if (n_outputs == n_tokens_all) {
+ n_outputs_new = n_tokens;
+ } else {
+ // keep last output only
+ if (cur_token + n_tokens >= n_tokens_all) {
+ n_outputs_new = 1;
+ }
+ }
+
+ // needs to happen before the graph is built
+ lctx.n_outputs = n_outputs_new;
+ }
+
int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch;
GGML_ASSERT(n_threads > 0);
@@ -9612,23 +9948,37 @@ static int llama_decode_internal(
struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1];
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
- if (!hparams.causal_attn) {
+ if (lctx.n_outputs == 0) {
+ // no output
+ res = nullptr;
+ embd = nullptr;
+ } else if (!hparams.causal_attn) {
res = nullptr; // do not extract logits for embedding models such as BERT
// token or sequence embeddings
embd = gf->nodes[gf->n_nodes - 1];
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0);
- } else {
- if (strcmp(res->name, "result_output") == 0) {
- // the token embeddings could be the second to last tensor, or the third to last tensor
- if (strcmp(embd->name, "result_norm") != 0) {
- embd = gf->nodes[gf->n_nodes - 3];
- GGML_ASSERT(strcmp(embd->name, "result_norm") == 0);
- }
- } else {
- GGML_ASSERT(false && "missing result_output tensor");
+ } else if (cparams.embeddings) {
+ // the embeddings could be in the second to last tensor, or any of the previous tensors
+ int i_embd = gf->n_nodes - 2;
+ for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) {
+ i_embd = gf->n_nodes - i;
+ if (i_embd < 0) { break; }
+ embd = gf->nodes[i_embd];
+ }
+ GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor");
+
+ // TODO: use a per-batch flag to know when to skip logits while keeping embeddings
+ if (!cparams.causal_attn) {
+ res = nullptr; // do not extract logits when not needed
+ // skip computing logits
+ // TODO: is this safe?
+ gf->n_nodes = i_embd + 1;
}
+ } else {
+ embd = nullptr; // do not extract embeddings when not needed
+ GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
}
// LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs);
@@ -9671,50 +10021,23 @@ static int llama_decode_internal(
//}
// extract logits
- // TODO: do not compute and extract logits if only embeddings are needed
- // update the graphs to skip "result_output" if logits are not needed
if (res) {
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res);
GGML_ASSERT(backend_res != nullptr);
- if (u_batch.logits) {
- int32_t i_first = -1;
- for (uint32_t i = 0; i < n_tokens; i++) {
- if (u_batch.logits[i] && i_first == -1) {
- i_first = (int32_t) i;
- }
- if (u_batch.logits[i] == 0 || i == n_tokens - 1) {
- if (i_first != -1) {
- int i_last = u_batch.logits[i] == 0 ? i : i + 1;
- // extract logits for the range [i_first, i_last)
- // group the requests to minimize the number of calls to the backend
- ggml_backend_tensor_get_async(backend_res, res,
- logits_out + n_vocab*(cur_token + i_first),
- i_first*n_vocab*sizeof(float),
- (i_last - i_first)*n_vocab*sizeof(float));
- i_first = -1;
- }
- }
-#ifndef NDEBUG
- logits_valid[cur_token + i] = u_batch.logits[i] != 0;;
-#endif
- }
- } else if (lctx.logits_all) {
- ggml_backend_tensor_get_async(backend_res, res, logits_out + n_vocab*cur_token, 0, n_vocab*n_tokens*sizeof(float));
-#ifndef NDEBUG
- std::fill(logits_valid.begin() + cur_token, logits_valid.begin() + cur_token + n_tokens, true);
-#endif
- } else {
- if (cur_token + n_tokens >= n_tokens_all) {
- ggml_backend_tensor_get_async(backend_res, res, logits_out, n_vocab*(n_tokens - 1)*sizeof(float), n_vocab*sizeof(float));
-#ifndef NDEBUG
- logits_valid[0] = true;
-#endif
- }
+ GGML_ASSERT(lctx.logits != nullptr);
+
+ float * logits_out = lctx.logits + n_outputs_prev*n_vocab;
+ const int32_t n_outputs_new = lctx.n_outputs;
+
+ if (n_outputs_new) {
+ GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
+ GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_vocab <= (int64_t) lctx.logits_size);
+ ggml_backend_tensor_get_async(backend_res, res, logits_out, 0, n_outputs_new*n_vocab*sizeof(float));
}
}
// extract embeddings
- if (cparams.embeddings && embd) {
+ if (embd) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
GGML_ASSERT(backend_embd != nullptr);
@@ -9722,16 +10045,14 @@ static int llama_decode_internal(
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
- auto & embd_out = lctx.embd;
-
- if (u_batch.logits) {
- //embd_out.resize(n_embd * n_tokens);
- for (uint32_t i = 0; i < n_tokens; i++) {
- if (u_batch.logits[i] == 0) {
- continue;
- }
- ggml_backend_tensor_get_async(backend_embd, embd, embd_out + n_embd*(i + cur_token), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
- }
+ GGML_ASSERT(lctx.embd != nullptr);
+ float * embd_out = lctx.embd + n_outputs_prev*n_embd;
+ const int32_t n_outputs_new = lctx.n_outputs;
+
+ if (n_outputs_new) {
+ GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
+ GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_embd <= (int64_t) lctx.embd_size);
+ ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_CLS:
@@ -9758,6 +10079,7 @@ static int llama_decode_internal(
} break;
}
}
+ n_outputs_prev += lctx.n_outputs;
}
// wait for the computation to finish (automatically done when obtaining the model output)
@@ -13531,7 +13853,7 @@ struct llama_context * llama_new_context_with_model(
const auto & hparams = model->hparams;
auto & cparams = ctx->cparams;
- // TODO: maybe add n_seq_max here too
+ cparams.n_seq_max = std::max(1u, params.n_seq_max);
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
@@ -13733,25 +14055,12 @@ struct llama_context * llama_new_context_with_model(
// graph outputs buffer
{
- // resized during inference, reserve maximum
- ctx->logits_size = hparams.n_vocab*cparams.n_batch;
- ctx->embd_size = params.embeddings ? hparams.n_embd*cparams.n_batch : 0;
-
- const size_t buf_output_size = (ctx->logits_size + ctx->embd_size)*sizeof(float);
-
- ctx->buf_output = ggml_backend_buft_alloc_buffer(llama_default_buffer_type_cpu(true), buf_output_size);
- if (ctx->buf_output == nullptr) {
- LLAMA_LOG_ERROR("%s: failed to allocate logits buffer\n", __func__);
+ // resized during inference when a batch uses more outputs
+ if (llama_output_reserve(*ctx, params.n_seq_max) < params.n_seq_max) {
+ LLAMA_LOG_ERROR("%s: failed to reserve initial output buffer\n", __func__);
llama_free(ctx);
return nullptr;
}
- ggml_backend_buffer_clear(ctx->buf_output, 0);
-
-
- ctx->logits = (float *) ggml_backend_buffer_get_base(ctx->buf_output);
- if (params.embeddings) {
- ctx->embd = ctx->logits + ctx->logits_size;
- }
LLAMA_LOG_INFO("%s: %10s output buffer size = %8.2f MiB\n", __func__,
ggml_backend_buffer_name(ctx->buf_output),
@@ -14268,27 +14577,33 @@ void llama_kv_cache_update(struct llama_context * ctx) {
// Returns the *maximum* size of the state
size_t llama_get_state_size(const struct llama_context * ctx) {
+ const auto & cparams = ctx->cparams;
+ const auto & hparams = ctx->model.hparams;
+
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
// for reference, std::mt19937(1337) serializes to 6701 bytes.
const size_t s_rng_size = sizeof(size_t);
const size_t s_rng = LLAMA_MAX_RNG_STATE;
+ const size_t s_n_outputs = sizeof(size_t);
+ // assume worst case for outputs although only currently set ones are serialized
+ const size_t s_output_pos = ctx->cparams.n_batch * sizeof(int32_t);
const size_t s_logits_size = sizeof(size_t);
- // assume worst case for logits although only currently set ones are serialized
- const size_t s_logits = ctx->logits_size * sizeof(float);
+ const size_t s_logits = ctx->logits_size ? cparams.n_batch * hparams.n_vocab * sizeof(float) : 0;
const size_t s_embedding_size = sizeof(size_t);
- const size_t s_embedding = ctx->embd_size * sizeof(float);
+ const size_t s_embedding = ctx->embd_size ? cparams.n_batch * hparams.n_embd * sizeof(float) : 0;
const size_t s_kv_buf_size = sizeof(size_t);
const size_t s_kv_head = sizeof(uint32_t);
const size_t s_kv_size = sizeof(uint32_t);
const size_t s_kv_used = sizeof(uint32_t);
const size_t s_kv = ctx->kv_self.total_size();
- // TODO: assume the max is more than 1 seq_id per KV cell
- const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + sizeof(llama_seq_id);
+ const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id);
const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell;
const size_t s_total = (
+ s_rng_size
+ s_rng
+ + s_n_outputs
+ + s_output_pos
+ s_logits_size
+ s_logits
+ s_embedding_size
@@ -14363,7 +14678,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
std::ostringstream rng_ss;
rng_ss << ctx->rng;
- const std::string & rng_str = rng_ss.str();
+ const std::string & rng_str = rng_ss.str();
const size_t rng_size = rng_str.size();
GGML_ASSERT(rng_size <= LLAMA_MAX_RNG_STATE);
@@ -14372,25 +14687,61 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
data_ctx->write(rng_str.data(), rng_size);
}
- // copy logits
+ // copy outputs
{
- const size_t logits_size = ctx->logits_size;
+ // Can't use ctx->n_outputs because it's not for the
+ // entire last batch when n_ubatch is smaller than n_batch
+ size_t n_outputs = 0;
- data_ctx->write(&logits_size, sizeof(logits_size));
+ // copy output ids
+ {
+ std::vector<int32_t> output_pos;
- if (logits_size) {
- data_ctx->write(ctx->logits, logits_size * sizeof(float));
+ const size_t n_batch = ctx->cparams.n_batch;
+ const auto & output_ids = ctx->output_ids;
+
+ output_pos.resize(ctx->output_size);
+
+ // build a more compact representation of the output ids
+ for (size_t i = 0; i < n_batch; ++i) {
+ // map an output id to a position in the batch
+ int32_t pos = output_ids[i];
+ if (pos >= 0) {
+ if ((size_t) pos >= n_outputs) {
+ n_outputs = pos + 1;
+ }
+ GGML_ASSERT((size_t) pos < ctx->output_size);
+ output_pos[pos] = i;
+ }
+ }
+
+ data_ctx->write(&n_outputs, sizeof(n_outputs));
+
+ if (n_outputs) {
+ data_ctx->write(output_pos.data(), n_outputs * sizeof(int32_t));
+ }
}
- }
- // copy embeddings
- {
- const size_t embeddings_size = ctx->embd_size;
+ // copy logits
+ {
+ const size_t logits_size = std::min(ctx->logits_size, n_outputs * ctx->model.hparams.n_vocab);
- data_ctx->write(&embeddings_size, sizeof(embeddings_size));
+ data_ctx->write(&logits_size, sizeof(logits_size));
- if (embeddings_size) {
- data_ctx->write(ctx->embd, embeddings_size * sizeof(float));
+ if (logits_size) {
+ data_ctx->write(ctx->logits, logits_size * sizeof(float));
+ }
+ }
+
+ // copy embeddings
+ {
+ const size_t embeddings_size = std::min(ctx->embd_size, n_outputs * ctx->model.hparams.n_embd);
+
+ data_ctx->write(&embeddings_size, sizeof(embeddings_size));
+
+ if (embeddings_size) {
+ data_ctx->write(ctx->embd, embeddings_size * sizeof(float));
+ }
}
}
@@ -14403,9 +14754,10 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
- const size_t kv_buf_size = kv_self.total_size();
+ // NOTE: kv_size and kv_buf_size are mostly used for sanity checks
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
const uint32_t kv_size = kv_self.size;
+ const size_t kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head;
const uint32_t kv_used = kv_self.used;
data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
@@ -14414,6 +14766,8 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
data_ctx->write(&kv_used, sizeof(kv_used));
if (kv_buf_size) {
+ const size_t pre_kv_buf_size = data_ctx->get_size_written();
+
std::vector<uint8_t> tmp_buf;
for (int il = 0; il < (int) n_layer; ++il) {
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
@@ -14443,6 +14797,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
data_ctx->write(tmp_buf.data(), tmp_buf.size());
}
}
+ GGML_ASSERT(kv_buf_size == data_ctx->get_size_written() - pre_kv_buf_size);
}
for (uint32_t i = 0; i < kv_head; ++i) {
@@ -14487,6 +14842,28 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
GGML_ASSERT(!rng_ss.fail());
}
+ // set output ids
+ {
+ size_t n_outputs;
+ std::vector<int32_t> output_pos;
+
+ memcpy(&n_outputs, inp, sizeof(n_outputs)); inp += sizeof(n_outputs);
+
+ GGML_ASSERT(n_outputs <= llama_output_reserve(*ctx, n_outputs));
+
+ if (n_outputs) {
+ output_pos.resize(n_outputs);
+ memcpy(output_pos.data(), inp, n_outputs * sizeof(int32_t));
+ inp += n_outputs * sizeof(int32_t);
+
+ for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
+ int32_t id = output_pos[i];
+ GGML_ASSERT((uint32_t) id < ctx->cparams.n_batch);
+ ctx->output_ids[id] = i;
+ }
+ }
+ }
+
// set logits
{
size_t logits_size;
@@ -14507,7 +14884,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
memcpy(&embeddings_size, inp, sizeof(embeddings_size)); inp += sizeof(embeddings_size);
- GGML_ASSERT(ctx->embd_size == embeddings_size);
+ GGML_ASSERT(ctx->embd_size >= embeddings_size);
if (embeddings_size) {
memcpy(ctx->embd, inp, embeddings_size * sizeof(float));
@@ -14534,8 +14911,18 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used);
+ if (kv_self.size != kv_size) {
+ // the KV cache needs to be big enough to load all the KV cells from the saved state
+ GGML_ASSERT(kv_self.size >= kv_head);
+
+ LLAMA_LOG_INFO("%s: state contains %d KV cells, was saved with kv_size=%d, but is loaded with kv_size=%d (fine, but different)\n",
+ __func__, kv_head, kv_size, kv_self.size);
+ }
+
if (kv_buf_size) {
- GGML_ASSERT(kv_self.total_size() == kv_buf_size);
+ const size_t pre_kv_buf_size = inp - src;
+
+ GGML_ASSERT(kv_self.total_size() >= kv_buf_size);
for (int il = 0; il < (int) n_layer; ++il) {
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
@@ -14555,23 +14942,21 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
// v is not contiguous, copy row by row
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
- const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
+ const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_self.size);
for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size);
inp += v_row_size;
}
}
+ GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size);
}
- GGML_ASSERT(kv_self.size == kv_size);
+ llama_kv_cache_clear(ctx);
ctx->kv_self.head = kv_head;
- ctx->kv_self.size = kv_size;
ctx->kv_self.used = kv_used;
- ctx->kv_self.cells.resize(kv_size);
-
for (uint32_t i = 0; i < kv_head; ++i) {
llama_pos pos;
size_t seq_id_size;
@@ -14588,11 +14973,6 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
ctx->kv_self.cells[i].seq_id.insert(seq_id);
}
}
-
- for (uint32_t i = kv_head; i < kv_size; ++i) {
- ctx->kv_self.cells[i].pos = -1;
- ctx->kv_self.cells[i].seq_id.clear();
- }
}
const size_t nread = inp - src;
@@ -14798,11 +15178,33 @@ float * llama_get_logits(struct llama_context * ctx) {
}
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
- assert(ctx->logits_valid.at(i));
-
llama_synchronize(ctx);
- return ctx->logits + i*ctx->model.hparams.n_vocab;
+ try {
+ if (ctx->logits == nullptr) {
+ throw std::runtime_error("no logits");
+ }
+ if ((size_t) i >= ctx->output_ids.size()) {
+ throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
+ }
+ const int32_t j = ctx->output_ids[i];
+
+ if (j < 0) {
+ throw std::runtime_error(format("batch.logits[%d] != true", i));
+ }
+ if ((size_t) j >= ctx->output_size) {
+ // This should not happen
+ throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
+ }
+
+ return ctx->logits + j*ctx->model.hparams.n_vocab;
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
+#ifndef NDEBUG
+ GGML_ASSERT(false);
+#endif
+ return nullptr;
+ }
}
float * llama_get_embeddings(struct llama_context * ctx) {
@@ -14814,7 +15216,31 @@ float * llama_get_embeddings(struct llama_context * ctx) {
float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
llama_synchronize(ctx);
- return ctx->embd + i*ctx->model.hparams.n_embd;
+ try {
+ if (ctx->embd == nullptr) {
+ throw std::runtime_error("no embeddings");
+ }
+ if ((size_t) i >= ctx->output_ids.size()) {
+ throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
+ }
+ const int32_t j = ctx->output_ids[i];
+
+ if (j < 0) {
+ throw std::runtime_error(format("batch.logits[%d] != true", i));
+ }
+ if ((size_t) j >= ctx->output_size) {
+ // This should not happen
+ throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
+ }
+
+ return ctx->embd + j*ctx->model.hparams.n_embd;
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what());
+#ifndef NDEBUG
+ GGML_ASSERT(false);
+#endif
+ return nullptr;
+ }
}
float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {
diff --git a/llama.h b/llama.h
index 54d62240..1fe4af49 100644
--- a/llama.h
+++ b/llama.h
@@ -39,7 +39,7 @@
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
-#define LLAMA_SESSION_VERSION 4
+#define LLAMA_SESSION_VERSION 5
#ifdef __cplusplus
extern "C" {
@@ -678,23 +678,29 @@ extern "C" {
LLAMA_API void llama_synchronize(struct llama_context * ctx);
// Token logits obtained from the last call to llama_decode()
- // The logits for the last token are stored in the last row
- // Logits for which llama_batch.logits[i] == 0 are undefined
- // Rows: n_tokens provided with llama_batch
+ // The logits for which llama_batch.logits[i] != 0 are stored contiguously
+ // in the order they have appeared in the batch.
+ // Rows: number of tokens for which llama_batch.logits[i] != 0
// Cols: n_vocab
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
// Logits for the ith token. Equivalent to:
- // llama_get_logits(ctx) + i*n_vocab
+ // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
+ // returns NULL for invalid ids.
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
- // Get all output token embeddings
- // shape: [n_tokens*n_embd] (1-dimensional)
+ // Get all output token embeddings.
+ // when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
+ // the embeddings for which llama_batch.logits[i] != 0 are stored contiguously
+ // in the order they have appeared in the batch.
+ // shape: [n_outputs*n_embd]
+ // Otherwise, returns NULL.
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
- // Get the embeddings for the ith token
- // llama_get_embeddings(ctx) + i*n_embd
+ // Get the embeddings for the ith token. Equivalent to:
+ // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
// shape: [n_embd] (1-dimensional)
+ // returns NULL for invalid ids.
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
// Get the embeddings for a sequence id