summaryrefslogtreecommitdiff
path: root/examples/main/main.cpp
diff options
context:
space:
mode:
authorKerfuffle <44031344+KerfuffleV2@users.noreply.github.com>2023-10-11 13:35:46 -0600
committerGitHub <noreply@github.com>2023-10-11 22:35:46 +0300
commit70c29da118cdb02bfcbd0376c32b5b2236e48e48 (patch)
tree9ba08e6a18d60e24b580d58b57f9c2b7a8848f3d /examples/main/main.cpp
parent8c70a5ff25964f0a81e20d142a2f5ac5baff22fc (diff)
common : fix mirostat state when using multiple sequences (#3543)
* Fix mirostat state when using multiple sequences * Fix mirostat by completely refactoring sampling! * Try to fix zig build. * Export function to fetch/create default sampler states Code formatting cleanups and add some comments Silence a warning about id not being used when logging is disabled * Apply some renaming suggestions. Fix comments that were out of sync with the pull. * Use more consistant naming convention for sampling contexts
Diffstat (limited to 'examples/main/main.cpp')
-rw-r--r--examples/main/main.cpp18
1 files changed, 10 insertions, 8 deletions
diff --git a/examples/main/main.cpp b/examples/main/main.cpp
index 775a5a20..b39a67d9 100644
--- a/examples/main/main.cpp
+++ b/examples/main/main.cpp
@@ -109,6 +109,7 @@ int main(int argc, char ** argv) {
if (!gpt_params_parse(argc, argv, params)) {
return 1;
}
+ llama_sampling_params & sparams = params.sampling_params;
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("main", "log"));
@@ -179,7 +180,7 @@ int main(int argc, char ** argv) {
// load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
std::tie(model, ctx) = llama_init_from_gpt_params(params);
- if (params.cfg_scale > 1.f) {
+ if (sparams.cfg_scale > 1.f) {
struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
ctx_guidance = llama_new_context_with_model(model, lparams);
}
@@ -257,9 +258,9 @@ int main(int argc, char ** argv) {
int guidance_offset = 0;
int original_prompt_len = 0;
if (ctx_guidance) {
- LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(params.cfg_negative_prompt));
+ LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
- guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos);
+ guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
@@ -343,7 +344,7 @@ int main(int argc, char ** argv) {
if (ctx_guidance) {
LOG_TEE("\n");
- LOG_TEE("%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str());
+ LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str());
LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
for (int i = 0; i < (int) guidance_inp.size(); i++) {
LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str());
@@ -395,7 +396,7 @@ int main(int argc, char ** argv) {
}
}
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
- params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
+ sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau);
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n");
@@ -413,8 +414,8 @@ int main(int argc, char ** argv) {
LOG_TEE("\n");
{
- auto it = params.logit_bias.find(llama_token_eos(ctx));
- if (it != params.logit_bias.end() && it->second == -INFINITY) {
+ auto it = sparams.logit_bias.find(llama_token_eos(ctx));
+ if (it != sparams.logit_bias.end() && it->second == -INFINITY) {
LOG_TEE("%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__);
}
}
@@ -469,6 +470,7 @@ int main(int argc, char ** argv) {
const int n_vocab = llama_n_vocab(model);
+ llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
@@ -625,7 +627,7 @@ int main(int argc, char ** argv) {
LOG("saved session to %s\n", path_session.c_str());
}
- const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates);
+ const llama_token id = llama_sampling_sample(ctx, ctx_guidance, ctx_sampling, last_tokens, candidates);
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(id);