summaryrefslogtreecommitdiff
path: root/examples/gritlm/gritlm.cpp
diff options
context:
space:
mode:
authorDouglas Hanley <thesecretaryofwar@gmail.com>2024-06-21 00:38:22 -0500
committerGitHub <noreply@github.com>2024-06-21 08:38:22 +0300
commit80ea089d771f0c2d97afa8bead80ded412f600d7 (patch)
tree25c04a967b5913ffdc00d1a851dcfbeb9ab37a37 /examples/gritlm/gritlm.cpp
parent0e64591e8290037db6412665a56354b789a0597e (diff)
llama : allow pooled embeddings on any model (#7477)
* create append_pooling operation; allow to specify attention_type; add last token pooling; update examples * find result_norm/result_embd tensors properly; update output allocation logic * only use embd output for pooling_type NONE * get rid of old causal_attn accessor * take out attention_type; add in llama_set_embeddings * bypass logits when doing non-NONE pooling
Diffstat (limited to 'examples/gritlm/gritlm.cpp')
-rw-r--r--examples/gritlm/gritlm.cpp6
1 files changed, 4 insertions, 2 deletions
diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp
index 21351579..2c61c2e1 100644
--- a/examples/gritlm/gritlm.cpp
+++ b/examples/gritlm/gritlm.cpp
@@ -44,6 +44,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
+ llama_set_embeddings(ctx, true);
llama_set_causal_attn(ctx, false);
// run model
@@ -98,7 +99,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
llama_token eos_token = llama_token_eos(mdl);
llama_kv_cache_clear(ctx);
+ llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true);
+
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
@@ -166,8 +169,7 @@ int main(int argc, char * argv[]) {
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
- // create new context - set to embedding mode
- cparams.embeddings = true;
+ // create generation context
llama_context * ctx = llama_new_context_with_model(mdl, cparams);
// ### Embedding/Representation ###