diff options
Diffstat (limited to 'examples/embedding/embedding.cpp')
-rw-r--r-- | examples/embedding/embedding.cpp | 148 |
1 files changed, 102 insertions, 46 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 1466e5b2..b05aa006 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -31,13 +31,24 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke } static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + const struct llama_model * model = llama_get_model(ctx); + // clear previous kv_cache values (irrelevant for embeddings) llama_kv_cache_clear(ctx); // run model fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); - if (llama_decode(ctx, batch) < 0) { - fprintf(stderr, "%s : failed to decode\n", __func__); + if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { + // encoder-only model + if (llama_encode(ctx, batch) < 0) { + fprintf(stderr, "%s : failed to encode\n", __func__); + } + } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { + // decoder-only model + if (llama_decode(ctx, batch) < 0) { + fprintf(stderr, "%s : failed to decode\n", __func__); + } } for (int i = 0; i < batch.n_tokens; i++) { @@ -45,11 +56,22 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu continue; } - // try to get sequence embeddings - supported only when pooling_type is not NONE - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); + const float * embd = nullptr; + int embd_pos = 0; + + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + // try to get token embeddings + embd = llama_get_embeddings_ith(ctx, i); + embd_pos = i; + GGML_ASSERT(embd != NULL && "failed to get token embeddings"); + } else { + // try to get sequence embeddings - supported only when pooling_type is not NONE + embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + embd_pos = batch.seq_id[i][0]; + GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); + } - float * out = output + batch.seq_id[i][0] * n_embd; + float * out = output + embd_pos * n_embd; llama_embd_normalize(embd, out, n_embd, embd_norm); } } @@ -79,11 +101,11 @@ int main(int argc, char ** argv) { llama_backend_init(); llama_numa_init(params.numa); - llama_model * model; - llama_context * ctx; - // load the model - std::tie(model, ctx) = llama_init_from_gpt_params(params); + llama_init_result llama_init = llama_init_from_gpt_params(params); + + llama_model * model = llama_init.model; + llama_context * ctx = llama_init.context; if (model == NULL) { fprintf(stderr, "%s: error: unable to load model\n", __func__); return 1; @@ -93,8 +115,9 @@ int main(int argc, char ** argv) { const int n_ctx = llama_n_ctx(ctx); const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); - if (pooling_type == LLAMA_POOLING_TYPE_NONE) { - fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__); + + if (llama_model_has_encoder(model) && llama_model_has_decoder(model)) { + fprintf(stderr, "%s: error: computing embeddings in encoder-decoder models is not supported\n", __func__); return 1; } @@ -153,13 +176,23 @@ int main(int argc, char ** argv) { const int n_prompts = prompts.size(); struct llama_batch batch = llama_batch_init(n_batch, 0, 1); + // count number of embeddings + int n_embd_count = 0; + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + for (int k = 0; k < n_prompts; k++) { + n_embd_count += inputs[k].size(); + } + } else { + n_embd_count = n_prompts; + } + // allocate output const int n_embd = llama_n_embd(model); - std::vector<float> embeddings(n_prompts * n_embd, 0); + std::vector<float> embeddings(n_embd_count * n_embd, 0); float * emb = embeddings.data(); // break into batches - int p = 0; // number of prompts processed already + int e = 0; // number of embeddings already stored int s = 0; // number of prompts in current batch for (int k = 0; k < n_prompts; k++) { // clamp to n_batch tokens @@ -169,11 +202,11 @@ int main(int argc, char ** argv) { // encode if at capacity if (batch.n_tokens + n_toks > n_batch) { - float * out = emb + p * n_embd; + float * out = emb + e * n_embd; batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); - llama_batch_clear(batch); - p += s; + e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s; s = 0; + llama_batch_clear(batch); } // add to batch @@ -182,39 +215,62 @@ int main(int argc, char ** argv) { } // final batch - float * out = emb + p * n_embd; + float * out = emb + e * n_embd; batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); if (params.embd_out.empty()) { - // print the first part of the embeddings or for a single prompt, the full embedding fprintf(stdout, "\n"); - for (int j = 0; j < n_prompts; j++) { - fprintf(stdout, "embedding %d: ", j); - for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) { - if (params.embd_normalize == 0) { - fprintf(stdout, "%6.0f ", emb[j * n_embd + i]); - } else { - fprintf(stdout, "%9.6f ", emb[j * n_embd + i]); + + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + for (int j = 0; j < n_embd_count; j++) { + fprintf(stdout, "embedding %d: ", j); + for (int i = 0; i < std::min(3, n_embd); i++) { + if (params.embd_normalize == 0) { + fprintf(stdout, "%6.0f ", emb[j * n_embd + i]); + } else { + fprintf(stdout, "%9.6f ", emb[j * n_embd + i]); + } + } + fprintf(stdout, " ... "); + for (int i = n_embd - 3; i < n_embd; i++) { + if (params.embd_normalize == 0) { + fprintf(stdout, "%6.0f ", emb[j * n_embd + i]); + } else { + fprintf(stdout, "%9.6f ", emb[j * n_embd + i]); + } } + fprintf(stdout, "\n"); } - fprintf(stdout, "\n"); - } - - // print cosine similarity matrix - if (n_prompts > 1) { - fprintf(stdout, "\n"); - printf("cosine similarity matrix:\n\n"); - for (int i = 0; i < n_prompts; i++) { - fprintf(stdout, "%6.6s ", prompts[i].c_str()); + } else { + // print the first part of the embeddings or for a single prompt, the full embedding + for (int j = 0; j < n_prompts; j++) { + fprintf(stdout, "embedding %d: ", j); + for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) { + if (params.embd_normalize == 0) { + fprintf(stdout, "%6.0f ", emb[j * n_embd + i]); + } else { + fprintf(stdout, "%9.6f ", emb[j * n_embd + i]); + } + } + fprintf(stdout, "\n"); } - fprintf(stdout, "\n"); - for (int i = 0; i < n_prompts; i++) { - for (int j = 0; j < n_prompts; j++) { - float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); - fprintf(stdout, "%6.2f ", sim); + + // print cosine similarity matrix + if (n_prompts > 1) { + fprintf(stdout, "\n"); + printf("cosine similarity matrix:\n\n"); + for (int i = 0; i < n_prompts; i++) { + fprintf(stdout, "%6.6s ", prompts[i].c_str()); } - fprintf(stdout, "%1.10s", prompts[i].c_str()); fprintf(stdout, "\n"); + for (int i = 0; i < n_prompts; i++) { + for (int j = 0; j < n_prompts; j++) { + float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); + fprintf(stdout, "%6.2f ", sim); + } + fprintf(stdout, "%1.10s", prompts[i].c_str()); + fprintf(stdout, "\n"); + } } } } @@ -233,23 +289,23 @@ int main(int argc, char ** argv) { } fprintf(stdout, notArray ? "]\n }" : "]"); j++; - if (j < n_prompts) fprintf(stdout, notArray ? ",\n" : ","); else break; + if (j < n_embd_count) fprintf(stdout, notArray ? ",\n" : ","); else break; } fprintf(stdout, notArray ? "\n ]" : "]\n"); if (params.embd_out == "json+" && n_prompts > 1) { fprintf(stdout, ",\n \"cosineSimilarity\": [\n"); - for (int i = 0;;) { // at least two iteration (n_prompts > 1) + for (int i = 0;;) { // at least two iteration (n_embd_count > 1) fprintf(stdout, " ["); - for (int j = 0;;) { // at least two iteration (n_prompts > 1) + for (int j = 0;;) { // at least two iteration (n_embd_count > 1) float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); fprintf(stdout, "%6.2f", sim); j++; - if (j < n_prompts) fprintf(stdout, ", "); else break; + if (j < n_embd_count) fprintf(stdout, ", "); else break; } fprintf(stdout, " ]"); i++; - if (i < n_prompts) fprintf(stdout, ",\n"); else break; + if (i < n_embd_count) fprintf(stdout, ",\n"); else break; } fprintf(stdout, "\n ]"); } |