summaryrefslogtreecommitdiff
path: root/llama.h
diff options
context:
space:
mode:
Diffstat (limited to 'llama.h')
-rw-r--r--llama.h18
1 files changed, 12 insertions, 6 deletions
diff --git a/llama.h b/llama.h
index 70da4cb3..3dc162b0 100644
--- a/llama.h
+++ b/llama.h
@@ -163,7 +163,7 @@ extern "C" {
// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
// - pos : the positions of the respective token in the sequence
// - seq_id : the sequence to which the respective token belongs
- // - logits : if zero, the logits for the respective token will not be output
+ // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
//
typedef struct llama_batch {
int32_t n_tokens;
@@ -173,7 +173,7 @@ extern "C" {
llama_pos * pos;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
- int8_t * logits;
+ int8_t * logits; // TODO: rename this to "output"
// NOTE: helpers for smooth API transition - can be deprecated in the future
// for future-proof code, use the above fields instead and ignore everything below
@@ -260,7 +260,7 @@ extern "C" {
// Keep the booleans together to avoid misalignment during copy-by-value.
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
- bool embedding; // embedding mode only
+ bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
// Abort callback
@@ -655,14 +655,20 @@ extern "C" {
// llama_get_logits(ctx) + i*n_vocab
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);
- // Get the embeddings for the input
- // shape: [n_embd] (1-dimensional)
+ // Get all output token embeddings
+ // shape: [n_tokens*n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
- // Get the embeddings for the ith sequence
+ // Get the embeddings for the ith token
// llama_get_embeddings(ctx) + i*n_embd
+ // shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
+ // Get the embeddings for a sequence id
+ // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE
+ // shape: [n_embd] (1-dimensional)
+ LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);
+
//
// Vocab
//