summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/server/README.md2
-rw-r--r--examples/server/server.cpp80
2 files changed, 74 insertions, 8 deletions
diff --git a/examples/server/README.md b/examples/server/README.md
index 4d97db2e..77997f98 100644
--- a/examples/server/README.md
+++ b/examples/server/README.md
@@ -126,7 +126,7 @@ node .
`stream`: It allows receiving each predicted token in real-time instead of waiting for the completion to finish. To enable this, set to `true`.
- `prompt`: Provide a prompt. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate. A space is inserted in the front like main.cpp does.
+ `prompt`: Provide a prompt as a string, or as an array of strings and numbers representing tokens. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate. If the prompt is a string, or an array with the first element given as a string, a space is inserted in the front like main.cpp does.
`stop`: Specify a JSON array of stopping strings.
These words will not be included in the completion, so make sure to add them to the prompt for the next iteration (default: []).
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index e5bc52cd..1e6d10c1 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -190,6 +190,7 @@ struct llama_server_context
size_t n_past = 0;
size_t n_remain = 0;
+ json prompt;
std::vector<llama_token> embd;
std::vector<llama_token> last_n_tokens;
@@ -267,6 +268,53 @@ struct llama_server_context
return true;
}
+ std::vector<llama_token> tokenize(json json_prompt, bool add_bos)
+ {
+ // If `add_bos` is true, we only add BOS, when json_prompt is a string,
+ // or the first element of the json_prompt array is a string.
+ std::vector<llama_token> prompt_tokens;
+
+ if (json_prompt.is_array())
+ {
+ bool first = true;
+ for (const auto& p : json_prompt)
+ {
+ if (p.is_string())
+ {
+ auto s = p.template get<std::string>();
+ std::vector<llama_token> p;
+ if (first)
+ {
+ s.insert(0, 1, ' '); // add a space if it's the first
+ p = ::llama_tokenize(ctx, s, add_bos);
+ first = false;
+ }
+ else
+ {
+ p = ::llama_tokenize(ctx, s, false);
+ }
+ prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
+ }
+ else
+ {
+ if (first)
+ {
+ first = false;
+ }
+ prompt_tokens.push_back(p.template get<llama_token>());
+ }
+ }
+ }
+ else
+ {
+ auto s = json_prompt.template get<std::string>();
+ s.insert(0, 1, ' '); // always add a first space
+ prompt_tokens = ::llama_tokenize(ctx, s, add_bos);
+ }
+
+ return prompt_tokens;
+ }
+
bool loadGrammar()
{
if (!params.grammar.empty()) {
@@ -294,8 +342,8 @@ struct llama_server_context
void loadPrompt()
{
- params.prompt.insert(0, 1, ' '); // always add a first space
- std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
+ auto prompt_tokens = tokenize(prompt, true); // always add BOS
+
num_prompt_tokens = prompt_tokens.size();
if (params.n_keep < 0)
@@ -1016,7 +1064,7 @@ static json format_final_response(llama_server_context &llama, const std::string
{"tokens_predicted", llama.num_tokens_predicted},
{"tokens_evaluated", llama.num_prompt_tokens},
{"generation_settings", format_generation_settings(llama)},
- {"prompt", llama.params.prompt},
+ {"prompt", llama.prompt},
{"truncated", llama.truncated},
{"stopped_eos", llama.stopped_eos},
{"stopped_word", llama.stopped_word},
@@ -1085,10 +1133,18 @@ static void parse_options_completion(const json &body, llama_server_context &lla
llama.params.penalize_nl = json_value(body, "penalize_nl", default_params.penalize_nl);
llama.params.n_keep = json_value(body, "n_keep", default_params.n_keep);
llama.params.seed = json_value(body, "seed", default_params.seed);
- llama.params.prompt = json_value(body, "prompt", default_params.prompt);
llama.params.grammar = json_value(body, "grammar", default_params.grammar);
llama.params.n_probs = json_value(body, "n_probs", default_params.n_probs);
+ if (body.count("prompt") != 0)
+ {
+ llama.prompt = body["prompt"];
+ }
+ else
+ {
+ llama.prompt = "";
+ }
+
llama.params.logit_bias.clear();
if (json_value(body, "ignore_eos", false))
{
@@ -1345,8 +1401,11 @@ int main(int argc, char **argv)
auto lock = llama.lock();
const json body = json::parse(req.body);
- const std::string content = json_value<std::string>(body, "content", "");
- const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false);
+ std::vector<llama_token> tokens;
+ if (body.count("content") != 0)
+ {
+ tokens = llama.tokenize(body["content"], false);
+ }
const json data = format_tokenizer_response(tokens);
return res.set_content(data.dump(), "application/json"); });
@@ -1358,7 +1417,14 @@ int main(int argc, char **argv)
llama.rewind();
llama_reset_timings(llama.ctx);
- llama.params.prompt = json_value<std::string>(body, "content", "");
+ if (body.count("content") != 0)
+ {
+ llama.prompt = body["content"];
+ }
+ else
+ {
+ llama.prompt = "";
+ }
llama.params.n_predict = 0;
llama.loadPrompt();
llama.beginCompletion();