summaryrefslogtreecommitdiff
path: root/examples/embd-input
diff options
context:
space:
mode:
authorKerfuffle <44031344+KerfuffleV2@users.noreply.github.com>2023-10-11 13:35:46 -0600
committerGitHub <noreply@github.com>2023-10-11 22:35:46 +0300
commit70c29da118cdb02bfcbd0376c32b5b2236e48e48 (patch)
tree9ba08e6a18d60e24b580d58b57f9c2b7a8848f3d /examples/embd-input
parent8c70a5ff25964f0a81e20d142a2f5ac5baff22fc (diff)
common : fix mirostat state when using multiple sequences (#3543)
* Fix mirostat state when using multiple sequences * Fix mirostat by completely refactoring sampling! * Try to fix zig build. * Export function to fetch/create default sampler states Code formatting cleanups and add some comments Silence a warning about id not being used when logging is disabled * Apply some renaming suggestions. Fix comments that were out of sync with the pull. * Use more consistant naming convention for sampling contexts
Diffstat (limited to 'examples/embd-input')
-rw-r--r--examples/embd-input/embd-input-lib.cpp19
1 files changed, 10 insertions, 9 deletions
diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp
index 99e6bdad..87a5a1c2 100644
--- a/examples/embd-input/embd-input-lib.cpp
+++ b/examples/embd-input/embd-input-lib.cpp
@@ -128,21 +128,22 @@ bool eval_string(struct MyModel * mymodel,const char* str){
llama_token sampling_id(struct MyModel* mymodel) {
llama_context* ctx = mymodel->ctx;
gpt_params params = mymodel->params;
+ llama_sampling_params & sparams = params.sampling_params;
// int n_ctx = llama_n_ctx(ctx);
// out of user input, sample next token
- const float temp = params.temp;
- const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : params.top_k;
- const float top_p = params.top_p;
- const float tfs_z = params.tfs_z;
- const float typical_p = params.typical_p;
+ const float temp = sparams.temp;
+ const int32_t top_k = sparams.top_k <= 0 ? llama_n_vocab(llama_get_model(ctx)) : sparams.top_k;
+ const float top_p = sparams.top_p;
+ const float tfs_z = sparams.tfs_z;
+ const float typical_p = sparams.typical_p;
// const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
// const float repeat_penalty = params.repeat_penalty;
// const float alpha_presence = params.presence_penalty;
// const float alpha_frequency = params.frequency_penalty;
- const int mirostat = params.mirostat;
- const float mirostat_tau = params.mirostat_tau;
- const float mirostat_eta = params.mirostat_eta;
+ const int mirostat = sparams.mirostat;
+ const float mirostat_tau = sparams.mirostat_tau;
+ const float mirostat_eta = sparams.mirostat_eta;
// const bool penalize_nl = params.penalize_nl;
llama_token id = 0;
@@ -151,7 +152,7 @@ llama_token sampling_id(struct MyModel* mymodel) {
auto n_vocab = llama_n_vocab(llama_get_model(ctx));
// Apply params.logit_bias map
- for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
+ for (auto it = sparams.logit_bias.begin(); it != sparams.logit_bias.end(); it++) {
logits[it->first] += it->second;
}