summaryrefslogtreecommitdiff
path: root/examples/embedding/embedding.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/embedding/embedding.cpp')
-rw-r--r--examples/embedding/embedding.cpp21
1 files changed, 16 insertions, 5 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp
index 49302a19..f390c406 100644
--- a/examples/embedding/embedding.cpp
+++ b/examples/embedding/embedding.cpp
@@ -168,14 +168,25 @@ int main(int argc, char ** argv) {
batch_decode(ctx, batch, out, s, n_embd);
// print first 3 embeddings
+ fprintf(stdout, "\n");
for (int j = 0; j < std::min(3, n_prompts); j++) {
- fprintf(stderr, "embedding %d: ", j);
- for (int i = 0; i < n_embd; i++) {
- fprintf(stderr, "%f ", emb[j * n_embd + i]);
+ fprintf(stdout, "embedding %d: ", j);
+ for (int i = 0; i < std::min(16, n_embd); i++) {
+ fprintf(stdout, "%f ", emb[j * n_embd + i]);
}
- fprintf(stderr, "\n\n");
+ fprintf(stdout, "\n");
+ }
+
+ // print cosine similarity matrix
+ fprintf(stdout, "\n");
+ printf("cosine similarity matrix:\n\n");
+ for (int i = 0; i < n_prompts; i++) {
+ for (int j = 0; j < n_prompts; j++) {
+ float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
+ fprintf(stdout, "%6.2f ", sim);
+ }
+ fprintf(stdout, "\n");
}
- fprintf(stderr, "\n");
// clean up
llama_print_timings(ctx);