summaryrefslogtreecommitdiff
path: root/examples/speculative/speculative.cpp
diff options
context:
space:
mode:
authorKerfuffle <44031344+KerfuffleV2@users.noreply.github.com>2023-10-11 13:35:46 -0600
committerGitHub <noreply@github.com>2023-10-11 22:35:46 +0300
commit70c29da118cdb02bfcbd0376c32b5b2236e48e48 (patch)
tree9ba08e6a18d60e24b580d58b57f9c2b7a8848f3d /examples/speculative/speculative.cpp
parent8c70a5ff25964f0a81e20d142a2f5ac5baff22fc (diff)
common : fix mirostat state when using multiple sequences (#3543)
* Fix mirostat state when using multiple sequences * Fix mirostat by completely refactoring sampling! * Try to fix zig build. * Export function to fetch/create default sampler states Code formatting cleanups and add some comments Silence a warning about id not being used when logging is disabled * Apply some renaming suggestions. Fix comments that were out of sync with the pull. * Use more consistant naming convention for sampling contexts
Diffstat (limited to 'examples/speculative/speculative.cpp')
-rw-r--r--examples/speculative/speculative.cpp12
1 files changed, 10 insertions, 2 deletions
diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp
index 75a2e5e2..018dbf9a 100644
--- a/examples/speculative/speculative.cpp
+++ b/examples/speculative/speculative.cpp
@@ -125,6 +125,8 @@ int main(int argc, char ** argv) {
grammar_tgt = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
+ llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar_tgt);
+
const auto t_dec_start = ggml_time_us();
while (true) {
@@ -134,7 +136,7 @@ int main(int argc, char ** argv) {
while (true) {
// sample from the target model
- llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
+ llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, i_dft);
// remember which tokens were sampled - used for repetition penalties during sampling
last_tokens.erase(last_tokens.begin());
@@ -211,7 +213,13 @@ int main(int argc, char ** argv) {
if (grammar_dft) {
llama_grammar_free(grammar_dft);
}
- grammar_dft = llama_grammar_copy(grammar_tgt);
+ // Note: Hardcoded to sequence id 0, if this ever supports parallel generation
+ // that will need to change.
+ auto it = ctx_sampling.sequence_contexts.find(0);
+ GGML_ASSERT(it != ctx_sampling.sequence_contexts.end());
+ // This is necessary because each sequence id in sequence_contexts
+ // uses a copy of the original grammar.
+ grammar_dft = llama_grammar_copy(it->second.grammar);
LOG("copied target grammar to draft grammar\n");
}