From 52c8bc3cf312e1caf02d37bfb9d9d865cbe33594 Mon Sep 17 00:00:00 2001 From: MaggotHATE Date: Tue, 5 Dec 2023 15:05:51 +0500 Subject: sampling : custom samplers order (#4285) * Samplers sequence order w parameter * Cleaned commented code * Fixed formatting * Rewrote with unordered_map * Revert and rewrite, too many problems and safeguards would be needed * Fixed code style * Code style fixes according to review * More readable samplers input string, fixed help * Style fix in sampler_queue * Formatting fixes * Fixing whitespaces --- common/sampling.cpp | 60 +++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 11 deletions(-) (limited to 'common/sampling.cpp') diff --git a/common/sampling.cpp b/common/sampling.cpp index 1317024c..b6bb886c 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -99,6 +99,54 @@ std::string llama_sampling_print(const llama_sampling_params & params) { return std::string(result); } +std::string llama_sampling_order_print(const llama_sampling_params & params) { + std::string result = "CFG -> Penalties "; + if (params.mirostat == 0) { + for (auto s : params.samplers_sequence) { + switch (s) { + case 'k': result += "-> top_k "; break; + case 'f': result += "-> tfs_z "; break; + case 'y': result += "-> typical_p "; break; + case 'p': result += "-> top_p "; break; + case 'm': result += "-> min_p "; break; + case 't': result += "-> temp "; break; + default : break; + } + } + } else result += "-> mirostat "; + + return result; +} + +// no reasons to expose this function in header +void sampler_queue( + struct llama_context * ctx_main, + const llama_sampling_params & params, + llama_token_data_array & cur_p, + size_t & min_keep) { + const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + + const float temp = params.temp; + const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; + const float top_p = params.top_p; + const float min_p = params.min_p; + const float tfs_z = params.tfs_z; + const float typical_p = params.typical_p; + const std::string & samplers_sequence = params.samplers_sequence; + + for (auto s : samplers_sequence) { + switch (s){ + case 'k': llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break; + case 'f': llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break; + case 'y': llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; + case 'p': llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; + case 'm': llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; + case 't': llama_sample_temp (ctx_main, &cur_p, temp); break; + default : break; + } + } +} + llama_token llama_sampling_sample( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, @@ -109,11 +157,6 @@ llama_token llama_sampling_sample( const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); const float temp = params.temp; - const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; - const float top_p = params.top_p; - const float min_p = params.min_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; const float penalty_repeat = params.penalty_repeat; const float penalty_freq = params.penalty_freq; @@ -188,12 +231,7 @@ llama_token llama_sampling_sample( // temperature sampling size_t min_keep = std::max(1, params.n_probs); - llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); - llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); - llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); - llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); - llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); - llama_sample_temp (ctx_main, &cur_p, temp); + sampler_queue(ctx_main, params, cur_p, min_keep); id = llama_sample_token(ctx_main, &cur_p); -- cgit v1.2.3