diff options
Diffstat (limited to 'examples/gritlm/gritlm.cpp')
-rw-r--r-- | examples/gritlm/gritlm.cpp | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 21351579..2c61c2e1 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -44,6 +44,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve // clear previous kv_cache values (irrelevant for embeddings) llama_kv_cache_clear(ctx); + llama_set_embeddings(ctx, true); llama_set_causal_attn(ctx, false); // run model @@ -98,7 +99,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo llama_token eos_token = llama_token_eos(mdl); llama_kv_cache_clear(ctx); + llama_set_embeddings(ctx, false); llama_set_causal_attn(ctx, true); + llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true); @@ -166,8 +169,7 @@ int main(int argc, char * argv[]) { llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams); - // create new context - set to embedding mode - cparams.embeddings = true; + // create generation context llama_context * ctx = llama_new_context_with_model(mdl, cparams); // ### Embedding/Representation ### |