diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2024-03-14 10:12:29 +0200 |
---|---|---|
committer | Georgi Gerganov <ggerganov@gmail.com> | 2024-03-14 10:12:29 +0200 |
commit | 0fd6c1f015f6cccf3b527f7dbd8386a434728126 (patch) | |
tree | 922b6618221ac91a03da9da22e3727f4cb40ceb1 /examples/embedding/embedding.cpp | |
parent | 19885d205e768579ab090d1e99281cae58c21b54 (diff) |
embedding : print cosine similarity (#899)
Diffstat (limited to 'examples/embedding/embedding.cpp')
-rw-r--r-- | examples/embedding/embedding.cpp | 21 |
1 files changed, 16 insertions, 5 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 49302a19..f390c406 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -168,14 +168,25 @@ int main(int argc, char ** argv) { batch_decode(ctx, batch, out, s, n_embd); // print first 3 embeddings + fprintf(stdout, "\n"); for (int j = 0; j < std::min(3, n_prompts); j++) { - fprintf(stderr, "embedding %d: ", j); - for (int i = 0; i < n_embd; i++) { - fprintf(stderr, "%f ", emb[j * n_embd + i]); + fprintf(stdout, "embedding %d: ", j); + for (int i = 0; i < std::min(16, n_embd); i++) { + fprintf(stdout, "%f ", emb[j * n_embd + i]); } - fprintf(stderr, "\n\n"); + fprintf(stdout, "\n"); + } + + // print cosine similarity matrix + fprintf(stdout, "\n"); + printf("cosine similarity matrix:\n\n"); + for (int i = 0; i < n_prompts; i++) { + for (int j = 0; j < n_prompts; j++) { + float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); + fprintf(stdout, "%6.2f ", sim); + } + fprintf(stdout, "\n"); } - fprintf(stderr, "\n"); // clean up llama_print_timings(ctx); |