diff options
author | Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> | 2023-10-11 13:35:46 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-11 22:35:46 +0300 |
commit | 70c29da118cdb02bfcbd0376c32b5b2236e48e48 (patch) | |
tree | 9ba08e6a18d60e24b580d58b57f9c2b7a8848f3d /examples/parallel/parallel.cpp | |
parent | 8c70a5ff25964f0a81e20d142a2f5ac5baff22fc (diff) |
common : fix mirostat state when using multiple sequences (#3543)
* Fix mirostat state when using multiple sequences
* Fix mirostat by completely refactoring sampling!
* Try to fix zig build.
* Export function to fetch/create default sampler states
Code formatting cleanups and add some comments
Silence a warning about id not being used when logging is disabled
* Apply some renaming suggestions.
Fix comments that were out of sync with the pull.
* Use more consistant naming convention for sampling contexts
Diffstat (limited to 'examples/parallel/parallel.cpp')
-rw-r--r-- | examples/parallel/parallel.cpp | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 04f1e45b..63ddcd8e 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -125,6 +125,8 @@ int main(int argc, char ** argv) { params.logits_all = true; std::tie(model, ctx) = llama_init_from_gpt_params(params); + llama_sampling_context ctx_sampling = llama_sampling_context_init(params, NULL); + // load the prompts from an external file if there are any if (params.prompt.empty()) { printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n"); @@ -339,7 +341,7 @@ int main(int argc, char ** argv) { //printf("client %d, seq %d, token %d, pos %d, batch %d\n", // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); - const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i); + const llama_token id = llama_sampling_sample(ctx, NULL, ctx_sampling, client.tokens_prev, candidates, client.i_batch - i, client.seq_id); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients @@ -384,7 +386,7 @@ int main(int argc, char ** argv) { n_total_prompt += client.n_prompt; n_total_gen += client.n_decoded; - + llama_sampling_context_reset(ctx_sampling, client.seq_id); client.seq_id = -1; } |