diff options
Diffstat (limited to 'common')
| -rw-r--r-- | common/common.cpp | 13 | ||||
| -rw-r--r-- | common/common.h | 1 |
2 files changed, 14 insertions, 0 deletions
diff --git a/common/common.cpp b/common/common.cpp index 73b1b61b..58fbd05a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1877,3 +1877,16 @@ void llama_embd_normalize(const float * inp, float * out, int n) { } } +float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n){ + double sum = 0.0; + double sum1 = 0.0; + double sum2 = 0.0; + + for (int i = 0; i < n; i++) { + sum += embd1[i] * embd2[i]; + sum1 += embd1[i] * embd1[i]; + sum2 += embd2[i] * embd2[i]; + } + + return sum / (sqrt(sum1) * sqrt(sum2)); +} diff --git a/common/common.h b/common/common.h index 0f178b9e..d250eef8 100644 --- a/common/common.h +++ b/common/common.h @@ -268,3 +268,4 @@ void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40 void llama_embd_normalize(const float * inp, float * out, int n); +float llama_embd_similarity_cos(const float * embd1, const float * embd2, int n); |
