diff options
Diffstat (limited to 'examples/embedding/embedding.cpp')
-rw-r--r-- | examples/embedding/embedding.cpp | 14 |
1 files changed, 1 insertions, 13 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index ff5883da..a553ae1c 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -23,17 +23,6 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke } } -static void normalize(const float * vec, float * out, int n) { - float norm = 0; - for (int i = 0; i < n; i++) { - norm += vec[i] * vec[i]; - } - norm = sqrt(norm); - for (int i = 0; i < n; i++) { - out[i] = vec[i] / norm; - } -} - static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { // clear previous kv_cache values (irrelevant for embeddings) llama_kv_cache_clear(ctx); @@ -44,7 +33,6 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu fprintf(stderr, "%s : failed to decode\n", __func__); } - // normalize on copy for (int i = 0; i < batch.n_tokens; i++) { if (!batch.logits[i]) { continue; @@ -61,7 +49,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } float * out = output + batch.seq_id[i][0] * n_embd; - normalize(embd, out, n_embd); + llama_embd_normalize(embd, out, n_embd); } } |