diff options
Diffstat (limited to 'llama.h')
-rw-r--r-- | llama.h | 18 |
1 files changed, 12 insertions, 6 deletions
@@ -163,7 +163,7 @@ extern "C" { // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) // - pos : the positions of the respective token in the sequence // - seq_id : the sequence to which the respective token belongs - // - logits : if zero, the logits for the respective token will not be output + // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output // typedef struct llama_batch { int32_t n_tokens; @@ -173,7 +173,7 @@ extern "C" { llama_pos * pos; int32_t * n_seq_id; llama_seq_id ** seq_id; - int8_t * logits; + int8_t * logits; // TODO: rename this to "output" // NOTE: helpers for smooth API transition - can be deprecated in the future // for future-proof code, use the above fields instead and ignore everything below @@ -260,7 +260,7 @@ extern "C" { // Keep the booleans together to avoid misalignment during copy-by-value. bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) - bool embedding; // embedding mode only + bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU // Abort callback @@ -655,14 +655,20 @@ extern "C" { // llama_get_logits(ctx) + i*n_vocab LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); - // Get the embeddings for the input - // shape: [n_embd] (1-dimensional) + // Get all output token embeddings + // shape: [n_tokens*n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); - // Get the embeddings for the ith sequence + // Get the embeddings for the ith token // llama_get_embeddings(ctx) + i*n_embd + // shape: [n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); + // Get the embeddings for a sequence id + // Returns NULL if pooling_type is LLAMA_POOLING_TYPE_NONE + // shape: [n_embd] (1-dimensional) + LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + // // Vocab // |