summaryrefslogtreecommitdiff
path: root/examples/embedding/embedding.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/embedding/embedding.cpp')
-rw-r--r--examples/embedding/embedding.cpp21
1 files changed, 11 insertions, 10 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp
index 244751e0..b4b73c01 100644
--- a/examples/embedding/embedding.cpp
+++ b/examples/embedding/embedding.cpp
@@ -17,9 +17,10 @@ static std::vector<std::string> split_lines(const std::string & s) {
return lines;
}
-static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
- for (size_t i = 0; i < tokens.size(); i++) {
- llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
+static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) {
+ size_t n_tokens = tokens.size();
+ for (size_t i = 0; i < n_tokens; i++) {
+ llama_batch_add(batch, tokens[i], i, { seq_id }, true);
}
}
@@ -40,13 +41,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
// try to get sequence embeddings - supported only when pooling_type is not NONE
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
- if (embd == NULL) {
- embd = llama_get_embeddings_ith(ctx, i);
- if (embd == NULL) {
- fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
- continue;
- }
- }
+ GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
float * out = output + batch.seq_id[i][0] * n_embd;
//TODO: I would also add a parameter here to enable normalization or not.
@@ -97,6 +92,12 @@ int main(int argc, char ** argv) {
const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
+ const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
+ if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
+ fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
+ return 1;
+ }
+
if (n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, n_ctx);