From 29ae62d2ae163e2b68aa0ad3bf2ab4636de0c957 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 4 Mar 2024 22:31:20 +0200 Subject: llama : fix embeddings (#5796) * llama : fix embeddings ggml-ci * llama : do not use KV cache for non-causal models ggml-ci * embeddings : fix llama_batch_init arg * llama : add pooling switch * llama : distinguish token vs sequence embeddings ggml-ci * llama : assert pooling tensor * llama : simplify causal mask condition ggml-ci * llama : assert input batch with pooling enabled * readme : update API changes list --- examples/embedding/embedding.cpp | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) (limited to 'examples/embedding/embedding.cpp') diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index acff715e..ff5883da 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -19,11 +19,11 @@ static std::vector split_lines(const std::string & s) { static void batch_add_seq(llama_batch & batch, const std::vector & tokens, int seq_id) { for (size_t i = 0; i < tokens.size(); i++) { - llama_batch_add(batch, tokens[i], i, { seq_id }, false); + llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); } } -static void normalize(float * vec, float * out, int n) { +static void normalize(const float * vec, float * out, int n) { float norm = 0; for (int i = 0; i < n; i++) { norm += vec[i] * vec[i]; @@ -45,10 +45,23 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } // normalize on copy - for (int k = 0; k < n_seq; k++) { - float * emb = llama_get_embeddings_ith(ctx, k); - float * out = output + k * n_embd; - normalize(emb, out, n_embd); + for (int i = 0; i < batch.n_tokens; i++) { + if (!batch.logits[i]) { + continue; + } + + // try to get sequence embeddings - supported only when pooling_type is not NONE + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + if (embd == NULL) { + fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i); + continue; + } + } + + float * out = output + batch.seq_id[i][0] * n_embd; + normalize(embd, out, n_embd); } } @@ -132,7 +145,7 @@ int main(int argc, char ** argv) { // initialize batch const int n_prompts = prompts.size(); - struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts); + struct llama_batch batch = llama_batch_init(n_batch, 0, 1); // allocate output const int n_embd = llama_n_embd(model); @@ -145,6 +158,7 @@ int main(int argc, char ** argv) { for (int k = 0; k < n_prompts; k++) { // clamp to n_batch tokens auto & inp = inputs[k]; + const uint64_t n_toks = inp.size(); // encode if at capacity -- cgit v1.2.3