summaryrefslogtreecommitdiff
path: root/examples/embedding
diff options
context:
space:
mode:
Diffstat (limited to 'examples/embedding')
-rw-r--r--examples/embedding/embedding.cpp20
1 files changed, 11 insertions, 9 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp
index 9aede7fa..53665752 100644
--- a/examples/embedding/embedding.cpp
+++ b/examples/embedding/embedding.cpp
@@ -178,25 +178,27 @@ int main(int argc, char ** argv) {
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);
- // print the first part of the embeddings
+ // print the first part of the embeddings or for a single prompt, the full embedding
fprintf(stdout, "\n");
for (int j = 0; j < n_prompts; j++) {
fprintf(stdout, "embedding %d: ", j);
- for (int i = 0; i < std::min(16, n_embd); i++) {
+ for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
}
fprintf(stdout, "\n");
}
// print cosine similarity matrix
- fprintf(stdout, "\n");
- printf("cosine similarity matrix:\n\n");
- for (int i = 0; i < n_prompts; i++) {
- for (int j = 0; j < n_prompts; j++) {
- float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
- fprintf(stdout, "%6.2f ", sim);
- }
+ if (n_prompts > 1) {
fprintf(stdout, "\n");
+ printf("cosine similarity matrix:\n\n");
+ for (int i = 0; i < n_prompts; i++) {
+ for (int j = 0; j < n_prompts; j++) {
+ float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
+ fprintf(stdout, "%6.2f ", sim);
+ }
+ fprintf(stdout, "\n");
+ }
}
// clean up