diff options
author | vvhg1 <94630311+vvhg1@users.noreply.github.com> | 2023-10-02 09:42:02 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-02 10:42:02 +0300 |
commit | c97f01c362ac102c6994edb80008f8608539553a (patch) | |
tree | c73baa81f489587329c3f768a3b940e353233012 /examples/server/server.cpp | |
parent | f5ef5cfb18148131fcf45bdd2331f0db5ab7c3d0 (diff) |
infill : add new example + extend server API (#3296)
* vvhg-code-infill (#1)
* infill in separate example (#2)
* reverted changes to main and added infill example
* cleanup
* naming improvement
* make : add missing blank line
* fix missing semicolon
* brought infill up to current main code
* cleanup
---------
Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com>
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r-- | examples/server/server.cpp | 206 |
1 files changed, 206 insertions, 0 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index fe9a4255..6dda5e36 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -342,6 +342,70 @@ struct llama_server_context return true; } + void loadInfill() + { + auto prefix_tokens = tokenize(params.input_prefix, true); // always add BOS + auto suffix_tokens = tokenize(params.input_suffix, true); // always add BOS + prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx)); + prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx)); + prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); + prefix_tokens.push_back(llama_token_middle(ctx)); + auto prompt_tokens = prefix_tokens; + + num_prompt_tokens = prompt_tokens.size(); + + if (params.n_keep < 0) + { + params.n_keep = (int)num_prompt_tokens; + } + params.n_keep = std::min(params.n_ctx - 4, params.n_keep); + + // if input prompt is too big, truncate like normal + if (num_prompt_tokens >= (size_t)params.n_ctx) + { + printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens); + // todo we probably want to cut from both sides + const int n_left = (params.n_ctx - params.n_keep) / 2; + std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep); + const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left; + new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end()); + std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin()); + + LOG_VERBOSE("input truncated", { + {"n_ctx", params.n_ctx}, + {"n_keep", params.n_keep}, + {"n_left", n_left}, + {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())}, + }); + + truncated = true; + prompt_tokens = new_tokens; + } + else + { + const size_t ps = num_prompt_tokens; + std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0); + std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps); + } + + // compare the evaluated prompt with the new prompt + n_past = common_part(embd, prompt_tokens); + embd = prompt_tokens; + if (n_past == num_prompt_tokens) + { + // we have to evaluate at least 1 token to generate logits. + printf("we have to evaluate at least 1 token to generate logits\n"); + n_past--; + } + + LOG_VERBOSE("prompt ingested", { + {"n_past", n_past}, + {"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)}, + {"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())}, + }); + + has_next_token = true; + } void loadPrompt() { auto prompt_tokens = tokenize(prompt, true); // always add BOS @@ -1219,6 +1283,27 @@ static void parse_options_completion(const json &body, llama_server_context &lla LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama)); } +static void parse_options_infill(const json &body, llama_server_context &llama) +{ + if (body.count("input_prefix") != 0) + { + llama.params.input_prefix = body["input_prefix"]; + } + else + { + llama.params.input_prefix = ""; + } + if (body.count("input_suffix") != 0) + { + llama.params.input_suffix = body["input_suffix"]; + } + else + { + llama.params.input_suffix = ""; + } + parse_options_completion(body, llama); +} + static void log_server_request(const Request &req, const Response &res) { LOG_INFO("request", { @@ -1519,6 +1604,127 @@ int main(int argc, char **argv) res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); } }); + svr.Post("/infill", [&llama](const Request &req, Response &res) + { + auto lock = llama.lock(); + + llama.rewind(); + + llama_reset_timings(llama.ctx); + + parse_options_infill(json::parse(req.body), llama); + + if (!llama.loadGrammar()) + { + res.status = 400; + return; + } + llama.loadInfill(); + llama.beginCompletion(); + const auto chunked_content_provider = [&](size_t, DataSink & sink) { + size_t sent_count = 0; + size_t sent_token_probs_index = 0; + + while (llama.has_next_token) { + const completion_token_output token_with_probs = llama.doCompletion(); + if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) { + continue; + } + const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok); + + size_t pos = std::min(sent_count, llama.generated_text.size()); + + const std::string str_test = llama.generated_text.substr(pos); + bool is_stop_full = false; + size_t stop_pos = + llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); + if (stop_pos != std::string::npos) { + is_stop_full = true; + llama.generated_text.erase( + llama.generated_text.begin() + pos + stop_pos, + llama.generated_text.end()); + pos = std::min(sent_count, llama.generated_text.size()); + } else { + is_stop_full = false; + stop_pos = llama.findStoppingStrings(str_test, token_text.size(), + STOP_PARTIAL); + } + + if ( + stop_pos == std::string::npos || + // Send rest of the text if we are at the end of the generation + (!llama.has_next_token && !is_stop_full && stop_pos > 0) + ) { + const std::string to_send = llama.generated_text.substr(pos, std::string::npos); + + sent_count += to_send.size(); + + std::vector<completion_token_output> probs_output = {}; + + if (llama.params.n_probs > 0) { + const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false); + size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size()); + size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size()); + if (probs_pos < probs_stop_pos) { + probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos); + } + sent_token_probs_index = probs_stop_pos; + } + + const json data = format_partial_response(llama, to_send, probs_output); + + const std::string str = + "data: " + + data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; + + LOG_VERBOSE("data stream", { + { "to_send", str } + }); + + if (!sink.write(str.data(), str.size())) { + LOG_VERBOSE("stream closed", {}); + llama_print_timings(llama.ctx); + return false; + } + } + + if (!llama.has_next_token) { + // Generation is done, send extra information. + const json data = format_final_response( + llama, + "", + std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.begin() + sent_token_probs_index) + ); + + const std::string str = + "data: " + + data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; + + LOG_VERBOSE("data stream", { + { "to_send", str } + }); + + if (!sink.write(str.data(), str.size())) { + LOG_VERBOSE("stream closed", {}); + llama_print_timings(llama.ctx); + return false; + } + } + } + + llama_print_timings(llama.ctx); + sink.done(); + return true; + }; + const auto on_complete = [&](bool) { + llama.mutex.unlock(); + }; + lock.release(); + res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); + }); + svr.Get("/model.json", [&llama](const Request &, Response &res) { const json data = format_generation_settings(llama); |