summaryrefslogtreecommitdiff
path: root/examples/server/server.cpp
diff options
context:
space:
mode:
authorvvhg1 <94630311+vvhg1@users.noreply.github.com>2023-10-02 09:42:02 +0200
committerGitHub <noreply@github.com>2023-10-02 10:42:02 +0300
commitc97f01c362ac102c6994edb80008f8608539553a (patch)
treec73baa81f489587329c3f768a3b940e353233012 /examples/server/server.cpp
parentf5ef5cfb18148131fcf45bdd2331f0db5ab7c3d0 (diff)
infill : add new example + extend server API (#3296)
* vvhg-code-infill (#1) * infill in separate example (#2) * reverted changes to main and added infill example * cleanup * naming improvement * make : add missing blank line * fix missing semicolon * brought infill up to current main code * cleanup --------- Co-authored-by: Cebtenzzre <cebtenzzre@gmail.com>
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r--examples/server/server.cpp206
1 files changed, 206 insertions, 0 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index fe9a4255..6dda5e36 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -342,6 +342,70 @@ struct llama_server_context
return true;
}
+ void loadInfill()
+ {
+ auto prefix_tokens = tokenize(params.input_prefix, true); // always add BOS
+ auto suffix_tokens = tokenize(params.input_suffix, true); // always add BOS
+ prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx));
+ prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx));
+ prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
+ prefix_tokens.push_back(llama_token_middle(ctx));
+ auto prompt_tokens = prefix_tokens;
+
+ num_prompt_tokens = prompt_tokens.size();
+
+ if (params.n_keep < 0)
+ {
+ params.n_keep = (int)num_prompt_tokens;
+ }
+ params.n_keep = std::min(params.n_ctx - 4, params.n_keep);
+
+ // if input prompt is too big, truncate like normal
+ if (num_prompt_tokens >= (size_t)params.n_ctx)
+ {
+ printf("Input prompt is too big, truncating. Can only take %d tokens but got %zu\n", params.n_ctx, num_prompt_tokens);
+ // todo we probably want to cut from both sides
+ const int n_left = (params.n_ctx - params.n_keep) / 2;
+ std::vector<llama_token> new_tokens(prompt_tokens.begin(), prompt_tokens.begin() + params.n_keep);
+ const int erased_blocks = (num_prompt_tokens - params.n_keep - n_left - 1) / n_left;
+ new_tokens.insert(new_tokens.end(), prompt_tokens.begin() + params.n_keep + erased_blocks * n_left, prompt_tokens.end());
+ std::copy(prompt_tokens.end() - params.n_ctx, prompt_tokens.end(), last_n_tokens.begin());
+
+ LOG_VERBOSE("input truncated", {
+ {"n_ctx", params.n_ctx},
+ {"n_keep", params.n_keep},
+ {"n_left", n_left},
+ {"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
+ });
+
+ truncated = true;
+ prompt_tokens = new_tokens;
+ }
+ else
+ {
+ const size_t ps = num_prompt_tokens;
+ std::fill(last_n_tokens.begin(), last_n_tokens.end() - ps, 0);
+ std::copy(prompt_tokens.begin(), prompt_tokens.end(), last_n_tokens.end() - ps);
+ }
+
+ // compare the evaluated prompt with the new prompt
+ n_past = common_part(embd, prompt_tokens);
+ embd = prompt_tokens;
+ if (n_past == num_prompt_tokens)
+ {
+ // we have to evaluate at least 1 token to generate logits.
+ printf("we have to evaluate at least 1 token to generate logits\n");
+ n_past--;
+ }
+
+ LOG_VERBOSE("prompt ingested", {
+ {"n_past", n_past},
+ {"cached", tokens_to_str(ctx, embd.cbegin(), embd.cbegin() + n_past)},
+ {"to_eval", tokens_to_str(ctx, embd.cbegin() + n_past, embd.cend())},
+ });
+
+ has_next_token = true;
+ }
void loadPrompt()
{
auto prompt_tokens = tokenize(prompt, true); // always add BOS
@@ -1219,6 +1283,27 @@ static void parse_options_completion(const json &body, llama_server_context &lla
LOG_VERBOSE("completion parameters parsed", format_generation_settings(llama));
}
+static void parse_options_infill(const json &body, llama_server_context &llama)
+{
+ if (body.count("input_prefix") != 0)
+ {
+ llama.params.input_prefix = body["input_prefix"];
+ }
+ else
+ {
+ llama.params.input_prefix = "";
+ }
+ if (body.count("input_suffix") != 0)
+ {
+ llama.params.input_suffix = body["input_suffix"];
+ }
+ else
+ {
+ llama.params.input_suffix = "";
+ }
+ parse_options_completion(body, llama);
+}
+
static void log_server_request(const Request &req, const Response &res)
{
LOG_INFO("request", {
@@ -1519,6 +1604,127 @@ int main(int argc, char **argv)
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
} });
+ svr.Post("/infill", [&llama](const Request &req, Response &res)
+ {
+ auto lock = llama.lock();
+
+ llama.rewind();
+
+ llama_reset_timings(llama.ctx);
+
+ parse_options_infill(json::parse(req.body), llama);
+
+ if (!llama.loadGrammar())
+ {
+ res.status = 400;
+ return;
+ }
+ llama.loadInfill();
+ llama.beginCompletion();
+ const auto chunked_content_provider = [&](size_t, DataSink & sink) {
+ size_t sent_count = 0;
+ size_t sent_token_probs_index = 0;
+
+ while (llama.has_next_token) {
+ const completion_token_output token_with_probs = llama.doCompletion();
+ if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) {
+ continue;
+ }
+ const std::string token_text = llama_token_to_piece(llama.ctx, token_with_probs.tok);
+
+ size_t pos = std::min(sent_count, llama.generated_text.size());
+
+ const std::string str_test = llama.generated_text.substr(pos);
+ bool is_stop_full = false;
+ size_t stop_pos =
+ llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
+ if (stop_pos != std::string::npos) {
+ is_stop_full = true;
+ llama.generated_text.erase(
+ llama.generated_text.begin() + pos + stop_pos,
+ llama.generated_text.end());
+ pos = std::min(sent_count, llama.generated_text.size());
+ } else {
+ is_stop_full = false;
+ stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
+ STOP_PARTIAL);
+ }
+
+ if (
+ stop_pos == std::string::npos ||
+ // Send rest of the text if we are at the end of the generation
+ (!llama.has_next_token && !is_stop_full && stop_pos > 0)
+ ) {
+ const std::string to_send = llama.generated_text.substr(pos, std::string::npos);
+
+ sent_count += to_send.size();
+
+ std::vector<completion_token_output> probs_output = {};
+
+ if (llama.params.n_probs > 0) {
+ const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
+ size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
+ size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
+ if (probs_pos < probs_stop_pos) {
+ probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
+ }
+ sent_token_probs_index = probs_stop_pos;
+ }
+
+ const json data = format_partial_response(llama, to_send, probs_output);
+
+ const std::string str =
+ "data: " +
+ data.dump(-1, ' ', false, json::error_handler_t::replace) +
+ "\n\n";
+
+ LOG_VERBOSE("data stream", {
+ { "to_send", str }
+ });
+
+ if (!sink.write(str.data(), str.size())) {
+ LOG_VERBOSE("stream closed", {});
+ llama_print_timings(llama.ctx);
+ return false;
+ }
+ }
+
+ if (!llama.has_next_token) {
+ // Generation is done, send extra information.
+ const json data = format_final_response(
+ llama,
+ "",
+ std::vector<completion_token_output>(llama.generated_token_probs.begin(), llama.generated_token_probs.begin() + sent_token_probs_index)
+ );
+
+ const std::string str =
+ "data: " +
+ data.dump(-1, ' ', false, json::error_handler_t::replace) +
+ "\n\n";
+
+ LOG_VERBOSE("data stream", {
+ { "to_send", str }
+ });
+
+ if (!sink.write(str.data(), str.size())) {
+ LOG_VERBOSE("stream closed", {});
+ llama_print_timings(llama.ctx);
+ return false;
+ }
+ }
+ }
+
+ llama_print_timings(llama.ctx);
+ sink.done();
+ return true;
+ };
+ const auto on_complete = [&](bool) {
+ llama.mutex.unlock();
+ };
+ lock.release();
+ res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
+ });
+
svr.Get("/model.json", [&llama](const Request &, Response &res)
{
const json data = format_generation_settings(llama);