summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/common.cpp15
-rw-r--r--common/common.h7
-rw-r--r--examples/embedding/embedding.cpp14
-rw-r--r--examples/server/server.cpp8
4 files changed, 30 insertions, 14 deletions
diff --git a/common/common.cpp b/common/common.cpp
index d7f650ef..16ef4d7f 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -1852,3 +1852,18 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
printf("\n=== Done dumping\n");
}
+
+void llama_embd_normalize(const float * inp, float * out, int n) {
+ double sum = 0.0;
+ for (int i = 0; i < n; i++) {
+ sum += inp[i] * inp[i];
+ }
+ sum = sqrt(sum);
+
+ const float norm = sum > 0.0 ? 1.0f / sum : 0.0f;
+
+ for (int i = 0; i < n; i++) {
+ out[i] = inp[i] * norm;
+ }
+}
+
diff --git a/common/common.h b/common/common.h
index 977ce419..f8d82b87 100644
--- a/common/common.h
+++ b/common/common.h
@@ -260,3 +260,10 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
// Dump the KV cache view showing individual sequences in each cell (long output).
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
+
+//
+// Embedding utils
+//
+
+void llama_embd_normalize(const float * inp, float * out, int n);
+
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp
index ff5883da..a553ae1c 100644
--- a/examples/embedding/embedding.cpp
+++ b/examples/embedding/embedding.cpp
@@ -23,17 +23,6 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
}
}
-static void normalize(const float * vec, float * out, int n) {
- float norm = 0;
- for (int i = 0; i < n; i++) {
- norm += vec[i] * vec[i];
- }
- norm = sqrt(norm);
- for (int i = 0; i < n; i++) {
- out[i] = vec[i] / norm;
- }
-}
-
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
@@ -44,7 +33,6 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
fprintf(stderr, "%s : failed to decode\n", __func__);
}
- // normalize on copy
for (int i = 0; i < batch.n_tokens; i++) {
if (!batch.logits[i]) {
continue;
@@ -61,7 +49,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
}
float * out = output + batch.seq_id[i][0] * n_embd;
- normalize(embd, out, n_embd);
+ llama_embd_normalize(embd, out, n_embd);
}
}
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 8cff514f..796f3499 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -1327,6 +1327,8 @@ struct server_context {
const int n_embd = llama_n_embd(model);
+ std::vector<float> embd_res(n_embd, 0.0f);
+
for (int i = 0; i < batch.n_tokens; ++i) {
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
continue;
@@ -1350,8 +1352,10 @@ struct server_context {
continue;
}
+ llama_embd_normalize(embd, embd_res.data(), n_embd);
+
res.data = json {
- {"embedding", std::vector<float>(embd, embd + n_embd)},
+ {"embedding", embd_res},
};
}
@@ -3354,6 +3358,8 @@ int main(int argc, char ** argv) {
// get the result
server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
+
+ // append to the responses
responses.push_back(result.data);
}