From d1031cf49c3b958b915fd558e23453471c29ac33 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 20 Oct 2023 21:07:23 +0300 Subject: sampling : refactor init to use llama_sampling_params (#3696) * sampling : refactor init to use llama_sampling_params * llama : combine repetition, frequency and presence penalties in 1 call * examples : remove embd-input and gptneox-wip * sampling : rename penalty params + reduce size of "prev" vector * sampling : add llama_sampling_print helper * sampling : hide prev behind API and apply #3661 ggml-ci --- examples/parallel/parallel.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'examples/parallel/parallel.cpp') diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 69f9526a..eb64adef 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -157,7 +157,7 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < clients.size(); ++i) { auto & client = clients[i]; client.id = i; - client.ctx_sampling = llama_sampling_init(params); + client.ctx_sampling = llama_sampling_init(params.sparams); } std::vector tokens_system; @@ -330,7 +330,7 @@ int main(int argc, char ** argv) { const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i); - llama_sampling_accept(client.ctx_sampling, ctx, id); + llama_sampling_accept(client.ctx_sampling, ctx, id, true); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients -- cgit v1.2.3