diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2024-03-14 10:12:29 +0200 |
---|---|---|
committer | Georgi Gerganov <ggerganov@gmail.com> | 2024-03-14 10:12:29 +0200 |
commit | 0fd6c1f015f6cccf3b527f7dbd8386a434728126 (patch) | |
tree | 922b6618221ac91a03da9da22e3727f4cb40ceb1 /examples/gritlm/gritlm.cpp | |
parent | 19885d205e768579ab090d1e99281cae58c21b54 (diff) |
embedding : print cosine similarity (#899)
Diffstat (limited to 'examples/gritlm/gritlm.cpp')
-rw-r--r-- | examples/gritlm/gritlm.cpp | 26 |
1 files changed, 6 insertions, 20 deletions
diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 3d4b085d..52fd719b 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -6,22 +6,6 @@ // #define GRIT_DEBUG -static float dot_product(const std::vector<float> & v1, const std::vector<float> & v2) { - float dot = 0.0f; - for (uint64_t i = 0; i < v1.size(); ++i) { - dot += v1[i] * v2[i]; - } - return dot; -} - -static float norm(const std::vector<float> & v) { - return std::sqrt(dot_product(v, v)); -} - -static float cosine_similarity(const std::vector<float> & v1, const std::vector<float> & v2) { - return dot_product(v1, v2) / (norm(v1) * norm(v2)); -} - static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) { std::vector<std::vector<float>> result; @@ -203,10 +187,12 @@ int main(int argc, char * argv[]) { const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction("")); const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction)); - const float cosine_sim_q0_d0 = cosine_similarity(q_rep[0], d_rep[0]); - const float cosine_sim_q0_d1 = cosine_similarity(q_rep[0], d_rep[1]); - const float cosine_sim_q1_d0 = cosine_similarity(q_rep[1], d_rep[0]); - const float cosine_sim_q1_d1 = cosine_similarity(q_rep[1], d_rep[1]); + const int n_embd = llama_n_embd(mdl); + + const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd); + const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd); + const float cosine_sim_q1_d0 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[0].data(), n_embd); + const float cosine_sim_q1_d1 = llama_embd_similarity_cos(q_rep[1].data(), d_rep[1].data(), n_embd); std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[0].c_str(), cosine_sim_q0_d0); std::printf("Cosine similarity between \"%.50s\" and \"%.50s\" is: %.3f\n", queries[0].c_str(), documents[1].c_str(), cosine_sim_q0_d1); |