summaryrefslogtreecommitdiff
path: root/llama.h
diff options
context:
space:
mode:
Diffstat (limited to 'llama.h')
-rw-r--r--llama.h24
1 files changed, 15 insertions, 9 deletions
diff --git a/llama.h b/llama.h
index 54d62240..1fe4af49 100644
--- a/llama.h
+++ b/llama.h
@@ -39,7 +39,7 @@
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
-#define LLAMA_SESSION_VERSION 4
+#define LLAMA_SESSION_VERSION 5
#ifdef __cplusplus
extern "C" {
@@ -678,23 +678,29 @@ extern "C" {
LLAMA_API void llama_synchronize(struct llama_context * ctx);
// Token logits obtained from the last call to llama_decode()
- // The logits for the last token are stored in the last row
- // Logits for which llama_batch.logits[i] == 0 are undefined
- // Rows: n_tokens provided with llama_batch
+ // The logits for which llama_batch.logits[i] != 0 are stored contiguously
+ // in the order they have appeared in the batch.
+ // Rows: number of tokens for which llama_batch.logits[i] != 0
// Cols: n_vocab
LLAMA_API float * llama_get_logits(struct llama_context * ctx);
// Logits for the ith token. Equivalent to:
- // llama_get_logits(ctx) + i*n_vocab
+ // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
+ // returns NULL for invalid ids.
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
- // Get all output token embeddings
- // shape: [n_tokens*n_embd] (1-dimensional)
+ // Get all output token embeddings.
+ // when pooling_type == LLAMA_POOLING_TYPE_NONE or when using a generative model,
+ // the embeddings for which llama_batch.logits[i] != 0 are stored contiguously
+ // in the order they have appeared in the batch.
+ // shape: [n_outputs*n_embd]
+ // Otherwise, returns NULL.
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
- // Get the embeddings for the ith token
- // llama_get_embeddings(ctx) + i*n_embd
+ // Get the embeddings for the ith token. Equivalent to:
+ // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
// shape: [n_embd] (1-dimensional)
+ // returns NULL for invalid ids.
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
// Get the embeddings for a sequence id