summaryrefslogtreecommitdiff
path: root/examples/embedding
diff options
context:
space:
mode:
authorSeungWon Jeong <65549245+redlion0929@users.noreply.github.com>2024-03-09 21:27:58 +0900
committerGitHub <noreply@github.com>2024-03-09 14:27:58 +0200
commitfb215c3832236fec7380c4fb618bd7154cb196ef (patch)
tree1c2d0eb8fce7d2c1f70024b1f2c7b4b35baaa029 /examples/embedding
parent2c4f566c88322ebf2f9bd11b01b5ebdaa0130b89 (diff)
server : normalize embeddings (#5956)
* output normalize embedding in '/v1/embeddings' * common : reuse llama_embd_normalize * common : better normalize impl --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/embedding')
-rw-r--r--examples/embedding/embedding.cpp14
1 files changed, 1 insertions, 13 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp
index ff5883da..a553ae1c 100644
--- a/examples/embedding/embedding.cpp
+++ b/examples/embedding/embedding.cpp
@@ -23,17 +23,6 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
}
}
-static void normalize(const float * vec, float * out, int n) {
- float norm = 0;
- for (int i = 0; i < n; i++) {
- norm += vec[i] * vec[i];
- }
- norm = sqrt(norm);
- for (int i = 0; i < n; i++) {
- out[i] = vec[i] / norm;
- }
-}
-
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
@@ -44,7 +33,6 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
fprintf(stderr, "%s : failed to decode\n", __func__);
}
- // normalize on copy
for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
continue;
@@ -61,7 +49,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}
float * out = output + batch.seq_id[i][0] * n_embd;
- normalize(embd, out, n_embd);
+ llama_embd_normalize(embd, out, n_embd);
}
}