summaryrefslogtreecommitdiff
path: root/llama.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llama.cpp')
-rw-r--r--llama.cpp22
1 files changed, 17 insertions, 5 deletions
diff --git a/llama.cpp b/llama.cpp
index c58a029f..b19616e8 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -8925,17 +8925,29 @@ static int llama_decode_internal(
if (batch.logits) {
logits_out.resize(n_vocab * n_tokens);
+ int32_t i_first = -1;
for (uint32_t i = 0; i < n_tokens; i++) {
- if (batch.logits[i] == 0) {
- continue;
+ if (batch.logits[i] && i_first == -1) {
+ i_first = (int32_t) i;
+ }
+ if (batch.logits[i] == 0 || i == n_tokens - 1) {
+ if (i_first != -1) {
+ int i_last = batch.logits[i] == 0 ? i : i + 1;
+ // extract logits for the range [i_first, i_last)
+ // group the requests to minimize the number of calls to the backend
+ ggml_backend_tensor_get_async(backend_res, res,
+ logits_out.data() + (n_vocab*i_first),
+ (n_vocab*i_first)*sizeof(float),
+ (i_last - i_first)*n_vocab*sizeof(float));
+ i_first = -1;
+ }
}
- ggml_backend_tensor_get_async(backend_res, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float));
#ifndef NDEBUG
- logits_valid[i] = true;
+ logits_valid[i] = batch.logits[i] != 0;
#endif
}
} else if (lctx.logits_all) {
- logits_out.resize(n_vocab * n_tokens);
+ logits_out.resize(n_vocab*n_tokens);
ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float));
#ifndef NDEBUG
std::fill(logits_valid.begin(), logits_valid.end(), true);