summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-03-14 10:12:29 +0200
committerGeorgi Gerganov <ggerganov@gmail.com>2024-03-14 10:12:29 +0200
commit0fd6c1f015f6cccf3b527f7dbd8386a434728126 (patch)
tree922b6618221ac91a03da9da22e3727f4cb40ceb1
parent19885d205e768579ab090d1e99281cae58c21b54 (diff)
embedding : print cosine similarity (#899)
-rw-r--r--common/common.cpp13
-rw-r--r--common/common.h1
-rw-r--r--examples/embedding/embedding.cpp21
-rw-r--r--examples/gritlm/gritlm.cpp26
4 files changed, 36 insertions, 25 deletions
diff --git a/common/common.cpp b/common/common.cpp
index 73b1b61b..58fbd05a 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -1877,3 +1877,16 @@ void llama_embd_normalize(const float * inp, float * out, int n) {
}
}
+float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n){
+ double sum = 0.0;
+ double sum1 = 0.0;
+ double sum2 = 0.0;
+
+ for (int i = 0; i < n; i++) {
+ sum += embd1[i] * embd2[i];
+ sum1 += embd1[i] * embd1[i];
+ sum2 += embd2[i] * embd2[i];
+ }
+
+ return sum / (sqrt(sum1) * sqrt(sum2));
+}
diff --git a/common/common.h b/common/common.h
index 0f178b9e..d250eef8 100644
--- a/common/common.h
+++ b/common/common.h
@@ -268,3 +268,4 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40
void llama_embd_normalize(const float * inp, float * out, int n);
+float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n);
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp
index 49302a19..f390c406 100644
--- a/examples/embedding/embedding.cpp
+++ b/examples/embedding/embedding.cpp
@@ -168,14 +168,25 @@ int main(int argc, char ** argv) {
batch_decode(ctx, batch, out, s, n_embd);
// print first 3 embeddings
+ fprintf(stdout, "\n");
for (int j = 0; j < std::min(3, n_prompts); j++) {
- fprintf(stderr, "embedding %d: ", j);
- for (int i = 0; i < n_embd; i++) {
- fprintf(stderr, "%f ", emb[j * n_embd + i]);
+ fprintf(stdout, "embedding %d: ", j);
+ for (int i = 0; i < std::min(16, n_embd); i++) {
+ fprintf(stdout, "%f ", emb[j * n_embd + i]);
}
- fprintf(stderr, "\n\n");
+ fprintf(stdout, "\n");
+ }
+
+ // print cosine similarity matrix
+ fprintf(stdout, "\n");
+ printf("cosine similarity matrix:\n\n");
+ for (int i = 0; i < n_prompts; i++) {
+ for (int j = 0; j < n_prompts; j++) {
+ float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
+ fprintf(stdout, "%6.2f ", sim);
+ }
+ fprintf(stdout, "\n");
}
- fprintf(stderr, "\n");
// clean up
llama_print_timings(ctx);
diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp
index 3d4b085d..52fd719b 100644
--- a/examples/gritlm/gritlm.cpp
+++ b/examples/gritlm/gritlm.cpp
@@ -6,22 +6,6 @@
// #define GRIT_DEBUG
-static float dot_product(const std::vector<float> & v1, const std::vector<float> & v2) {
- float dot = 0.0f;
- for (uint64_t i = 0; i < v1.size(); ++i) {
- dot += v1[i] * v2[i];
- }
- return dot;
-}
-
-static float norm(const std::vector<float> & v) {
- return std::sqrt(dot_product(v, v));
-}
-
-static float cosine_similarity(const std::vector<float> & v1, const std::vector<float> & v2) {
- return dot_product(v1, v2) / (norm(v1) * norm(v2));
-}
-
static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
std::vector<std::vector<float>> result;
@@ -203,10 +187,12 @@ int main(int argc, char * argv[]) {
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
- const float cosine_sim_q0_d0 = cosine_similarity(q_rep[0], d_rep[0]);
- const float cosine_sim_q0_d1 = cosine_similarity(q_rep[0], d_rep[1]);
- const float cosine_sim_q1_d0 = cosine_similarity(q_rep[1], d_rep[0]);
- const float cosine_sim_q1_d1 = cosine_similarity(q_rep[1], d_rep[1]);
+ const int n_embd = llama_n_embd(mdl);
+
+ const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
+ const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
+ const float cosine_sim_q1_d0 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[0].data(), n_embd);
+ const float cosine_sim_q1_d1 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[1].data(), n_embd);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[0].c_str(), cosine_sim_q0_d0);
std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[1].c_str(), cosine_sim_q0_d1);