diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-10-20 21:07:23 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-20 21:07:23 +0300 |
commit | d1031cf49c3b958b915fd558e23453471c29ac33 (patch) | |
tree | 14fa2bc6d54d5e27bd1e8bfd6fa4dbf894dbe6b9 /common | |
parent | 8cf19d60dc93809db8e51fedc811595eed9134c5 (diff) |
sampling : refactor init to use llama_sampling_params (#3696)
* sampling : refactor init to use llama_sampling_params
* llama : combine repetition, frequency and presence penalties in 1 call
* examples : remove embd-input and gptneox-wip
* sampling : rename penalty params + reduce size of "prev" vector
* sampling : add llama_sampling_print helper
* sampling : hide prev behind API and apply #3661
ggml-ci
Diffstat (limited to 'common')
-rw-r--r-- | common/common.cpp | 69 | ||||
-rw-r--r-- | common/common.h | 3 | ||||
-rw-r--r-- | common/sampling.cpp | 73 | ||||
-rw-r--r-- | common/sampling.h | 32 |
4 files changed, 108 insertions, 69 deletions
diff --git a/common/common.cpp b/common/common.cpp index ce14d66b..2ef902bd 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -107,7 +107,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { std::string arg; gpt_params default_params; const std::string arg_prefix = "--"; - llama_sampling_params & sparams = params.sampling_params; + llama_sampling_params & sparams = params.sparams; for (int i = 1; i < argc; i++) { arg = argv[i]; @@ -241,25 +241,26 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - sparams.repeat_last_n = std::stoi(argv[i]); + sparams.penalty_last_n = std::stoi(argv[i]); + sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n); } else if (arg == "--repeat-penalty") { if (++i >= argc) { invalid_param = true; break; } - sparams.repeat_penalty = std::stof(argv[i]); + sparams.penalty_repeat = std::stof(argv[i]); } else if (arg == "--frequency-penalty") { if (++i >= argc) { invalid_param = true; break; } - sparams.frequency_penalty = std::stof(argv[i]); + sparams.penalty_freq = std::stof(argv[i]); } else if (arg == "--presence-penalty") { if (++i >= argc) { invalid_param = true; break; } - sparams.presence_penalty = std::stof(argv[i]); + sparams.penalty_present = std::stof(argv[i]); } else if (arg == "--mirostat") { if (++i >= argc) { invalid_param = true; @@ -572,7 +573,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - params.grammar = argv[i]; + sparams.grammar = argv[i]; } else if (arg == "--grammar-file") { if (++i >= argc) { invalid_param = true; @@ -587,7 +588,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { std::copy( std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), - std::back_inserter(params.grammar) + std::back_inserter(sparams.grammar) ); #ifndef LOG_DISABLE_LOGS // Parse args for logging parameters @@ -640,7 +641,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { - const llama_sampling_params & sparams = params.sampling_params; + const llama_sampling_params & sparams = params.sparams; printf("usage: %s [options]\n", argv[0]); printf("\n"); @@ -678,10 +679,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); - printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.repeat_last_n); - printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.repeat_penalty); - printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.presence_penalty); - printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.frequency_penalty); + printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n); + printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat); + printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present); + printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq); printf(" --mirostat N use Mirostat sampling.\n"); printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat); @@ -878,7 +879,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par } if (params.ignore_eos) { - params.sampling_params.logit_bias[llama_token_eos(lctx)] = -INFINITY; + params.sparams.logit_bias[llama_token_eos(lctx)] = -INFINITY; } { @@ -1123,28 +1124,28 @@ std::string get_sortable_timestamp() { void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx, const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc) { - const llama_sampling_params & sparams = params.sampling_params; + const llama_sampling_params & sparams = params.sparams; fprintf(stream, "build_commit: %s\n", BUILD_COMMIT); fprintf(stream, "build_number: %d\n", BUILD_NUMBER); - fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false"); - fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false"); - fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false"); + fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false"); + fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false"); + fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false"); + fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false"); fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false"); fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false"); - fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false"); - fprintf(stream, "cpu_has_cublas: %s\n", ggml_cpu_has_cublas() ? "true" : "false"); - fprintf(stream, "cpu_has_clblast: %s\n", ggml_cpu_has_clblast() ? "true" : "false"); - fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false"); - fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false"); - fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false"); - fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false"); - fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false"); - fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false"); - fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false"); - fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false"); - fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false"); + fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false"); + fprintf(stream, "cpu_has_cublas: %s\n", ggml_cpu_has_cublas() ? "true" : "false"); + fprintf(stream, "cpu_has_clblast: %s\n", ggml_cpu_has_clblast() ? "true" : "false"); + fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false"); + fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false"); + fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false"); + fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false"); + fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false"); + fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false"); + fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false"); + fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false"); + fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false"); #ifdef NDEBUG fprintf(stream, "debug: false\n"); @@ -1178,8 +1179,8 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); - fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.frequency_penalty); - dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str()); + fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq); + dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str()); fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); @@ -1238,14 +1239,14 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false"); fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); - fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.presence_penalty); + fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present); dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str()); fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str()); fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false"); fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false"); dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens); fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false"); - fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.repeat_penalty); + fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat); fprintf(stream, "reverse_prompt:\n"); for (std::string ap : params.antiprompt) { diff --git a/common/common.h b/common/common.h index 65d3d20c..84523a4f 100644 --- a/common/common.h +++ b/common/common.h @@ -56,7 +56,7 @@ struct gpt_params { float rope_freq_scale = 0.0f; // RoPE frequency scaling factor // // sampling parameters - struct llama_sampling_params sampling_params; + struct llama_sampling_params sparams; std::string model = "models/7B/ggml-model-f16.gguf"; // model path std::string model_draft = ""; // draft model for speculative decoding @@ -66,7 +66,6 @@ struct gpt_params { std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string input_prefix = ""; // string to prefix user inputs with std::string input_suffix = ""; // string to suffix user inputs with - std::string grammar = ""; // optional BNF-like grammar to constrain sampling std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted std::string logdir = ""; // directory in which to save YAML log files diff --git a/common/sampling.cpp b/common/sampling.cpp index 0b246658..6f0af3c4 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,9 +1,9 @@ #include "sampling.h" -struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params) { +struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); - result->params = params.sampling_params; + result->params = params; result->grammar = nullptr; // if there is a grammar, parse it @@ -23,7 +23,7 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_params & pa grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root")); } - result->prev.resize(params.n_ctx); + result->prev.resize(params.n_prev); return result; } @@ -66,25 +66,56 @@ void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * ds dst->prev = src->prev; } +llama_token llama_sampling_last(llama_sampling_context * ctx) { + return ctx->prev.back(); +} + +std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) { + const int size = ctx_sampling->prev.size(); + + n = std::min(n, size); + + std::string result; + + for (int i = size - n; i < size; i++) { + result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]); + } + + return result; +} + +std::string llama_sampling_print(const llama_sampling_params & params) { + char result[1024]; + + snprintf(result, sizeof(result), + "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" + "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, typical_p = %.3f, temp = %.3f\n" + "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", + params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, + params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, + params.mirostat, params.mirostat_eta, params.mirostat_tau); + + return std::string(result); +} + llama_token llama_sampling_sample( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, struct llama_context * ctx_cfg, const int idx) { - const int n_ctx = llama_n_ctx(ctx_main); - const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); - const llama_sampling_params & params = ctx_sampling->params; + const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + const float temp = params.temp; const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; const float top_p = params.top_p; const float tfs_z = params.tfs_z; const float typical_p = params.typical_p; - const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; - const float repeat_penalty = params.repeat_penalty; - const float alpha_presence = params.presence_penalty; - const float alpha_frequency = params.frequency_penalty; + const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; + const float penalty_repeat = params.penalty_repeat; + const float penalty_freq = params.penalty_freq; + const float penalty_present = params.penalty_present; const int mirostat = params.mirostat; const float mirostat_tau = params.mirostat_tau; const float mirostat_eta = params.mirostat_eta; @@ -97,7 +128,7 @@ llama_token llama_sampling_sample( float * logits = llama_get_logits_ith(ctx_main, idx); - // Apply params.logit_bias map + // apply params.logit_bias map for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { logits[it->first] += it->second; } @@ -117,14 +148,10 @@ llama_token llama_sampling_sample( // apply penalties if (!prev.empty()) { const float nl_logit = logits[llama_token_nl(ctx_main)]; - const int last_n_repeat = std::min(std::min((int)prev.size(), repeat_last_n), n_ctx); - llama_sample_repetition_penalty(ctx_main, &cur_p, - prev.data() + prev.size() - last_n_repeat, - last_n_repeat, repeat_penalty); - llama_sample_frequency_and_presence_penalties(ctx_main, &cur_p, - prev.data() + prev.size() - last_n_repeat, - last_n_repeat, alpha_frequency, alpha_presence); + llama_sample_repetition_penalties(ctx_main, &cur_p, + prev.data() + prev.size() - penalty_last_n, + penalty_last_n, penalty_repeat, penalty_freq, penalty_present); if (!penalize_nl) { for (size_t idx = 0; idx < cur_p.size; idx++) { @@ -141,7 +168,7 @@ llama_token llama_sampling_sample( } if (temp <= 0) { - // Greedy sampling + // greedy sampling id = llama_sample_token_greedy(ctx_main, &cur_p); } else { if (mirostat == 1) { @@ -152,8 +179,9 @@ llama_token llama_sampling_sample( llama_sample_temp(ctx_main, &cur_p, temp); id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); } else { - // Temperature sampling + // temperature sampling size_t min_keep = std::max(1, params.n_probs); + llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); @@ -183,11 +211,12 @@ llama_token llama_sampling_sample( void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, - llama_token id) { + llama_token id, + bool apply_grammar) { ctx_sampling->prev.erase(ctx_sampling->prev.begin()); ctx_sampling->prev.push_back(id); - if (ctx_sampling->grammar != NULL) { + if (ctx_sampling->grammar != NULL && apply_grammar) { llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); } } diff --git a/common/sampling.h b/common/sampling.h index 50afcbc1..62ea6d4c 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -10,30 +10,30 @@ // sampling parameters typedef struct llama_sampling_params { + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled float tfs_z = 1.00f; // 1.0 = disabled float typical_p = 1.00f; // 1.0 = disabled float temp = 0.80f; // 1.0 = disabled - float repeat_penalty = 1.10f; // 1.0 = disabled - int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float frequency_penalty = 0.00f; // 0.0 = disabled - float presence_penalty = 0.00f; // 0.0 = disabled + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.10f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate - bool penalize_nl = true; // consider newlines as a repeatable token - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + std::string grammar; // optional BNF-like grammar to constrain sampling // Classifier-Free Guidance // https://arxiv.org/abs/2306.17806 - std::string cfg_negative_prompt; // string to help guidance - float cfg_scale = 1.f; // How strong is guidance + std::string cfg_negative_prompt; // string to help guidance + float cfg_scale = 1.f; // how strong is guidance std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens - } llama_sampling_params; // general sampler context @@ -58,7 +58,7 @@ struct llama_sampling_context { #include "common.h" // Create a new sampling context instance. -struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params); +struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params); void llama_sampling_free(struct llama_sampling_context * ctx); @@ -70,6 +70,15 @@ void llama_sampling_reset(llama_sampling_context * ctx); // Copy the sampler context void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); +// Get the last sampled token +llama_token llama_sampling_last(llama_sampling_context * ctx); + +// Get a string representation of the last sampled tokens +std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n); + +// Print sampling parameters into a string +std::string llama_sampling_print(const llama_sampling_params & params); + // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function // Note: When using multiple sequences, it is the caller's responsibility to call @@ -96,4 +105,5 @@ llama_token llama_sampling_sample( void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_main, - llama_token id); + llama_token id, + bool apply_grammar); |