summaryrefslogtreecommitdiff
path: root/examples/embedding/embedding.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/embedding/embedding.cpp')
-rw-r--r--examples/embedding/embedding.cpp115
1 files changed, 82 insertions, 33 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp
index b4b73c01..1466e5b2 100644
--- a/examples/embedding/embedding.cpp
+++ b/examples/embedding/embedding.cpp
@@ -7,13 +7,19 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
-static std::vector<std::string> split_lines(const std::string & s) {
- std::string line;
+static std::vector<std::string> split_lines(const std::string & s, const std::string & separator = "\n") {
std::vector<std::string> lines;
- std::stringstream ss(s);
- while (std::getline(ss, line)) {
- lines.push_back(line);
+ size_t start = 0;
+ size_t end = s.find(separator);
+
+ while (end != std::string::npos) {
+ lines.push_back(s.substr(start, end - start));
+ start = end + separator.length();
+ end = s.find(separator, start);
}
+
+ lines.push_back(s.substr(start)); // Add the last part
+
return lines;
}
@@ -24,7 +30,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
}
}
-static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
+static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
@@ -44,13 +50,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
float * out = output + batch.seq_id[i][0] * n_embd;
- //TODO: I would also add a parameter here to enable normalization or not.
- /*fprintf(stdout, "unnormalized_embedding:");
- for (int hh = 0; hh < n_embd; hh++) {
- fprintf(stdout, "%9.6f ", embd[hh]);
- }
- fprintf(stdout, "\n");*/
- llama_embd_normalize(embd, out, n_embd);
+ llama_embd_normalize(embd, out, n_embd, embd_norm);
}
}
@@ -110,7 +110,7 @@ int main(int argc, char ** argv) {
}
// split the prompt into lines
- std::vector<std::string> prompts = split_lines(params.prompt);
+ std::vector<std::string> prompts = split_lines(params.prompt, params.embd_sep);
// max batch size
const uint64_t n_batch = params.n_batch;
@@ -170,7 +170,7 @@ int main(int argc, char ** argv) {
// encode if at capacity
if (batch.n_tokens + n_toks > n_batch) {
float * out = emb + p * n_embd;
- batch_decode(ctx, batch, out, s, n_embd);
+ batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
llama_batch_clear(batch);
p += s;
s = 0;
@@ -183,29 +183,78 @@ int main(int argc, char ** argv) {
// final batch
float * out = emb + p * n_embd;
- batch_decode(ctx, batch, out, s, n_embd);
-
- // print the first part of the embeddings or for a single prompt, the full embedding
- fprintf(stdout, "\n");
- for (int j = 0; j < n_prompts; j++) {
- fprintf(stdout, "embedding %d: ", j);
- for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
- fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
- }
- fprintf(stdout, "\n");
- }
+ batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
- // print cosine similarity matrix
- if (n_prompts > 1) {
+ if (params.embd_out.empty()) {
+ // print the first part of the embeddings or for a single prompt, the full embedding
fprintf(stdout, "\n");
- printf("cosine similarity matrix:\n\n");
- for (int i = 0; i < n_prompts; i++) {
- for (int j = 0; j < n_prompts; j++) {
- float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
- fprintf(stdout, "%6.2f ", sim);
+ for (int j = 0; j < n_prompts; j++) {
+ fprintf(stdout, "embedding %d: ", j);
+ for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
+ if (params.embd_normalize == 0) {
+ fprintf(stdout, "%6.0f ", emb[j * n_embd + i]);
+ } else {
+ fprintf(stdout, "%9.6f ", emb[j * n_embd + i]);
+ }
+ }
+ fprintf(stdout, "\n");
+ }
+
+ // print cosine similarity matrix
+ if (n_prompts > 1) {
+ fprintf(stdout, "\n");
+ printf("cosine similarity matrix:\n\n");
+ for (int i = 0; i < n_prompts; i++) {
+ fprintf(stdout, "%6.6s ", prompts[i].c_str());
}
fprintf(stdout, "\n");
+ for (int i = 0; i < n_prompts; i++) {
+ for (int j = 0; j < n_prompts; j++) {
+ float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
+ fprintf(stdout, "%6.2f ", sim);
+ }
+ fprintf(stdout, "%1.10s", prompts[i].c_str());
+ fprintf(stdout, "\n");
+ }
+ }
+ }
+
+ if (params.embd_out == "json" || params.embd_out == "json+" || params.embd_out == "array") {
+ const bool notArray = params.embd_out != "array";
+
+ fprintf(stdout, notArray ? "{\n \"object\": \"list\",\n \"data\": [\n" : "[");
+ for (int j = 0;;) { // at least one iteration (one prompt)
+ if (notArray) fprintf(stdout, " {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j);
+ fprintf(stdout, "[");
+ for (int i = 0;;) { // at least one iteration (n_embd > 0)
+ fprintf(stdout, params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd + i]);
+ i++;
+ if (i < n_embd) fprintf(stdout, ","); else break;
+ }
+ fprintf(stdout, notArray ? "]\n }" : "]");
+ j++;
+ if (j < n_prompts) fprintf(stdout, notArray ? ",\n" : ","); else break;
}
+ fprintf(stdout, notArray ? "\n ]" : "]\n");
+
+ if (params.embd_out == "json+" && n_prompts > 1) {
+ fprintf(stdout, ",\n \"cosineSimilarity\": [\n");
+ for (int i = 0;;) { // at least two iteration (n_prompts > 1)
+ fprintf(stdout, " [");
+ for (int j = 0;;) { // at least two iteration (n_prompts > 1)
+ float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
+ fprintf(stdout, "%6.2f", sim);
+ j++;
+ if (j < n_prompts) fprintf(stdout, ", "); else break;
+ }
+ fprintf(stdout, " ]");
+ i++;
+ if (i < n_prompts) fprintf(stdout, ",\n"); else break;
+ }
+ fprintf(stdout, "\n ]");
+ }
+
+ if (notArray) fprintf(stdout, "\n}\n");
}
// clean up