summaryrefslogtreecommitdiff
path: root/examples/parallel/parallel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/parallel/parallel.cpp')
-rw-r--r--examples/parallel/parallel.cpp6
1 files changed, 4 insertions, 2 deletions
diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp
index 04f1e45b..63ddcd8e 100644
--- a/examples/parallel/parallel.cpp
+++ b/examples/parallel/parallel.cpp
@@ -125,6 +125,8 @@ int main(int argc, char ** argv) {
params.logits_all = true;
std::tie(model, ctx) = llama_init_from_gpt_params(params);
+ llama_sampling_context ctx_sampling = llama_sampling_context_init(params, NULL);
+
// load the prompts from an external file if there are any
if (params.prompt.empty()) {
printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
@@ -339,7 +341,7 @@ int main(int argc, char ** argv) {
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
- const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i);
+ const llama_token id = llama_sampling_sample(ctx, NULL, ctx_sampling, client.tokens_prev, candidates, client.i_batch - i, client.seq_id);
if (client.n_decoded == 1) {
// start measuring generation time after the first token to make sure all concurrent clients
@@ -384,7 +386,7 @@ int main(int argc, char ** argv) {
n_total_prompt += client.n_prompt;
n_total_gen += client.n_decoded;
-
+ llama_sampling_context_reset(ctx_sampling, client.seq_id);
client.seq_id = -1;
}