summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorhowlger <eclipse@voormann.de>2024-03-27 12:15:44 +0100
committerGitHub <noreply@github.com>2024-03-27 13:15:44 +0200
commit1e13987fba5a536965ef942f2c86549d62cef50b (patch)
treed20b7bc3fc78d8a73133890d45f58624b68bcbac /examples
parente82f9e2b833d88cd2b30123ef57346c2cb8abd99 (diff)
embedding : show full embedding for single prompt (#6342)
* embedding : show full embedding for single prompt To support the use case of creating an embedding for a given prompt, the entire embedding and not just the first part needed to be printed. Also, show cosine similarity matrix only if there is more than one prompt, as the cosine similarity matrix for a single prompt is always `1.00`. * Update examples/embedding/embedding.cpp --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples')
-rw-r--r--examples/embedding/embedding.cpp20
1 files changed, 11 insertions, 9 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp
index 9aede7fa..53665752 100644
--- a/examples/embedding/embedding.cpp
+++ b/examples/embedding/embedding.cpp
@@ -178,25 +178,27 @@ int main(int argc, char ** argv) {
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);
- // print the first part of the embeddings
+ // print the first part of the embeddings or for a single prompt, the full embedding
fprintf(stdout, "\n");
for (int j = 0; j < n_prompts; j++) {
fprintf(stdout, "embedding %d: ", j);
- for (int i = 0; i < std::min(16, n_embd); i++) {
+ for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
}
fprintf(stdout, "\n");
}
// print cosine similarity matrix
- fprintf(stdout, "\n");
- printf("cosine similarity matrix:\n\n");
- for (int i = 0; i < n_prompts; i++) {
- for (int j = 0; j < n_prompts; j++) {
- float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
- fprintf(stdout, "%6.2f ", sim);
- }
+ if (n_prompts > 1) {
fprintf(stdout, "\n");
+ printf("cosine similarity matrix:\n\n");
+ for (int i = 0; i < n_prompts; i++) {
+ for (int j = 0; j < n_prompts; j++) {
+ float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
+ fprintf(stdout, "%6.2f ", sim);
+ }
+ fprintf(stdout, "\n");
+ }
}
// clean up