diff options
author | Douglas Hanley <thesecretaryofwar@gmail.com> | 2024-06-21 00:38:22 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-21 08:38:22 +0300 |
commit | 80ea089d771f0c2d97afa8bead80ded412f600d7 (patch) | |
tree | 25c04a967b5913ffdc00d1a851dcfbeb9ab37a37 /examples | |
parent | 0e64591e8290037db6412665a56354b789a0597e (diff) |
llama : allow pooled embeddings on any model (#7477)
* create append_pooling operation; allow to specify attention_type; add last token pooling; update examples
* find result_norm/result_embd tensors properly; update output allocation logic
* only use embd output for pooling_type NONE
* get rid of old causal_attn accessor
* take out attention_type; add in llama_set_embeddings
* bypass logits when doing non-NONE pooling
Diffstat (limited to 'examples')
-rw-r--r-- | examples/embedding/embedding.cpp | 21 | ||||
-rw-r--r-- | examples/gritlm/gritlm.cpp | 6 | ||||
-rw-r--r-- | examples/retrieval/retrieval.cpp | 13 |
3 files changed, 25 insertions, 15 deletions
diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 244751e0..b4b73c01 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -17,9 +17,10 @@ static std::vector<std::string> split_lines(const std::string & s) { return lines; } -static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) { - for (size_t i = 0; i < tokens.size(); i++) { - llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); +static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) { + size_t n_tokens = tokens.size(); + for (size_t i = 0; i < n_tokens; i++) { + llama_batch_add(batch, tokens[i], i, { seq_id }, true); } } @@ -40,13 +41,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu // try to get sequence embeddings - supported only when pooling_type is not NONE const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - if (embd == NULL) { - fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i); - continue; - } - } + GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); float * out = output + batch.seq_id[i][0] * n_embd; //TODO: I would also add a parameter here to enable normalization or not. @@ -97,6 +92,12 @@ int main(int argc, char ** argv) { const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__); + return 1; + } + if (n_ctx > n_ctx_train) { fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 21351579..2c61c2e1 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -44,6 +44,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve // clear previous kv_cache values (irrelevant for embeddings) llama_kv_cache_clear(ctx); + llama_set_embeddings(ctx, true); llama_set_causal_attn(ctx, false); // run model @@ -98,7 +99,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo llama_token eos_token = llama_token_eos(mdl); llama_kv_cache_clear(ctx); + llama_set_embeddings(ctx, false); llama_set_causal_attn(ctx, true); + llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true); @@ -166,8 +169,7 @@ int main(int argc, char * argv[]) { llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams); - // create new context - set to embedding mode - cparams.embeddings = true; + // create generation context llama_context * ctx = llama_new_context_with_model(mdl, cparams); // ### Embedding/Representation ### diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 55b7b2f7..eb89d16d 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -73,9 +73,10 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz return chunks; } -static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) { - for (size_t i = 0; i < tokens.size(); i++) { - llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); +static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, llama_seq_id seq_id) { + size_t n_tokens = tokens.size(); + for (size_t i = 0; i < n_tokens; i++) { + llama_batch_add(batch, tokens[i], i, { seq_id }, true); } } @@ -160,6 +161,12 @@ int main(int argc, char ** argv) { const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__); + return 1; + } + if (n_ctx > n_ctx_train) { fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx); |