diff options
Diffstat (limited to 'common/sampling.cpp')
-rw-r--r-- | common/sampling.cpp | 69 |
1 files changed, 50 insertions, 19 deletions
diff --git a/common/sampling.cpp b/common/sampling.cpp index 4db12ee1..4b983e5f 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,8 +1,9 @@ #define LLAMA_API_INTERNAL #include "sampling.h" +#include "llama-vocab.h" #include <random> -struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { +struct llama_sampling_context * llama_sampling_init(const struct llama_vocab* vocab, const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); result->params = params; @@ -36,13 +37,32 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ } result->grammar = grammar; } - result->prev.resize(params.n_prev); result->n_valid = 0; + // init DRY + for (const auto& cnstr : params.samplers_sequence) + { + switch (cnstr) + { + case llama_sampler_type::DRY: + { + std::vector<const char*> c_breakers; + c_breakers.reserve(params.dry_sequence_breakers.size()); + for (const auto& str : params.dry_sequence_breakers) + { + c_breakers.push_back(str.c_str()); + } + result->smpl=llama_sampler_init_dry(vocab, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()); + + break; + } + default: + break; + } + } llama_sampling_set_rng_seed(result, params.seed); - return result; } @@ -50,7 +70,8 @@ void llama_sampling_free(struct llama_sampling_context * ctx) { if (ctx->grammar != NULL) { llama_grammar_free(ctx->grammar); } - + if (ctx->smpl !=NULL) + llama_sampler_dry_free(ctx->smpl); delete ctx; } @@ -75,6 +96,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) { std::fill(ctx->prev.begin(), ctx->prev.end(), 0); ctx->cur.clear(); ctx->n_valid = 0; + llama_sampler_dry_reset(ctx->smpl); } void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { @@ -95,6 +117,7 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds } dst->prev = src->prev; + dst->smpl = llama_sampler_dry_clone(src->smpl); } llama_token llama_sampling_last(llama_sampling_context * ctx) { @@ -149,6 +172,7 @@ std::string llama_sampling_order_print(const llama_sampling_params & params) { std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) { switch (sampler_type) { + case llama_sampler_type::DRY: return "dry"; case llama_sampler_type::TOP_K: return "top_k"; case llama_sampler_type::TFS_Z: return "tfs_z"; case llama_sampler_type::TYPICAL_P: return "typical_p"; @@ -163,6 +187,7 @@ std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) { std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names) { std::unordered_map<std::string, llama_sampler_type> sampler_canonical_name_map { + {"dry", llama_sampler_type::DRY}, {"top_k", llama_sampler_type::TOP_K}, {"top_p", llama_sampler_type::TOP_P}, {"typical_p", llama_sampler_type::TYPICAL_P}, @@ -176,6 +201,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto // since samplers names are written multiple ways // make it ready for both system names and input names std::unordered_map<std::string, llama_sampler_type> sampler_alt_name_map { + {"dry", llama_sampler_type::DRY}, {"top-k", llama_sampler_type::TOP_K}, {"top-p", llama_sampler_type::TOP_P}, {"nucleus", llama_sampler_type::TOP_P}, @@ -215,6 +241,7 @@ std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vecto std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string) { std::unordered_map<char, llama_sampler_type> sampler_name_map { + {'d', llama_sampler_type::DRY}, {'k', llama_sampler_type::TOP_K}, {'p', llama_sampler_type::TOP_P}, {'y', llama_sampler_type::TYPICAL_P}, @@ -238,25 +265,28 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin // no reasons to expose this function in header static void sampler_queue( - struct llama_context * ctx_main, - const llama_sampling_params & params, - llama_token_data_array & cur_p, - size_t min_keep) { - const float temp = params.temp; - const float dynatemp_range = params.dynatemp_range; + struct llama_context* ctx_main, + const llama_sampling_params& params, + llama_sampling_context * ctx_sampling, + llama_token_data_array& cur_p, + size_t min_keep) { + const float temp = params.temp; + const float dynatemp_range = params.dynatemp_range; const float dynatemp_exponent = params.dynatemp_exponent; - const int32_t top_k = params.top_k; - const float top_p = params.top_p; - const float min_p = params.min_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; - const float xtc_probability = params.xtc_probability; - const float xtc_threshold = params.xtc_threshold; - const float top_n_sigma = params.top_n_sigma; + const int32_t top_k = params.top_k; + const float top_p = params.top_p; + const float min_p = params.min_p; + const float tfs_z = params.tfs_z; + const float typical_p = params.typical_p; + const float xtc_probability = params.xtc_probability; + const float xtc_threshold = params.xtc_threshold; + const float top_n_sigma = params.top_n_sigma; + const std::vector<llama_sampler_type> & samplers_sequence = params.samplers_sequence; for (auto sampler_type : samplers_sequence) { switch (sampler_type) { + case llama_sampler_type::DRY : llama_sample_dry (ctx_main, ctx_sampling->smpl, &cur_p); break; case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break; case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break; case llama_sampler_type::TYPICAL_P : llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; @@ -317,7 +347,7 @@ static llama_token llama_sampling_sample_impl( // temperature sampling size_t min_keep = std::max(1, params.min_keep); - sampler_queue(ctx_main, params, cur_p, min_keep); + sampler_queue(ctx_main, params,ctx_sampling, cur_p, min_keep); id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng); @@ -472,4 +502,5 @@ void llama_sampling_accept( if (ctx_sampling->grammar != NULL && apply_grammar) { llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id); } + llama_sampler_dry_accept(ctx_sampling->smpl, id); } |