summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJared Van Bortel <jared@nomic.ai>2023-12-15 22:16:15 -0500
committerGitHub <noreply@github.com>2023-12-15 22:16:15 -0500
commit8a5be3bd5885d79ad84aadf32bb8c1a67bd43c19 (patch)
treec10a5cdb0796969cac2d9d29c915c558ad7d96cc
parent88ae8952b65cbf32eb1f5703681ea592e510e570 (diff)
llama : sanity checks for access to logits (#4274)
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
-rw-r--r--llama.cpp22
1 files changed, 22 insertions, 0 deletions
diff --git a/llama.cpp b/llama.cpp
index eddb7085..58fe7492 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1505,6 +1505,10 @@ struct llama_context {
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
+#ifndef NDEBUG
+ // guard against access to unset logits
+ std::vector<bool> logits_valid;
+#endif
bool logits_all = false;
// input embedding (1-dimensional array: [n_embd])
@@ -6150,6 +6154,14 @@ static int llama_decode_internal(
{
auto & logits_out = lctx.logits;
+#ifndef NDEBUG
+ auto & logits_valid = lctx.logits_valid;
+ logits_valid.clear();
+ logits_valid.resize(n_tokens);
+
+ logits_out.clear();
+#endif
+
if (batch.logits) {
logits_out.resize(n_vocab * n_tokens);
for (uint32_t i = 0; i < n_tokens; i++) {
@@ -6157,13 +6169,22 @@ static int llama_decode_internal(
continue;
}
memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab);
+#ifndef NDEBUG
+ logits_valid[i] = true;
+#endif
}
} else if (lctx.logits_all) {
logits_out.resize(n_vocab * n_tokens);
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
+#ifndef NDEBUG
+ std::fill(logits_valid.begin(), logits_valid.end(), true);
+#endif
} else {
logits_out.resize(n_vocab);
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab);
+#ifndef NDEBUG
+ logits_valid[n_tokens - 1] = true;
+#endif
}
}
@@ -10052,6 +10073,7 @@ float * llama_get_logits(struct llama_context * ctx) {
}
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
+ assert(ctx->logits_valid.at(i));
return ctx->logits.data() + i*ctx->model.hparams.n_vocab;
}