diff options
Diffstat (limited to 'common/common.cpp')
-rw-r--r-- | common/common.cpp | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp index 73b1b61b..58fbd05a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1877,3 +1877,16 @@ void llama_embd_normalize(const float * inp, float * out, int n) { } } +float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n){ + double sum = 0.0; + double sum1 = 0.0; + double sum2 = 0.0; + + for (int i = 0; i < n; i++) { + sum += embd1[i] * embd2[i]; + sum1 += embd1[i] * embd1[i]; + sum2 += embd2[i] * embd2[i]; + } + + return sum / (sqrt(sum1) * sqrt(sum2)); +} |