diff options
Diffstat (limited to 'examples/server')
-rw-r--r-- | examples/server/server.cpp | 73 |
1 files changed, 20 insertions, 53 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ee0ababb..28b3f3f5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1,7 +1,6 @@ #include "common.h" #include "llama.h" #include "build-info.h" -#include "grammar-parser.h" #ifndef NDEBUG // crash the server in debug mode, otherwise send an http 500 error @@ -195,17 +194,13 @@ struct llama_server_context json prompt; std::vector<llama_token> embd; - std::vector<llama_token> last_n_tokens; llama_model *model = nullptr; llama_context *ctx = nullptr; gpt_params params; - llama_sampling_context ctx_sampling; + llama_sampling_context *ctx_sampling; int n_ctx; - grammar_parser::parse_state parsed_grammar; - llama_grammar *grammar = nullptr; - bool truncated = false; bool stopped_eos = false; bool stopped_word = false; @@ -252,11 +247,10 @@ struct llama_server_context n_remain = 0; n_past = 0; - if (grammar != nullptr) { - llama_grammar_free(grammar); - grammar = nullptr; - ctx_sampling = llama_sampling_context_init(params, NULL); + if (ctx_sampling != nullptr) { + llama_sampling_free(ctx_sampling); } + ctx_sampling = llama_sampling_init(params); } bool loadModel(const gpt_params ¶ms_) @@ -269,8 +263,6 @@ struct llama_server_context return false; } n_ctx = llama_n_ctx(ctx); - last_n_tokens.resize(n_ctx); - std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); return true; } @@ -321,27 +313,7 @@ struct llama_server_context bool loadGrammar() { - if (!params.grammar.empty()) { - parsed_grammar = grammar_parser::parse(params.grammar.c_str()); - // will be empty (default) if there are parse errors - if (parsed_grammar.rules.empty()) { - LOG_ERROR("grammar parse error", {{"grammar", params.grammar}}); - return false; - } - grammar_parser::print_grammar(stderr, parsed_grammar); - - { - auto it = params.sampling_params.logit_bias.find(llama_token_eos(ctx)); - if (it != params.sampling_params.logit_bias.end() && it->second == -INFINITY) { - LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {}); - } - } - - std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules()); - grammar = llama_grammar_init( - grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); - } - ctx_sampling = llama_sampling_context_init(params, grammar); + ctx_sampling = llama_sampling_init(params); return true; } @@ -383,7 +355,7 @@ struct llama_server_context std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); + std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin()); LOG_VERBOSE("input truncated", { {"n_ctx", params.n_ctx}, @@ -398,8 +370,8 @@ struct llama_server_context else { const size_t ps = num_prompt_tokens; - std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); + std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps); } // compare the evaluated prompt with the new prompt @@ -443,7 +415,7 @@ struct llama_server_context std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); - std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), last_n_tokens.begin()); + std::copy(prompt_tokens.end() - n_ctx, prompt_tokens.end(), ctx_sampling->prev.begin()); LOG_VERBOSE("input truncated", { {"n_ctx", n_ctx}, @@ -458,8 +430,8 @@ struct llama_server_context else { const size_t ps = num_prompt_tokens; - std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); - std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); + std::fill(ctx_sampling->prev.begin(), ctx_sampling->prev.end() - ps, 0); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), ctx_sampling->prev.end() - ps); } // compare the evaluated prompt with the new prompt @@ -554,27 +526,24 @@ struct llama_server_context { // out of user input, sample next token - std::vector<llama_token_data> candidates; - candidates.reserve(llama_n_vocab(model)); - - result.tok = llama_sampling_sample(ctx, NULL, ctx_sampling, last_n_tokens, candidates); + result.tok = llama_sampling_sample(ctx_sampling, ctx, NULL); - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + llama_token_data_array cur_p = { ctx_sampling->cur.data(), ctx_sampling->cur.size(), false }; const int32_t n_probs = params.sampling_params.n_probs; if (params.sampling_params.temp <= 0 && n_probs > 0) { // For llama_sample_token_greedy we need to sort candidates - llama_sample_softmax(ctx, &candidates_p); + llama_sample_softmax(ctx, &cur_p); } - for (size_t i = 0; i < std::min(candidates_p.size, (size_t)n_probs); ++i) + for (size_t i = 0; i < std::min(cur_p.size, (size_t)n_probs); ++i) { - result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p}); + result.probs.push_back({cur_p.data[i].id, cur_p.data[i].p}); } - last_n_tokens.erase(last_n_tokens.begin()); - last_n_tokens.push_back(result.tok); + llama_sampling_accept(ctx_sampling, ctx, result.tok); + if (tg) { num_tokens_predicted++; } @@ -1235,7 +1204,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla } } - llama.ctx_sampling = llama_sampling_context_init(llama.params, llama.grammar); + llama.ctx_sampling = llama_sampling_init(llama.params); LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); } @@ -1793,9 +1762,7 @@ int main(int argc, char **argv) return 1; } - if (llama.grammar != nullptr) { - llama_grammar_free(llama.grammar); - } + llama_sampling_free(llama.ctx_sampling); llama_backend_free(); return 0; |