summaryrefslogtreecommitdiff
path: root/examples/server/server.cpp
diff options
context:
space:
mode:
authorKerfuffle <44031344+KerfuffleV2@users.noreply.github.com>2023-10-11 13:35:46 -0600
committerGitHub <noreply@github.com>2023-10-11 22:35:46 +0300
commit70c29da118cdb02bfcbd0376c32b5b2236e48e48 (patch)
tree9ba08e6a18d60e24b580d58b57f9c2b7a8848f3d /examples/server/server.cpp
parent8c70a5ff25964f0a81e20d142a2f5ac5baff22fc (diff)
common : fix mirostat state when using multiple sequences (#3543)
* Fix mirostat state when using multiple sequences * Fix mirostat by completely refactoring sampling! * Try to fix zig build. * Export function to fetch/create default sampler states Code formatting cleanups and add some comments Silence a warning about id not being used when logging is disabled * Apply some renaming suggestions. Fix comments that were out of sync with the pull. * Use more consistant naming convention for sampling contexts
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r--examples/server/server.cpp100
1 files changed, 54 insertions, 46 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 8c5318c6..58af78de 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -200,6 +200,7 @@ struct llama_server_context
llama_model *model = nullptr;
llama_context *ctx = nullptr;
gpt_params params;
+ llama_sampling_context ctx_sampling;
int n_ctx;
grammar_parser::parse_state parsed_grammar;
@@ -254,6 +255,7 @@ struct llama_server_context
if (grammar != nullptr) {
llama_grammar_free(grammar);
grammar = nullptr;
+ ctx_sampling = llama_sampling_context_init(params, NULL);
}
}
@@ -329,8 +331,8 @@ struct llama_server_context
grammar_parser::print_grammar(stderr, parsed_grammar);
{
- auto it = params.logit_bias.find(llama_token_eos(ctx));
- if (it != params.logit_bias.end() && it->second == -INFINITY) {
+ auto it = params.sampling_params.logit_bias.find(llama_token_eos(ctx));
+ if (it != params.sampling_params.logit_bias.end() && it->second == -INFINITY) {
LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
}
}
@@ -339,6 +341,7 @@ struct llama_server_context
grammar = llama_grammar_init(
grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
+ ctx_sampling = llama_sampling_context_init(params, grammar);
return true;
}
@@ -550,12 +553,12 @@ struct llama_server_context
std::vector<llama_token_data> candidates;
candidates.reserve(llama_n_vocab(model));
- result.tok = llama_sample_token(ctx, NULL, grammar, params, last_n_tokens, candidates);
+ result.tok = llama_sampling_sample(ctx, NULL, ctx_sampling, last_n_tokens, candidates);
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
- const int32_t n_probs = params.n_probs;
- if (params.temp <= 0 && n_probs > 0)
+ const int32_t n_probs = params.sampling_params.n_probs;
+ if (params.sampling_params.temp <= 0 && n_probs > 0)
{
// For llama_sample_token_greedy we need to sort candidates
llama_sample_softmax(ctx, &candidates_p);
@@ -630,7 +633,7 @@ struct llama_server_context
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_piece(ctx, token_with_probs.tok);
generated_text += token_text;
- if (params.n_probs > 0)
+ if (params.sampling_params.n_probs > 0)
{
generated_token_probs.push_back(token_with_probs);
}
@@ -1018,34 +1021,35 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
static json format_generation_settings(llama_server_context &llama)
{
- const auto eos_bias = llama.params.logit_bias.find(llama_token_eos(llama.ctx));
- const bool ignore_eos = eos_bias != llama.params.logit_bias.end() &&
+ const auto & sparams = llama.params.sampling_params;
+ const auto eos_bias = sparams.logit_bias.find(llama_token_eos(llama.ctx));
+ const bool ignore_eos = eos_bias != sparams.logit_bias.end() &&
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
return json{
{"n_ctx", llama.n_ctx},
{"model", llama.params.model_alias},
{"seed", llama.params.seed},
- {"temp", llama.params.temp},
- {"top_k", llama.params.top_k},
- {"top_p", llama.params.top_p},
- {"tfs_z", llama.params.tfs_z},
- {"typical_p", llama.params.typical_p},
- {"repeat_last_n", llama.params.repeat_last_n},
- {"repeat_penalty", llama.params.repeat_penalty},
- {"presence_penalty", llama.params.presence_penalty},
- {"frequency_penalty", llama.params.frequency_penalty},
- {"mirostat", llama.params.mirostat},
- {"mirostat_tau", llama.params.mirostat_tau},
- {"mirostat_eta", llama.params.mirostat_eta},
- {"penalize_nl", llama.params.penalize_nl},
+ {"temp", sparams.temp},
+ {"top_k", sparams.top_k},
+ {"top_p", sparams.top_p},
+ {"tfs_z", sparams.tfs_z},
+ {"typical_p", sparams.typical_p},
+ {"repeat_last_n", sparams.repeat_last_n},
+ {"repeat_penalty", sparams.repeat_penalty},
+ {"presence_penalty", sparams.presence_penalty},
+ {"frequency_penalty", sparams.frequency_penalty},
+ {"mirostat", sparams.mirostat},
+ {"mirostat_tau", sparams.mirostat_tau},
+ {"mirostat_eta", sparams.mirostat_eta},
+ {"penalize_nl", sparams.penalize_nl},
{"stop", llama.params.antiprompt},
{"n_predict", llama.params.n_predict},
{"n_keep", llama.params.n_keep},
{"ignore_eos", ignore_eos},
{"stream", llama.stream},
- {"logit_bias", llama.params.logit_bias},
- {"n_probs", llama.params.n_probs},
+ {"logit_bias", sparams.logit_bias},
+ {"n_probs", sparams.n_probs},
{"grammar", llama.params.grammar},
};
}
@@ -1094,7 +1098,7 @@ static json format_final_response(llama_server_context &llama, const std::string
{"timings", format_timings(llama)},
};
- if (llama.params.n_probs > 0)
+ if (llama.params.sampling_params.n_probs > 0)
{
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
}
@@ -1110,7 +1114,7 @@ static json format_partial_response(
{"stop", false},
};
- if (llama.params.n_probs > 0)
+ if (llama.params.sampling_params.n_probs > 0)
{
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
}
@@ -1142,26 +1146,28 @@ static T json_value(const json &body, const std::string &key, const T &default_v
static void parse_options_completion(const json &body, llama_server_context &llama)
{
gpt_params default_params;
+ const auto & default_sparams = default_params.sampling_params;
+ auto & sparams = llama.params.sampling_params;
llama.stream = json_value(body, "stream", false);
llama.params.n_predict = json_value(body, "n_predict", default_params.n_predict);
- llama.params.top_k = json_value(body, "top_k", default_params.top_k);
- llama.params.top_p = json_value(body, "top_p", default_params.top_p);
- llama.params.tfs_z = json_value(body, "tfs_z", default_params.tfs_z);
- llama.params.typical_p = json_value(body, "typical_p", default_params.typical_p);
- llama.params.repeat_last_n = json_value(body, "repeat_last_n", default_params.repeat_last_n);
- llama.params.temp = json_value(body, "temperature", default_params.temp);
- llama.params.repeat_penalty = json_value(body, "repeat_penalty", default_params.repeat_penalty);
- llama.params.presence_penalty = json_value(body, "presence_penalty", default_params.presence_penalty);
- llama.params.frequency_penalty = json_value(body, "frequency_penalty", default_params.frequency_penalty);
- llama.params.mirostat = json_value(body, "mirostat", default_params.mirostat);
- llama.params.mirostat_tau = json_value(body, "mirostat_tau", default_params.mirostat_tau);
- llama.params.mirostat_eta = json_value(body, "mirostat_eta", default_params.mirostat_eta);
- llama.params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl);
+ sparams.top_k = json_value(body, "top_k", default_sparams.top_k);
+ sparams.top_p = json_value(body, "top_p", default_sparams.top_p);
+ sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z);
+ sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p);
+ sparams.repeat_last_n = json_value(body, "repeat_last_n", default_sparams.repeat_last_n);
+ sparams.temp = json_value(body, "temperature", default_sparams.temp);
+ sparams.repeat_penalty = json_value(body, "repeat_penalty", default_sparams.repeat_penalty);
+ sparams.presence_penalty = json_value(body, "presence_penalty", default_sparams.presence_penalty);
+ sparams.frequency_penalty = json_value(body, "frequency_penalty", default_sparams.frequency_penalty);
+ sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat);
+ sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
+ sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
+ sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep);
llama.params.seed = json_value(body, "seed", default_params.seed);
llama.params.grammar = json_value(body, "grammar", default_params.grammar);
- llama.params.n_probs = json_value(body, "n_probs", default_params.n_probs);
+ sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
if (body.count("prompt") != 0)
{
@@ -1172,10 +1178,10 @@ static void parse_options_completion(const json &body, llama_server_context &lla
llama.prompt = "";
}
- llama.params.logit_bias.clear();
+ sparams.logit_bias.clear();
if (json_value(body, "ignore_eos", false))
{
- llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY;
+ sparams.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY;
}
const auto &logit_bias = body.find("logit_bias");
@@ -1191,11 +1197,11 @@ static void parse_options_completion(const json &body, llama_server_context &lla
{
if (el[1].is_number())
{
- llama.params.logit_bias[tok] = el[1].get<float>();
+ sparams.logit_bias[tok] = el[1].get<float>();
}
else if (el[1].is_boolean() && !el[1].get<bool>())
{
- llama.params.logit_bias[tok] = -INFINITY;
+ sparams.logit_bias[tok] = -INFINITY;
}
}
}
@@ -1215,6 +1221,8 @@ static void parse_options_completion(const json &body, llama_server_context &lla
}
}
+ llama.ctx_sampling = llama_sampling_context_init(llama.params, llama.grammar);
+
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
}
@@ -1423,7 +1431,7 @@ int main(int argc, char **argv)
}
auto probs = llama.generated_token_probs;
- if (llama.params.n_probs > 0 && llama.stopped_word) {
+ if (llama.params.sampling_params.n_probs > 0 && llama.stopped_word) {
const std::vector<llama_token> stop_word_toks = llama_tokenize(llama.ctx, llama.stopping_word, false);
probs = std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.end() - stop_word_toks.size());
}
@@ -1475,7 +1483,7 @@ int main(int argc, char **argv)
std::vector<completion_token_output> probs_output = {};
- if (llama.params.n_probs > 0) {
+ if (llama.params.sampling_params.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
@@ -1596,7 +1604,7 @@ int main(int argc, char **argv)
std::vector<completion_token_output> probs_output = {};
- if (llama.params.n_probs > 0) {
+ if (llama.params.sampling_params.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());