diff options
author | firecoperana <xuqiaowei1124@gmail.com> | 2025-06-19 02:24:53 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-06-19 10:24:53 +0300 |
commit | 3f111ad7bbb2d4f721332f9b2b344e48b3bbf9aa (patch) | |
tree | a3a17ee74e0436253e17f0d322320ed554d34b0a | |
parent | c5368148cf3af7a3694e0eb03d24a08326c01d12 (diff) |
add dry sampler (#513)
* add dry sampler
* use vocab instead of model in dry_init function
* fix compile error for build test
---------
Co-authored-by: firecoperana <firecoperana>
-rw-r--r-- | common/common.cpp | 50 | ||||
-rw-r--r-- | common/sampling.cpp | 69 | ||||
-rw-r--r-- | common/sampling.h | 17 | ||||
-rw-r--r-- | examples/infill/infill.cpp | 2 | ||||
-rw-r--r-- | examples/llava/llava-cli.cpp | 2 | ||||
-rw-r--r-- | examples/llava/minicpmv-cli.cpp | 2 | ||||
-rw-r--r-- | examples/lookahead/lookahead.cpp | 2 | ||||
-rw-r--r-- | examples/lookup/lookup.cpp | 2 | ||||
-rw-r--r-- | examples/main/main.cpp | 2 | ||||
-rw-r--r-- | examples/parallel/parallel.cpp | 2 | ||||
-rw-r--r-- | examples/rpc/CMakeLists.txt | 8 | ||||
-rw-r--r-- | examples/server/CMakeLists.txt | 8 | ||||
-rw-r--r-- | examples/server/server.cpp | 54 | ||||
-rw-r--r-- | examples/speculative/speculative.cpp | 4 | ||||
-rw-r--r-- | include/llama.h | 27 | ||||
-rw-r--r-- | src/llama-impl.h | 114 | ||||
-rw-r--r-- | src/llama-sampling.cpp | 311 | ||||
-rw-r--r-- | src/llama-sampling.h | 32 | ||||
-rw-r--r-- | src/llama-vocab.cpp | 19 | ||||
-rw-r--r-- | src/llama-vocab.h | 7 | ||||
-rw-r--r-- | src/llama.cpp | 45 |
21 files changed, 743 insertions, 36 deletions
diff --git a/common/common.cpp b/common/common.cpp index 20e583fc..208d4511 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -666,6 +666,47 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa sparams.top_n_sigma = std::stof(argv[i]); return true; } + + if (arg == "--dry-multiplier") { + CHECK_ARG + sparams.dry_multiplier = std::stof(argv[i]); + return true; + } + if (arg == "--dry-base") { + CHECK_ARG + sparams.dry_base = std::stof(argv[i]); + return true; + } + if (arg == "--dry-allowed-length") { + CHECK_ARG + sparams.dry_allowed_length = std::stof(argv[i]); + return true; + } + if (arg == "--dry-penalty-last-n") { + CHECK_ARG + sparams.dry_penalty_last_n = std::stof(argv[i]); + return true; + } + if (arg == "--dry-sequence-breaker") { + CHECK_ARG + static bool defaults_cleared = false; + + if (!defaults_cleared) { + params.sparams.dry_sequence_breakers.clear(); + defaults_cleared = true; + } + std::string value= std::string(argv[i]); + if (value == "none") { + params.sparams.dry_sequence_breakers.clear(); + } + else { + for (size_t i; i < value.size(); i++) + { + params.sparams.dry_sequence_breakers.emplace_back(""+value[i]); + } + } + return true; + } if (arg == "--cfg-negative-prompt") { CHECK_ARG sparams.cfg_negative_prompt = argv[i]; @@ -2326,6 +2367,11 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; } + if (params.sparams.dry_penalty_last_n == -1) { + LOG("%s: setting dry_penalty_last_n to ctx_size = %d\n", __func__, llama_n_ctx(lctx)); + params.sparams.dry_penalty_last_n = llama_n_ctx(lctx); + } + if (params.warmup) { LOG("warming up the model with an empty run\n"); @@ -3389,6 +3435,10 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); + fprintf(stream, "dry_allowed_length: %d # default: 2\n", sparams.dry_allowed_length); + fprintf(stream, "dry_base: %.2f # default: 1.75\n", sparams.dry_base); + fprintf(stream, "dry_multiplier: %.1f # default: 0.0\n", sparams.dry_multiplier); + fprintf(stream, "dry_penalty_last_n: %d # default: -1 (0 = disable, -1 = context size)\n", sparams.dry_penalty_last_n); fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq); diff --git a/common/sampling.cpp b/common/sampling.cpp index 4db12ee1..4b983e5f 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,8 +1,9 @@ #define LLAMA_API_INTERNAL #include "sampling.h" +#include "llama-vocab.h" #include <random> -struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { +struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vocab, const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); result->params = params; @@ -36,13 +37,32 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ } result->grammar = grammar; } - result->prev.resize(params.n_prev); result->n_valid = 0; + // init DRY + for (const auto& cnstr : params.samplers_sequence) + { + switch (cnstr) + { + case llama_sampler_type::DRY: + { + std::vector<const char*> c_breakers; + c_breakers.reserve(params.dry_sequence_breakers.size()); + for (const auto& str : params.dry_sequence_breakers) + { + c_breakers.push_back(str.c_str()); + } + result->smpl=llama_sampler_init_dry(vocab, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()); + + break; + } + default: + break; + } + } llama_sampling_set_rng_seed(result, params.seed); - return result; } @@ -50,7 +70,8 @@ void llama_sampling_free(struct llama_sampling_context * ctx) { if (ctx->grammar != NULL) { llama_grammar_free(ctx->grammar); } - + if (ctx->smpl !=NULL) + llama_sampler_dry_free(ctx->smpl); delete ctx; } @@ -75,6 +96,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) { std::fill(ctx->prev.begin(), ctx->prev.end(), 0); ctx->cur.clear(); ctx->n_valid = 0; + llama_sampler_dry_reset(ctx->smpl); } void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { @@ -95,6 +117,7 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds } dst->prev = src->prev; + dst->smpl = llama_sampler_dry_clone(src->smpl); } llama_token llama_sampling_last(llama_sampling_context * ctx) { @@ -149,6 +172,7 @@ std::string llama_sampling_order_print(const llama_sampling_params & params) { std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) { switch (sampler_type) { + case llama_sampler_type::DRY: return "dry"; case llama_sampler_type::TOP_K: return "top_k"; case llama_sampler_type::TFS_Z: return "tfs_z"; case llama_sampler_type::TYPICAL_P: return "typical_p"; @@ -163,6 +187,7 @@ std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) { std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) { std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map { + {"dry", llama_sampler_type::DRY}, {"top_k", llama_sampler_type::TOP_K}, {"top_p", llama_sampler_type::TOP_P}, {"typical_p", llama_sampler_type::TYPICAL_P}, @@ -176,6 +201,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto // since samplers names are written multiple ways // make it ready for both system names and input names std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map { + {"dry", llama_sampler_type::DRY}, {"top-k", llama_sampler_type::TOP_K}, {"top-p", llama_sampler_type::TOP_P}, {"nucleus", llama_sampler_type::TOP_P}, @@ -215,6 +241,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string) { std::unordered_map<char, llama_sampler_type> sampler_name_map { + {'d', llama_sampler_type::DRY}, {'k', llama_sampler_type::TOP_K}, {'p', llama_sampler_type::TOP_P}, {'y', llama_sampler_type::TYPICAL_P}, @@ -238,25 +265,28 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin // no reasons to expose this function in header static void sampler_queue( - struct llama_context * ctx_main, - const llama_sampling_params & params, - llama_token_data_array & cur_p, - size_t min_keep) { - const float temp = params.temp; - const float dynatemp_range = params.dynatemp_range; + struct llama_context* ctx_main, + const llama_sampling_params& params, + llama_sampling_context * ctx_sampling, + llama_token_data_array& cur_p, + size_t min_keep) { + const float temp = params.temp; + const float dynatemp_range = params.dynatemp_range; const float dynatemp_exponent = params.dynatemp_exponent; - const int32_t top_k = params.top_k; - const float top_p = params.top_p; - const float min_p = params.min_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; - const float xtc_probability = params.xtc_probability; - const float xtc_threshold = params.xtc_threshold; - const float top_n_sigma = params.top_n_sigma; + const int32_t top_k = params.top_k; + const float top_p = params.top_p; + const float min_p = params.min_p; + const float tfs_z = params.tfs_z; + const float typical_p = params.typical_p; + const float xtc_probability = params.xtc_probability; + const float xtc_threshold = params.xtc_threshold; + const float top_n_sigma = params.top_n_sigma; + const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence; for (auto sampler_type : samplers_sequence) { switch (sampler_type) { + case llama_sampler_type::DRY : llama_sample_dry (ctx_main, ctx_sampling->smpl, &cur_p); break; case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break; case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break; case llama_sampler_type::TYPICAL_P : llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; @@ -317,7 +347,7 @@ static llama_token llama_sampling_sample_impl( // temperature sampling size_t min_keep = std::max(1, params.min_keep); - sampler_queue(ctx_main, params, cur_p, min_keep); + sampler_queue(ctx_main, params,ctx_sampling, cur_p, min_keep); id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng); @@ -472,4 +502,5 @@ void llama_sampling_accept( if (ctx_sampling->grammar != NULL && apply_grammar) { llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id); } + llama_sampler_dry_accept(ctx_sampling->smpl, id); } diff --git a/common/sampling.h b/common/sampling.h index 4fc86595..1d5bf0b9 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -35,11 +35,16 @@ typedef struct llama_sampling_params { float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities float dynatemp_range = 0.00f; // 0.0 = disabled float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler - int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) float penalty_repeat = 1.00f; // 1.0 = disabled float penalty_freq = 0.00f; // 0.0 = disabled float penalty_present = 0.00f; // 0.0 = disabled - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float dry_multiplier = 0.0f; // 0.0 = disabled; DRY repetition penalty for tokens extending repetition: + float dry_base = 1.75f; // 0.0 = disabled; multiplier * base ^ (length of sequence before token - allowed length) + int32_t dry_allowed_length = 2; // tokens extending repetitions beyond this receive penalty + int32_t dry_penalty_last_n = -1; // how many tokens to scan for repetitions (0 = disable penalty, -1 = context size) + int32_t total_context_size = 16840; + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate float xtc_probability = 0.0f; // xtc probability @@ -48,12 +53,16 @@ typedef struct llama_sampling_params { bool penalize_nl = false; // consider newlines as a repeatable token uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context + std::vector<std::string> dry_sequence_breakers = { "\n", ":", "\"", "*" }; // default sequence breakers for DRY + std::vector<llama_sampler_type> samplers_sequence = { + llama_sampler_type::DRY, llama_sampler_type::TOP_K, llama_sampler_type::TFS_Z, llama_sampler_type::TYPICAL_P, llama_sampler_type::TOP_P, llama_sampler_type::MIN_P, + llama_sampler_type::XTC, llama_sampler_type::TOP_N_SIGMA, llama_sampler_type::TEMPERATURE }; @@ -88,6 +97,8 @@ struct llama_sampling_context { // TODO: replace with ring-buffer std::vector<llama_token> prev; std::vector<llama_token_data> cur; + llama_sampler_dry* smpl; + size_t n_valid; // Number of correct top tokens with correct probabilities. std::mt19937 rng; @@ -96,7 +107,7 @@ struct llama_sampling_context { #include "common.h" // Create a new sampling context instance. -struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params); +struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vocab, const struct llama_sampling_params & params); void llama_sampling_free(struct llama_sampling_context * ctx); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 92d630b1..d3c3ad5a 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -349,7 +349,7 @@ int main(int argc, char ** argv) { std::vector<llama_token> embd; - struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), sparams); while (n_remain != 0 || params.interactive) { // predict diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 8c7dd2ae..889a6222 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ LOG_TEE("\n"); - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(ctx_llava->model),params->sparams); if (!ctx_sampling) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index f951b57b..022508a2 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -218,7 +218,7 @@ static struct llama_sampling_context * llama_init(struct llava_context * ctx_lla LOG_TEE("\n"); - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(ctx_llava->model),params->sparams); return ctx_sampling; } diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 81cf1629..b817be2d 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -118,7 +118,7 @@ int main(int argc, char ** argv) { llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); // target model sampling context - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), params.sparams); // verification n-grams std::vector<ngram_data> ngrams_cur(G); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index d53a9828..1fff4f74 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -106,7 +106,7 @@ int main(int argc, char ** argv){ bool has_eos = false; - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), params.sparams); std::vector<llama_token> draft; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6e0635a6..de736f08 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -531,7 +531,7 @@ int main(int argc, char ** argv) { antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true)); } - struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), sparams); if (!ctx_sampling) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 621a1c95..ee614f84 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -161,7 +161,7 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < clients.size(); ++i) { auto & client = clients[i]; client.id = i; - client.ctx_sampling = llama_sampling_init(params.sparams); + client.ctx_sampling = llama_sampling_init(llama_get_model_vocab(model), params.sparams); } std::vector<llama_token> tokens_system; diff --git a/examples/rpc/CMakeLists.txt b/examples/rpc/CMakeLists.txt index 41b22863..815636fe 100644 --- a/examples/rpc/CMakeLists.txt +++ b/examples/rpc/CMakeLists.txt @@ -1,4 +1,10 @@ set(TARGET rpc-server) add_executable(${TARGET} rpc-server.cpp) target_link_libraries(${TARGET} PRIVATE ggml) -target_compile_features(${TARGET} PRIVATE cxx_std_17)
\ No newline at end of file +target_compile_features(${TARGET} PRIVATE cxx_std_17) +if (MSVC) + target_link_options(${TARGET} PRIVATE + $<$<CONFIG:DEBUG>:/STACK:20971520,1048576 > + $<$<CONFIG:RELEASE>:/STACK:20971520,1048576> + ) +endif()
\ No newline at end of file diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index 33fd25f0..20ddc5c5 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -37,7 +37,13 @@ install(TARGETS ${TARGET} RUNTIME) target_compile_definitions(${TARGET} PRIVATE SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}> ) - +if (MSVC) + target_link_options(${TARGET} PRIVATE + $<$<CONFIG:DEBUG>:/STACK:20971520,1048576 > + $<$<CONFIG:RELEASE>:/STACK:20971520,1048576> + ) +endif() +# target_link_libraries(${TARGET} PRIVATE "/STACK:104857600") target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT}) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 31f7383e..563570ad 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -977,6 +977,10 @@ struct server_context { slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat); slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq); slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present); + slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier); + slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base); + slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length); + slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n); slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat); slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau); slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta); @@ -987,6 +991,42 @@ struct server_context { slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs); slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep); + if (slot.sparams.penalty_last_n < -1) { + throw std::runtime_error("Error: repeat_last_n must be >= -1"); + } + + if (slot.sparams.dry_penalty_last_n < -1) { + throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); + } + + if (slot.sparams.penalty_last_n == -1) { + // note: should be the slot's context and not the full context, but it's ok + slot.sparams.penalty_last_n = llama_n_ctx(ctx); + } + + if (slot.sparams.dry_penalty_last_n == -1) { + slot.sparams.dry_penalty_last_n = llama_n_ctx(ctx); + + } + if (slot.sparams.dry_base < 1.0f) + { + slot.sparams.dry_base = default_sparams.dry_base; + } + + // sequence breakers for DRY + { + // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format + // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 + + if (data.contains("dry_sequence_breakers")) { + slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>()); + if (slot.sparams.dry_sequence_breakers.empty()) { + send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST); + return false; + } + } + } + // process "json_schema" and "grammar" if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST); @@ -1156,7 +1196,7 @@ struct server_context { if (slot.ctx_sampling != nullptr) { llama_sampling_free(slot.ctx_sampling); } - slot.ctx_sampling = llama_sampling_init(slot.sparams); + slot.ctx_sampling = llama_sampling_init(llama_get_model_vocab(model),slot.sparams); if (slot.ctx_sampling == nullptr) { // for now, the only error that may happen here is invalid grammar send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); @@ -1405,6 +1445,11 @@ struct server_context { {"frequency_penalty", slot.sparams.penalty_freq}, {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens}, {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens}, + {"dry_multiplier", slot.sparams.dry_multiplier}, + {"dry_base", slot.sparams.dry_base}, + {"dry_allowed_length", slot.sparams.dry_allowed_length}, + {"dry_penalty_last_n", slot.sparams.dry_penalty_last_n}, + {"dry_sequence_breakers", slot.sparams.dry_sequence_breakers}, {"mirostat", slot.sparams.mirostat}, {"mirostat_tau", slot.sparams.mirostat_tau}, {"mirostat_eta", slot.sparams.mirostat_eta}, @@ -2337,6 +2382,13 @@ struct server_context { slot.command = SLOT_COMMAND_NONE; GGML_ASSERT(batch.n_tokens > 0); + llama_sampling_reset(slot.ctx_sampling); + for (int i = 0; i < slot.n_prompt_tokens; ++i) { + llama_token id = slot.prompt_tokens[i]; + if (id != LLAMA_TOKEN_NULL) { + llama_sampling_accept(slot.ctx_sampling, ctx, id, false); + } + } // extract the logits only for the last token batch.logits[batch.n_tokens - 1] = true; diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index b051a18f..3063f0b6 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -179,7 +179,7 @@ int main(int argc, char ** argv) { bool has_eos = false; // target model sampling context - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(llama_get_model_vocab(model_tgt), params.sparams); // draft sequence data std::vector<seq_draft> drafts(n_seq_dft); @@ -190,7 +190,7 @@ int main(int argc, char ** argv) { } for (int s = 0; s < n_seq_dft; ++s) { - drafts[s].ctx_sampling = llama_sampling_init(params.sparams); + drafts[s].ctx_sampling = llama_sampling_init(llama_get_model_vocab(model_dft), params.sparams); } llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); diff --git a/include/llama.h b/include/llama.h index f1645228..bf26a55f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -40,6 +40,8 @@ #define LLAMA_DEFAULT_SEED 0xFFFFFFFF +#define LLAMA_TOKEN_NULL -1 + #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' @@ -556,6 +558,7 @@ extern "C" { LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); + LLAMA_API const struct llama_vocab* llama_get_model_vocab(const struct llama_model* model); LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); LLAMA_API int32_t llama_n_embd (const struct llama_model * model); LLAMA_API int32_t llama_n_layer (const struct llama_model * model); @@ -1222,6 +1225,30 @@ extern "C" { llama_token_data_array * candidates_p, float top_n_sigma); + /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 + LLAMA_API struct llama_sampler_dry * llama_sampler_init_dry( + const struct llama_vocab* model, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const char** seq_breakers, + size_t num_breakers); + + //LLAMA_API void llama_sample_dry(struct llama_context* ctx, llama_token_data_array* candidates_p, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers); + + void llama_sample_dry(struct llama_context* ctx, struct llama_sampler_dry* smpl, llama_token_data_array* candidates_p); + + void llama_sampler_dry_reset(struct llama_sampler_dry* smpl); + + void llama_sampler_dry_free(struct llama_sampler_dry* smpl); + + struct llama_sampler_dry* llama_sampler_dry_clone(struct llama_sampler_dry* smpl); + + void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token); + + /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. diff --git a/src/llama-impl.h b/src/llama-impl.h index a9cbe0df..a50f60cf 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -9,6 +9,7 @@ #define LLAMA_API_INTERNAL #include "llama.h" +#include <stdexcept> #ifdef __GNUC__ #ifdef __MINGW32__ @@ -20,6 +21,7 @@ #define LLAMA_ATTRIBUTE_FORMAT(...) #endif + // // logging // @@ -52,3 +54,115 @@ static void replace_all(std::string & s, const std::string & search, const std:: builder.append(s, last_pos, std::string::npos); s = std::move(builder); } + + +// the ring buffer works similarly to std::deque, but with a fixed capacity +template<typename T> +struct ring_buffer { + ring_buffer(size_t cap) : capacity(cap), data(cap) {} + + T& front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + const T& front() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + T& back() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + const T& back() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + void push_back(const T& value) { + if (capacity == 0) { + throw std::runtime_error("ring buffer: capacity is zero"); + } + + if (sz == capacity) { + // advance the start when buffer is full + first = (first + 1) % capacity; + } + else { + sz++; + } + data[pos] = value; + pos = (pos + 1) % capacity; + } + + T pop_front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + T value = data[first]; + first = (first + 1) % capacity; + sz--; + return value; + } + + //T & operator[](size_t i) { + // if (i >= sz) { + // throw std::runtime_error("ring buffer: index out of bounds"); + // } + // return data[(first + i) % capacity]; + //} + + //const T & at(size_t i) const { + // if (i >= sz) { + // throw std::runtime_error("ring buffer: index out of bounds"); + // } + // return data[(first + i) % capacity]; + //} + + const T& rat(size_t i) const { + if (i >= sz) { + throw std::runtime_error("ring buffer: index out of bounds"); + } + return data[(first + sz - i - 1) % capacity]; + } + + std::vector<T> to_vector() const { + std::vector<T> result; + result.reserve(sz); + for (size_t i = 0; i < sz; i++) { + result.push_back(data[(first + i) % capacity]); + } + return result; + } + + void clear() { + // here only reset the status of the buffer + sz = 0; + first = 0; + pos = 0; + } + + bool empty() const { + return sz == 0; + } + + size_t size() const { + return sz; + } + + size_t capacity = 0; + size_t sz = 0; + size_t first = 0; + size_t pos = 0; + std::vector<T> data; +}; diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 7a185c5b..40d9963d 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1,4 +1,6 @@ #include "llama-sampling.h" +#include "llama-vocab.h" +#include "llama-grammar.h" #include <algorithm> #include <cstring> @@ -469,7 +471,7 @@ void llama_sample_xtc_impl(struct llama_sampling * smpl, llama_token_data_array } void llama_sample_top_n_sigma_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float top_n_sigma) { - + if (top_n_sigma <= 0.0f || candidates->size < 4) { // top_n_sigma <= 0: disabled // candidates->size < 4: no point in applying the transformation for fewer than 4 logits. @@ -725,3 +727,310 @@ llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng); } + + +// DRY + +// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) +static void get_overlapping_token_sequences(const llama_vocab& vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) { + for (llama_token token_id = 0; token_id < (llama_token)vocab.n_tokens(); token_id++) { + std::string word = llama_detokenize(vocab, { token_id }, true); + if (word.find(str) != std::string::npos) { + token_sequences.emplace(token_id, std::vector<llama_token>()); + } + else { + size_t word_len = word.size(), str_len = str.size(); + size_t pos = -1; + while ((pos = word.find(str[0], pos + 1)) != std::string::npos) { + bool match = true; + size_t i; + for (i = 1; i < str_len && i + pos < word_len; ++i) { + if (word[pos + i] != str[i]) { + match = false; + break; + } + } + if (match) { + std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false); + if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) { + tokenization.resize(max_tail_len); + } + + // Ensure we don't already have a duplicate matching tokenization + auto its = token_sequences.equal_range(token_id); + bool found = false; + for (auto it = its.first; it != its.second; ++it) { + if (tokenization == it->second) { + found = true; + break; + } + } + if (!found) { + token_sequences.emplace(token_id, tokenization); + } + } + } + } + } +} + +static const char* llama_sampler_dry_name(const struct llama_sampler* /*smpl*/) { + return "dry"; +} + + + +// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) +void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p) { + if (smpl->dry_multiplier == 0.0f || smpl->dry_base < 1.0f || smpl->dry_penalty_last_n == 0) { + return; + } + + int32_t effective_dry_penalty_last_n = (smpl->dry_penalty_last_n == -1) ? smpl->total_context_size : std::max(smpl->dry_penalty_last_n, 0); + int last_n_repeat = std::min(std::min((int)smpl->last_tokens.size(), effective_dry_penalty_last_n), smpl->total_context_size); + + if (last_n_repeat <= smpl->dry_allowed_length) { + return; + } + + smpl->dry_repeat_count.assign(last_n_repeat, 0); + smpl->dry_max_token_repeat.clear(); + + // Step 1: Look for restart sequences to limit the maximum repetition length. + // Work backwards through the context looking for any token that begins a restart sequence. + // + // The collection `restart_sequences` is a mapping from a "head" token to all "tail" + // sequences that together comprise a restart sequence. This allows us to quickly check + // whether each token is the head of a complete sequence. Most restart sequences are actually + // a single token, and for these the "tail" is an empty vector. + // + // If the token is a "head", test all restart sequences that begin with this token + // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and + // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The + // longest matching sequence (if any) is used to limit the maximum repetition length. + // + // Note that in the case case of a short sequence contained in a longer one, this might fail to + // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as + // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress + // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare. + // + // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we + // have already clamped the maximum tail sequence length when generating `restart_sequences`. + // With clamping, this scan is O(N) in the context length. + + int rep_limit = last_n_repeat; + for (int i = 0; i < last_n_repeat; ++i) { + llama_token token = smpl->last_tokens.rat(i); + auto its = smpl->dry_processed_breakers.equal_range(token); + if (its.first == smpl->dry_processed_breakers.end()) { + continue; + } + int longest_match = -1; + for (auto it = its.first; it != its.second; ++it) { + // Note that (*it) does not contain the head character, so seq_len will be + // the restart sequence length minus 1. + // In the common case of a single-token restart sequence, (*it) will be empty + // and we will trivially match. + int seq_len = (int)it->second.size(); + if (seq_len > longest_match && seq_len <= (int)i) { + bool match = true; + for (int offset = 0; offset < seq_len; ++offset) { + // The -1 when indexing `last_tokens` is because we already matched the head. + if (it->second[offset] != smpl->last_tokens.rat(i - offset - 1)) { + match = false; + break; + } + } + if (match) { + longest_match = seq_len; + } + } + } + if (longest_match >= 0) { + // We found a restart sequence starting `i` tokens from the end and continuing for + // `longest_match` tokens. + rep_limit = i - longest_match; + break; + } + } + if (rep_limit < smpl->dry_allowed_length) { + return; + } + + // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in + // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing + // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences. + // + // This algorithm is not currently documented on Wikipedia, but there is a clear description here: + // https://ivanyu.me/blog/2014/10/15/z-algorithm/ + // + // The code below is adapted from the public domain implementation by the same author here: + // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py + // + // Example: + // Last N tokens: a b c c b c y a b c + // Repeat counts: 0 0 3 1 0 2 0 0 0 0 + // ^ + // This `3` means that the last three tokens of the context (a b c) also appear here. + // + // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested + // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each + // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables + // ensure that the inner while loops only examine each token in the context once as the outer + // for loop iterates over the context. + + { + const int last = last_n_repeat - 1; + int rt = 0, lt = 0; + + for (int k = 1; k < last_n_repeat; ++k) { + if (k > rt) { + // If k is outside the current Z-box, do naive computation. + int n = 0; + while (n + k < last_n_repeat && smpl->last_tokens.rat(n) == smpl->last_tokens.rat(n + k)) { + ++n; + } + smpl->dry_repeat_count[last - k] = std::min(n, rep_limit); + if (n > 0) { + lt = k; + rt = k + n - 1; + } + } + else { + // If k is inside the current Z-box, consider two cases. + + int p = k - lt; // Pair index. + int right_part_len = rt - k + 1; + + if (smpl->dry_repeat_count[last - p] < right_part_len) { + int n = std::min(smpl->dry_repeat_count[last - p], rep_limit); + smpl->dry_repeat_count[last - k] = n; + } + else { + int i = rt + 1; + while (i < last_n_repeat && smpl->last_tokens.rat(i) == smpl->last_tokens.rat(i - k)) { + i += 1; + } + + int n = std::min(i - k, rep_limit); + smpl->dry_repeat_count[last - k] = n; + lt = k; + rt = i - 1; + } + } + } + } + + // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length + // that would be generated by emitting each new token that would extend a sequence. + // + // Following the same example as above: + // Last N tokens: a b c c b c y a b c + // Repeat counts: 0 0 3 1 0 2 0 0 0 0 + // + // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition. + // c: 3 -> 4 (from `a b c` to `a b c c`) + // b: 1 -> 2 (from `c` to `c b`) + // y: 2 -> 3 (from `b c` to `b c y`) + + for (int i = 0; i < last_n_repeat - 1; ++i) { + int repeat_len = smpl->dry_repeat_count[i]; + if (repeat_len >= smpl->dry_allowed_length) { + // This token ends a repeat, so the next token would continue one. + // By convention, the value of `repeat_len` only includes the tokens currently + // in the context, not the new token that would be added. + llama_token token = smpl->last_tokens.rat(last_n_repeat - 2 - i); + // Track the maximum sequence ending in this token. + const auto& it = smpl->dry_max_token_repeat.find(token); + if (it == smpl->dry_max_token_repeat.end() || it->second < repeat_len) { + smpl->dry_max_token_repeat[token] = repeat_len; + } + } + } + + // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens. + + // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`. + // Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()` + const float FLOAT_MAX_LOG = 88.7228391f; + int max_exponent = 0; + if (smpl->dry_base > 1.000001f) { + max_exponent = FLOAT_MAX_LOG / std::log(smpl->dry_base); + } + + for (size_t i = 0; i < cur_p->size; ++i) { + const auto& af_kvp = smpl->dry_max_token_repeat.find(cur_p->data[i].id); + if (af_kvp != smpl->dry_max_token_repeat.end()) { + // Check all sequence breakers starting with this token + auto range = smpl->dry_processed_breakers.equal_range(cur_p->data[i].id); + bool is_single_token_breaker = false; + + for (auto it = range.first; it != range.second; ++it) { + if (it->second.empty()) { + is_single_token_breaker = true; + break; + } + } + + // Apply penalty only if it's not a single-token sequence breaker + if (!is_single_token_breaker) { + int repeat_exp = af_kvp->second - smpl->dry_allowed_length; + if (max_exponent > 0 && repeat_exp > max_exponent) { + repeat_exp = max_exponent; + } + float penalty = smpl->dry_multiplier * std::pow(smpl->dry_base, repeat_exp); + cur_p->data[i].logit -= penalty; + } + } + } + + cur_p->sorted = false; +} + + + +struct llama_sampler_dry* llama_sampler_init_dry_impl(const struct llama_vocab& vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { + int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0); + std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers; + const int MAX_CHAR_LEN = 40; + const int MAX_SEQ_LEN = 20; + + const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0); + + if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) { + // Process sequence breakers + for (size_t i = 0; i < num_breakers; ++i) { + if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) { + LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i); + continue; + } + + std::string sequence_break(seq_breakers[i]); + if (sequence_break.empty()) { + LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n"); + continue; + } + + if (sequence_break.size() > MAX_CHAR_LEN) { + LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN); + sequence_break.resize(MAX_CHAR_LEN); + } + + get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN); + } + } + + return new llama_sampler_dry { + /* .total_context_size = */ context_size, + /* .dry_multiplier = */ dry_multiplier, + /* .dry_base = */ dry_base, + /* .dry_allowed_length = */ dry_allowed_length, + /* .dry_penalty_last_n = */ dry_penalty_last_n, + /* .dry_processed_breakers = */ std::move(processed_breakers), + /* .dry_repeat_count = */ dry_enabled ? std::vector<int>(effective_dry_penalty_last_n, 0) : std::vector<int>{}, + /* .dry_max_token_repeat = */ {}, + /* .last_tokens = */ dry_enabled ? ring_buffer<llama_token>(effective_dry_penalty_last_n) : ring_buffer<llama_token>(0), + }; +} + + diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 69d92a3a..855278e2 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -1,7 +1,7 @@ #pragma once #include "llama-impl.h" - +#include <unordered_map> struct llama_sampling { llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {} @@ -35,6 +35,34 @@ void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_ void llama_sample_xtc_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float probability, float threshold, size_t min_keep); void llama_sample_top_n_sigma_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float top_n_sigma); +struct llama_sampler_dry { + int32_t total_context_size; + + const float dry_multiplier; + const float dry_base; + const int32_t dry_allowed_length; + const int32_t dry_penalty_last_n; + + std::unordered_multimap<llama_token, std::vector<llama_token>> dry_processed_breakers; + std::vector<int> dry_repeat_count; + std::unordered_map<llama_token, int> dry_max_token_repeat; + ring_buffer<llama_token> last_tokens; +}; + +struct llama_sampler_dry * llama_sampler_init_dry_impl( + const struct llama_vocab & vocab, + int32_t context_size, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const char ** seq_breakers, + size_t num_breakers); + +void llama_sampler_dry_apply(struct llama_sampler_dry* smpl, llama_token_data_array* cur_p); + + + void llama_sample_repetition_penalties_impl( struct llama_sampling * smpl, llama_token_data_array * candidates, @@ -56,3 +84,5 @@ llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, ll llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng); llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); + + diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 09399417..abf48824 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -75,6 +75,9 @@ struct naive_trie { llama_token value; }; +uint32_t llama_vocab::n_tokens() const { + return (uint32_t)id_to_token.size(); +} // // impl // @@ -1741,3 +1744,19 @@ int32_t llama_detokenize_impl( return total <= text_len_max ? total : -total; } + +std::string llama_detokenize(const struct llama_vocab& vocab, const std::vector<llama_token>& tokens, bool special) { + std::string text; + text.resize(std::max(text.capacity(), tokens.size())); + int32_t n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + if (n_chars < 0) { + text.resize(-n_chars); + n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization + } + + text.resize(n_chars); + + // NOTE: the original tokenizer decodes bytes after collecting the pieces. + return text; +} diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 7adfc16d..a461eca0 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -23,6 +23,8 @@ struct llama_vocab { int max_token_len = 0; // used for optimizing longest token search + uint32_t n_tokens() const; + std::unordered_map<token, id> token_to_id; std::vector<token_data> id_to_token; @@ -130,3 +132,8 @@ int32_t llama_detokenize_impl( int32_t text_len_max, bool remove_special, bool unparse_special); + +std::string llama_detokenize( + const struct llama_vocab& vocab, + const std::vector<llama_token>& tokens, + bool special); diff --git a/src/llama.cpp b/src/llama.cpp index af8ef9be..c0f147b9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20849,6 +20849,10 @@ enum llama_vocab_type llama_vocab_type(const struct llama_model * model) { return model->vocab.type; } +const struct llama_vocab* llama_get_model_vocab(const struct llama_model* model) { + return &model->vocab; +} + enum llama_rope_type llama_rope_type(const struct llama_model * model) { switch (model->arch) { // these models do not use RoPE @@ -23280,6 +23284,11 @@ void llama_sample_top_n_sigma(struct llama_context * ctx, llama_token_data_array llama_sample_top_n_sigma_impl(ctx ? &ctx->sampling : nullptr, candidates_p, top_n_sigma); } + +void llama_sample_dry(struct llama_context* ctx, struct llama_sampler_dry* smpl, llama_token_data_array* candidates_p) { + llama_sampler_dry_apply(smpl, candidates_p); +} + void llama_sample_repetition_penalties( struct llama_context * ctx, llama_token_data_array * candidates, @@ -23327,6 +23336,42 @@ int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, return 0; } +struct llama_sampler_dry * llama_sampler_init_dry(const struct llama_vocab* vocab, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { + return llama_sampler_init_dry_impl(*vocab, vocab->n_tokens(), dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers, num_breakers); +} + +void llama_sampler_dry_reset(struct llama_sampler_dry* smpl) { + smpl->last_tokens.clear(); + smpl->dry_repeat_count.clear(); + smpl->dry_max_token_repeat.clear(); +} + +void llama_sampler_dry_free(struct llama_sampler_dry* smpl) { + delete smpl; +} + +struct llama_sampler_dry* llama_sampler_dry_clone(struct llama_sampler_dry* smpl) { + // nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying + auto* result = llama_sampler_init_dry(nullptr, smpl->dry_multiplier, smpl->dry_base, smpl->dry_allowed_length, smpl->dry_penalty_last_n, NULL, 0); + // Copy the state, including the processed breakers + { + auto* result_ctx = smpl; + result_ctx->dry_processed_breakers = smpl->dry_processed_breakers; + result_ctx->dry_repeat_count = smpl->dry_repeat_count; + result_ctx->dry_max_token_repeat = smpl->dry_max_token_repeat; + result_ctx->last_tokens = smpl->last_tokens; + } + + return result; +} + +void llama_sampler_dry_accept(struct llama_sampler_dry* smpl, llama_token token) { + if (smpl->dry_multiplier == 0.0f || smpl->dry_base < 1.0f || smpl->dry_penalty_last_n == 0) { + return; + } + smpl->last_tokens.push_back(token); +} + int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int split_no, int split_count) { std::string str_split_path(split_path); char postfix[32]; |