From 3f111ad7bbb2d4f721332f9b2b344e48b3bbf9aa Mon Sep 17 00:00:00 2001 From: firecoperana Date: Thu, 19 Jun 2025 02:24:53 -0500 Subject: add dry sampler (#513) * add dry sampler * use vocab instead of model in dry_init function * fix compile error for build test --------- Co-authored-by: firecoperana --- examples/infill/infill.cpp | 2 +- examples/llava/llava-cli.cpp | 2 +- examples/llava/minicpmv-cli.cpp | 2 +- examples/lookahead/lookahead.cpp | 2 +- examples/lookup/lookup.cpp | 2 +- examples/main/main.cpp | 2 +- examples/parallel/parallel.cpp | 2 +- examples/rpc/CMakeLists.txt | 8 +++++- examples/server/CMakeLists.txt | 8 +++++- examples/server/server.cpp | 54 +++++++++++++++++++++++++++++++++++- examples/speculative/speculative.cpp | 4 +-- 11 files changed, 76 insertions(+), 12 deletions(-) (limited to 'examples') diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 92d630b1..d3c3ad5a 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -349,7 +349,7 @@ int main(int argc, char ** argv) { std::vector embd; - struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), sparams); while (n_remain != 0 || params.interactive) { // predict diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 8c7dd2ae..889a6222 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ LOG_TEE("\n"); - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(ctx_llava->model),params->sparams); if (!ctx_sampling) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index f951b57b..022508a2 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -218,7 +218,7 @@ static struct llama_sampling_context * llama_init(struct llava_context * ctx_lla LOG_TEE("\n"); - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(ctx_llava->model),params->sparams); return ctx_sampling; } diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 81cf1629..b817be2d 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -118,7 +118,7 @@ int main(int argc, char ** argv) { llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); // target model sampling context - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), params.sparams); // verification n-grams std::vector ngrams_cur(G); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index d53a9828..1fff4f74 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -106,7 +106,7 @@ int main(int argc, char ** argv){ bool has_eos = false; - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), params.sparams); std::vector draft; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6e0635a6..de736f08 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -531,7 +531,7 @@ int main(int argc, char ** argv) { antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true)); } - struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), sparams); if (!ctx_sampling) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 621a1c95..ee614f84 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -161,7 +161,7 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < clients.size(); ++i) { auto & client = clients[i]; client.id = i; - client.ctx_sampling = llama_sampling_init(params.sparams); + client.ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), params.sparams); } std::vector tokens_system; diff --git a/examples/rpc/CMakeLists.txt b/examples/rpc/CMakeLists.txt index 41b22863..815636fe 100644 --- a/examples/rpc/CMakeLists.txt +++ b/examples/rpc/CMakeLists.txt @@ -1,4 +1,10 @@ set(TARGET rpc-server) add_executable(${TARGET} rpc-server.cpp) target_link_libraries(${TARGET} PRIVATE ggml) -target_compile_features(${TARGET} PRIVATE cxx_std_17) \ No newline at end of file +target_compile_features(${TARGET} PRIVATE cxx_std_17) +if (MSVC) + target_link_options(${TARGET} PRIVATE + $<$:/STACK:20971520,1048576 > + $<$:/STACK:20971520,1048576> + ) +endif() \ No newline at end of file diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index 33fd25f0..20ddc5c5 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -37,7 +37,13 @@ install(TARGETS ${TARGET} RUNTIME) target_compile_definitions(${TARGET} PRIVATE SERVER_VERBOSE=$ ) - +if (MSVC) + target_link_options(${TARGET} PRIVATE + $<$:/STACK:20971520,1048576 > + $<$:/STACK:20971520,1048576> + ) +endif() +# target_link_libraries(${TARGET} PRIVATE "/STACK:104857600") target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT}) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 31f7383e..563570ad 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -977,6 +977,10 @@ struct server_context { slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); + slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier); + slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base); + slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length); + slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n); slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); @@ -987,6 +991,42 @@ struct server_context { slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); + if (slot.sparams.penalty_last_n < -1) { + throw std::runtime_error("Error: repeat_last_n must be >= -1"); + } + + if (slot.sparams.dry_penalty_last_n < -1) { + throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); + } + + if (slot.sparams.penalty_last_n == -1) { + // note: should be the slot's context and not the full context, but it's ok + slot.sparams.penalty_last_n = llama_n_ctx(ctx); + } + + if (slot.sparams.dry_penalty_last_n == -1) { + slot.sparams.dry_penalty_last_n = llama_n_ctx(ctx); + + } + if (slot.sparams.dry_base < 1.0f) + { + slot.sparams.dry_base = default_sparams.dry_base; + } + + // sequence breakers for DRY + { + // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format + // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 + + if (data.contains("dry_sequence_breakers")) { + slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); + if (slot.sparams.dry_sequence_breakers.empty()) { + send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + } + // process "json_schema" and "grammar" if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST); @@ -1156,7 +1196,7 @@ struct server_context { if (slot.ctx_sampling != nullptr) { llama_sampling_free(slot.ctx_sampling); } - slot.ctx_sampling = llama_sampling_init(slot.sparams); + slot.ctx_sampling = llama_sampling_init(llama_get_model_vocab(model),slot.sparams); if (slot.ctx_sampling == nullptr) { // for now, the only error that may happen here is invalid grammar send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); @@ -1405,6 +1445,11 @@ struct server_context { {"frequency_penalty", slot.sparams.penalty_freq}, {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens}, {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens}, + {"dry_multiplier", slot.sparams.dry_multiplier}, + {"dry_base", slot.sparams.dry_base}, + {"dry_allowed_length", slot.sparams.dry_allowed_length}, + {"dry_penalty_last_n", slot.sparams.dry_penalty_last_n}, + {"dry_sequence_breakers", slot.sparams.dry_sequence_breakers}, {"mirostat", slot.sparams.mirostat}, {"mirostat_tau", slot.sparams.mirostat_tau}, {"mirostat_eta", slot.sparams.mirostat_eta}, @@ -2337,6 +2382,13 @@ struct server_context { slot.command = SLOT_COMMAND_NONE; GGML_ASSERT(batch.n_tokens > 0); + llama_sampling_reset(slot.ctx_sampling); + for (int i = 0; i < slot.n_prompt_tokens; ++i) { + llama_token id = slot.prompt_tokens[i]; + if (id != LLAMA_TOKEN_NULL) { + llama_sampling_accept(slot.ctx_sampling, ctx, id, false); + } + } // extract the logits only for the last token batch.logits[batch.n_tokens - 1] = true; diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index b051a18f..3063f0b6 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -179,7 +179,7 @@ int main(int argc, char ** argv) { bool has_eos = false; // target model sampling context - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model_tgt), params.sparams); // draft sequence data std::vector drafts(n_seq_dft); @@ -190,7 +190,7 @@ int main(int argc, char ** argv) { } for (int s = 0; s < n_seq_dft; ++s) { - drafts[s].ctx_sampling = llama_sampling_init(params.sparams); + drafts[s].ctx_sampling = llama_sampling_init(llama_get_model_vocab(model_dft), params.sparams); } llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); -- cgit v1.2.3