diff options
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r-- | examples/server/server.cpp | 33 |
1 files changed, 7 insertions, 26 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 99660455..a04f1910 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -279,7 +279,7 @@ struct llama_server_context grammar_parser::print_grammar(stderr, parsed_grammar); { - auto it = params.logit_bias.find(llama_token_eos()); + auto it = params.logit_bias.find(llama_token_eos(ctx)); if (it != params.logit_bias.end() && it->second == -INFINITY) { LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {}); } @@ -402,7 +402,7 @@ struct llama_server_context if (params.n_predict == 0) { has_next_token = false; - result.tok = llama_token_eos(); + result.tok = llama_token_eos(ctx); return result; } @@ -442,7 +442,7 @@ struct llama_server_context llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; // Apply penalties - float nl_logit = logits[llama_token_nl()]; + float nl_logit = logits[llama_token_nl(ctx)]; auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx); llama_sample_repetition_penalty(ctx, &candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, @@ -452,7 +452,7 @@ struct llama_server_context last_n_repeat, alpha_frequency, alpha_presence); if (!penalize_nl) { - logits[llama_token_nl()] = nl_logit; + logits[llama_token_nl(ctx)] = nl_logit; } if (grammar != nullptr) { @@ -515,7 +515,7 @@ struct llama_server_context // decrement remaining sampling budget --n_remain; - if (!embd.empty() && embd.back() == llama_token_eos()) + if (!embd.empty() && embd.back() == llama_token_eos(ctx)) { // stopping_word = llama_token_to_str(ctx, embd.back()); has_next_token = false; @@ -652,8 +652,6 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, fprintf(stdout, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled"); fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); - fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa); - fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps); fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base); fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale); fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); @@ -774,23 +772,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.n_ctx = std::stoi(argv[i]); } - else if (arg == "-gqa" || arg == "--gqa") - { - if (++i >= argc) - { - invalid_param = true; - break; - } - params.n_gqa = std::stoi(argv[i]); - } - else if (arg == "-eps" || arg == "--rms-norm-eps") { - if (++i >= argc) - { - invalid_param = true; - break; - } - params.rms_norm_eps = std::stof(argv[i]); - } else if (arg == "--rope-freq-base") { if (++i >= argc) @@ -968,7 +949,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, static json format_generation_settings(llama_server_context &llama) { - const auto eos_bias = llama.params.logit_bias.find(llama_token_eos()); + const auto eos_bias = llama.params.logit_bias.find(llama_token_eos(llama.ctx)); const bool ignore_eos = eos_bias != llama.params.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); @@ -1103,7 +1084,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla llama.params.logit_bias.clear(); if (body.value("ignore_eos", false)) { - llama.params.logit_bias[llama_token_eos()] = -INFINITY; + llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY; } const auto &logit_bias = body.find("logit_bias"); |