summaryrefslogtreecommitdiff
path: root/examples/embedding/embedding.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/embedding/embedding.cpp')
-rw-r--r--examples/embedding/embedding.cpp28
1 files changed, 21 insertions, 7 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp
index acff715e..ff5883da 100644
--- a/examples/embedding/embedding.cpp
+++ b/examples/embedding/embedding.cpp
@@ -19,11 +19,11 @@ static std::vector<std::string> split_lines(const std::string & s) {
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
for (size_t i = 0; i < tokens.size(); i++) {
- llama_batch_add(batch, tokens[i], i, { seq_id }, false);
+ llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
}
}
-static void normalize(float * vec, float * out, int n) {
+static void normalize(const float * vec, float * out, int n) {
float norm = 0;
for (int i = 0; i < n; i++) {
norm += vec[i] * vec[i];
@@ -45,10 +45,23 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}
// normalize on copy
- for (int k = 0; k < n_seq; k++) {
- float * emb = llama_get_embeddings_ith(ctx, k);
- float * out = output + k * n_embd;
- normalize(emb, out, n_embd);
+ for (int i = 0; i < batch.n_tokens; i++) {
+ if (!batch.logits[i]) {
+ continue;
+ }
+
+ // try to get sequence embeddings - supported only when pooling_type is not NONE
+ const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
+ if (embd == NULL) {
+ embd = llama_get_embeddings_ith(ctx, i);
+ if (embd == NULL) {
+ fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
+ continue;
+ }
+ }
+
+ float * out = output + batch.seq_id[i][0] * n_embd;
+ normalize(embd, out, n_embd);
}
}
@@ -132,7 +145,7 @@ int main(int argc, char ** argv) {
// initialize batch
const int n_prompts = prompts.size();
- struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts);
+ struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
// allocate output
const int n_embd = llama_n_embd(model);
@@ -145,6 +158,7 @@ int main(int argc, char ** argv) {
for (int k = 0; k < n_prompts; k++) {
// clamp to n_batch tokens
auto & inp = inputs[k];
+
const uint64_t n_toks = inp.size();
// encode if at capacity