summaryrefslogtreecommitdiff
path: root/llama.h
diff options
context:
space:
mode:
authorDouglas Hanley <thesecretaryofwar@gmail.com>2024-02-13 06:06:58 -0600
committerGitHub <noreply@github.com>2024-02-13 14:06:58 +0200
commit03bf161eb6dea6400ee49c6dc6b69bdcfa9fd3fc (patch)
tree49320ac8aca35d2ba8162c2a280924bacbd7e06b /llama.h
parentad014bba97ef6ef6c3e2f78b2fc463e91ae94579 (diff)
llama : support batched embeddings (#5466)
* batched embedding: pool outputs by sequence id. updated embedding example * bring back non-causal attention * embd : minor improvements * llama : minor --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'llama.h')
-rw-r--r--llama.h5
1 files changed, 5 insertions, 0 deletions
diff --git a/llama.h b/llama.h
index 367e8f1a..5ef78ec9 100644
--- a/llama.h
+++ b/llama.h
@@ -236,6 +236,7 @@ extern "C" {
bool logits_all; // the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embedding; // embedding mode only
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
+ bool do_pooling; // whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
};
// model quantization parameters
@@ -628,6 +629,10 @@ extern "C" {
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
+ // Get the embeddings for the ith sequence
+ // llama_get_embeddings(ctx) + i*n_embd
+ LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
+
//
// Vocab
//