diff options
Diffstat (limited to 'examples/embedding/embedding.cpp')
-rw-r--r-- | examples/embedding/embedding.cpp | 28 |
1 files changed, 21 insertions, 7 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index acff715e..ff5883da 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -19,11 +19,11 @@ static std::vector<std::string> split_lines(const std::string & s) { static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) { for (size_t i = 0; i < tokens.size(); i++) { - llama_batch_add(batch, tokens[i], i, { seq_id }, false); + llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); } } -static void normalize(float * vec, float * out, int n) { +static void normalize(const float * vec, float * out, int n) { float norm = 0; for (int i = 0; i < n; i++) { norm += vec[i] * vec[i]; @@ -45,10 +45,23 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } // normalize on copy - for (int k = 0; k < n_seq; k++) { - float * emb = llama_get_embeddings_ith(ctx, k); - float * out = output + k * n_embd; - normalize(emb, out, n_embd); + for (int i = 0; i < batch.n_tokens; i++) { + if (!batch.logits[i]) { + continue; + } + + // try to get sequence embeddings - supported only when pooling_type is not NONE + const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + if (embd == NULL) { + embd = llama_get_embeddings_ith(ctx, i); + if (embd == NULL) { + fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i); + continue; + } + } + + float * out = output + batch.seq_id[i][0] * n_embd; + normalize(embd, out, n_embd); } } @@ -132,7 +145,7 @@ int main(int argc, char ** argv) { // initialize batch const int n_prompts = prompts.size(); - struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts); + struct llama_batch batch = llama_batch_init(n_batch, 0, 1); // allocate output const int n_embd = llama_n_embd(model); @@ -145,6 +158,7 @@ int main(int argc, char ** argv) { for (int k = 0; k < n_prompts; k++) { // clamp to n_batch tokens auto & inp = inputs[k]; + const uint64_t n_toks = inp.size(); // encode if at capacity |