From 70c29da118cdb02bfcbd0376c32b5b2236e48e48 Mon Sep 17 00:00:00 2001 From: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> Date: Wed, 11 Oct 2023 13:35:46 -0600 Subject: common : fix mirostat state when using multiple sequences (#3543) * Fix mirostat state when using multiple sequences * Fix mirostat by completely refactoring sampling! * Try to fix zig build. * Export function to fetch/create default sampler states Code formatting cleanups and add some comments Silence a warning about id not being used when logging is disabled * Apply some renaming suggestions. Fix comments that were out of sync with the pull. * Use more consistant naming convention for sampling contexts --- examples/embd-input/embd-input-lib.cpp | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) (limited to 'examples/embd-input/embd-input-lib.cpp') diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index 99e6bdad..87a5a1c2 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -128,21 +128,22 @@ bool eval_string(struct MyModel * mymodel,const char* str){ llama_token sampling_id(struct MyModel* mymodel) { llama_context* ctx = mymodel->ctx; gpt_params params = mymodel->params; + llama_sampling_params & sparams = params.sampling_params; // int n_ctx = llama_n_ctx(ctx); // out of user input, sample next token - const float temp = params.temp; - const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : params.top_k; - const float top_p = params.top_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; + const float temp = sparams.temp; + const int32_t top_k = sparams.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : sparams.top_k; + const float top_p = sparams.top_p; + const float tfs_z = sparams.tfs_z; + const float typical_p = sparams.typical_p; // const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; // const float repeat_penalty = params.repeat_penalty; // const float alpha_presence = params.presence_penalty; // const float alpha_frequency = params.frequency_penalty; - const int mirostat = params.mirostat; - const float mirostat_tau = params.mirostat_tau; - const float mirostat_eta = params.mirostat_eta; + const int mirostat = sparams.mirostat; + const float mirostat_tau = sparams.mirostat_tau; + const float mirostat_eta = sparams.mirostat_eta; // const bool penalize_nl = params.penalize_nl; llama_token id = 0; @@ -151,7 +152,7 @@ llama_token sampling_id(struct MyModel* mymodel) { auto n_vocab = llama_n_vocab(llama_get_model(ctx)); // Apply params.logit_bias map - for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { + for (auto it = sparams.logit_bias.begin(); it != sparams.logit_bias.end(); it++) { logits[it->first] += it->second; } -- cgit v1.2.3