diff options
author | Xiao-Yong Jin <jinxiaoyong@gmail.com> | 2023-08-23 02:12:12 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-23 15:12:12 +0800 |
commit | b8ad1b66b23f9b2e6e4531e9a62753323036a556 (patch) | |
tree | 72799c23c8335ee997ab579c41313449ca2e4e91 /examples/server/server.cpp | |
parent | f5fe98d11bdf9e7797bcfb05c0c3601ffc4b9d26 (diff) |
server : allow json array in prompt or content for direct token input (#2306)
* server: allow json array in prompt or content
We accept an array of strings and numbers representing tokens,
in addition to the current string valued prompt or content.
This allows direct token input, so that any special tokens
can be processed and used at the frontend during the construction
of the json data, before sending to the server. And the server
does not need to know or parse special tokens from textual input.
With this, we can use EOS and BOS used in llama-2-chat models.
* server: use tokenizePrompt(json) and default "" if empty prompt
* server: fix prompt check
* server: tokenize endpoint no longer adds BOS
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r-- | examples/server/server.cpp | 80 |
1 files changed, 73 insertions, 7 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e5bc52cd..1e6d10c1 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -190,6 +190,7 @@ struct llama_server_context size_t n_past = 0; size_t n_remain = 0; + json prompt; std::vector<llama_token> embd; std::vector<llama_token> last_n_tokens; @@ -267,6 +268,53 @@ struct llama_server_context return true; } + std::vector<llama_token> tokenize(json json_prompt, bool add_bos) + { + // If `add_bos` is true, we only add BOS, when json_prompt is a string, + // or the first element of the json_prompt array is a string. + std::vector<llama_token> prompt_tokens; + + if (json_prompt.is_array()) + { + bool first = true; + for (const auto& p : json_prompt) + { + if (p.is_string()) + { + auto s = p.template get<std::string>(); + std::vector<llama_token> p; + if (first) + { + s.insert(0, 1, ' '); // add a space if it's the first + p = ::llama_tokenize(ctx, s, add_bos); + first = false; + } + else + { + p = ::llama_tokenize(ctx, s, false); + } + prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); + } + else + { + if (first) + { + first = false; + } + prompt_tokens.push_back(p.template get<llama_token>()); + } + } + } + else + { + auto s = json_prompt.template get<std::string>(); + s.insert(0, 1, ' '); // always add a first space + prompt_tokens = ::llama_tokenize(ctx, s, add_bos); + } + + return prompt_tokens; + } + bool loadGrammar() { if (!params.grammar.empty()) { @@ -294,8 +342,8 @@ struct llama_server_context void loadPrompt() { - params.prompt.insert(0, 1, ' '); // always add a first space - std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true); + auto prompt_tokens = tokenize(prompt, true); // always add BOS + num_prompt_tokens = prompt_tokens.size(); if (params.n_keep < 0) @@ -1016,7 +1064,7 @@ static json format_final_response(llama_server_context &llama, const std::string {"tokens_predicted", llama.num_tokens_predicted}, {"tokens_evaluated", llama.num_prompt_tokens}, {"generation_settings", format_generation_settings(llama)}, - {"prompt", llama.params.prompt}, + {"prompt", llama.prompt}, {"truncated", llama.truncated}, {"stopped_eos", llama.stopped_eos}, {"stopped_word", llama.stopped_word}, @@ -1085,10 +1133,18 @@ static void parse_options_completion(const json &body, llama_server_context &lla llama.params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl); llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep); llama.params.seed = json_value(body, "seed", default_params.seed); - llama.params.prompt = json_value(body, "prompt", default_params.prompt); llama.params.grammar = json_value(body, "grammar", default_params.grammar); llama.params.n_probs = json_value(body, "n_probs", default_params.n_probs); + if (body.count("prompt") != 0) + { + llama.prompt = body["prompt"]; + } + else + { + llama.prompt = ""; + } + llama.params.logit_bias.clear(); if (json_value(body, "ignore_eos", false)) { @@ -1345,8 +1401,11 @@ int main(int argc, char **argv) auto lock = llama.lock(); const json body = json::parse(req.body); - const std::string content = json_value<std::string>(body, "content", ""); - const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false); + std::vector<llama_token> tokens; + if (body.count("content") != 0) + { + tokens = llama.tokenize(body["content"], false); + } const json data = format_tokenizer_response(tokens); return res.set_content(data.dump(), "application/json"); }); @@ -1358,7 +1417,14 @@ int main(int argc, char **argv) llama.rewind(); llama_reset_timings(llama.ctx); - llama.params.prompt = json_value<std::string>(body, "content", ""); + if (body.count("content") != 0) + { + llama.prompt = body["content"]; + } + else + { + llama.prompt = ""; + } llama.params.n_predict = 0; llama.loadPrompt(); llama.beginCompletion(); |