From 70c29da118cdb02bfcbd0376c32b5b2236e48e48 Mon Sep 17 00:00:00 2001 From: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> Date: Wed, 11 Oct 2023 13:35:46 -0600 Subject: common : fix mirostat state when using multiple sequences (#3543) * Fix mirostat state when using multiple sequences * Fix mirostat by completely refactoring sampling! * Try to fix zig build. * Export function to fetch/create default sampler states Code formatting cleanups and add some comments Silence a warning about id not being used when logging is disabled * Apply some renaming suggestions. Fix comments that were out of sync with the pull. * Use more consistant naming convention for sampling contexts --- examples/speculative/speculative.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) (limited to 'examples/speculative/speculative.cpp') diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 75a2e5e2..018dbf9a 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -125,6 +125,8 @@ int main(int argc, char ** argv) { grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); } + llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar_tgt); + const auto t_dec_start = ggml_time_us(); while (true) { @@ -134,7 +136,7 @@ int main(int argc, char ** argv) { while (true) { // sample from the target model - llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); + llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, i_dft); // remember which tokens were sampled - used for repetition penalties during sampling last_tokens.erase(last_tokens.begin()); @@ -211,7 +213,13 @@ int main(int argc, char ** argv) { if (grammar_dft) { llama_grammar_free(grammar_dft); } - grammar_dft = llama_grammar_copy(grammar_tgt); + // Note: Hardcoded to sequence id 0, if this ever supports parallel generation + // that will need to change. + auto it = ctx_sampling.sequence_contexts.find(0); + GGML_ASSERT(it != ctx_sampling.sequence_contexts.end()); + // This is necessary because each sequence id in sequence_contexts + // uses a copy of the original grammar. + grammar_dft = llama_grammar_copy(it->second.grammar); LOG("copied target grammar to draft grammar\n"); } -- cgit v1.2.3