summaryrefslogtreecommitdiff
path: root/examples/server/server.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r--examples/server/server.cpp33
1 files changed, 7 insertions, 26 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 99660455..a04f1910 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -279,7 +279,7 @@ struct llama_server_context
grammar_parser::print_grammar(stderr, parsed_grammar);
{
- auto it = params.logit_bias.find(llama_token_eos());
+ auto it = params.logit_bias.find(llama_token_eos(ctx));
if (it != params.logit_bias.end() && it->second == -INFINITY) {
LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {});
}
@@ -402,7 +402,7 @@ struct llama_server_context
if (params.n_predict == 0)
{
has_next_token = false;
- result.tok = llama_token_eos();
+ result.tok = llama_token_eos(ctx);
return result;
}
@@ -442,7 +442,7 @@ struct llama_server_context
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
// Apply penalties
- float nl_logit = logits[llama_token_nl()];
+ float nl_logit = logits[llama_token_nl(ctx)];
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx);
llama_sample_repetition_penalty(ctx, &candidates_p,
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
@@ -452,7 +452,7 @@ struct llama_server_context
last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl)
{
- logits[llama_token_nl()] = nl_logit;
+ logits[llama_token_nl(ctx)] = nl_logit;
}
if (grammar != nullptr) {
@@ -515,7 +515,7 @@ struct llama_server_context
// decrement remaining sampling budget
--n_remain;
- if (!embd.empty() && embd.back() == llama_token_eos())
+ if (!embd.empty() && embd.back() == llama_token_eos(ctx))
{
// stopping_word = llama_token_to_str(ctx, embd.back());
has_next_token = false;
@@ -652,8 +652,6 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
fprintf(stdout, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
- fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
- fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
@@ -774,23 +772,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
}
params.n_ctx = std::stoi(argv[i]);
}
- else if (arg == "-gqa" || arg == "--gqa")
- {
- if (++i >= argc)
- {
- invalid_param = true;
- break;
- }
- params.n_gqa = std::stoi(argv[i]);
- }
- else if (arg == "-eps" || arg == "--rms-norm-eps") {
- if (++i >= argc)
- {
- invalid_param = true;
- break;
- }
- params.rms_norm_eps = std::stof(argv[i]);
- }
else if (arg == "--rope-freq-base")
{
if (++i >= argc)
@@ -968,7 +949,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
static json format_generation_settings(llama_server_context &llama)
{
- const auto eos_bias = llama.params.logit_bias.find(llama_token_eos());
+ const auto eos_bias = llama.params.logit_bias.find(llama_token_eos(llama.ctx));
const bool ignore_eos = eos_bias != llama.params.logit_bias.end() &&
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
@@ -1103,7 +1084,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla
llama.params.logit_bias.clear();
if (body.value("ignore_eos", false))
{
- llama.params.logit_bias[llama_token_eos()] = -INFINITY;
+ llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY;
}
const auto &logit_bias = body.find("logit_bias");