summaryrefslogtreecommitdiff
path: root/examples/server
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-06-10 14:59:55 +0300
committerGitHub <noreply@github.com>2024-06-10 14:59:55 +0300
commitd9da0e4986f121c727bdd9579a6688097b11602c (patch)
tree46bad40b3b41a0c817ab1573861e673161341e95 /examples/server
parent1f0dabda8d5c131f9d4632aa41de74317cdd61fb (diff)
server : improve "prompt" handling (#7847)
Diffstat (limited to 'examples/server')
-rw-r--r--examples/server/server.cpp30
1 files changed, 16 insertions, 14 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 6ffaa8d9..80714fa5 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -147,7 +147,7 @@ struct server_slot {
int32_t n_prompt_tokens = 0;
int32_t n_prompt_tokens_processed = 0;
- json prompt;
+ std::string prompt;
// when a task is submitted, we first tokenize the prompt and store it here
std::vector<llama_token> prompt_tokens;
@@ -822,13 +822,8 @@ struct server_context {
continue;
}
- // skip the slot if it does not contains prompt
- if (!slot.prompt.is_string()) {
- continue;
- }
-
// current slot's prompt
- std::string slot_prompt = slot.prompt.get<std::string>();
+ std::string slot_prompt = slot.prompt;
// length of the current slot's prompt
int slot_prompt_len = slot_prompt.size();
@@ -958,13 +953,16 @@ struct server_context {
if (!task.infill) {
const auto & prompt = data.find("prompt");
if (prompt == data.end()) {
- send_error(task, "Either \"prompt\" or \"messages\" must be provided", ERROR_TYPE_INVALID_REQUEST);
+ send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
return false;
- } else {
- slot.prompt = *prompt;
}
- if (slot.prompt.is_array() && slot.prompt.size() == 0) {
- send_error(task, "\"prompt\" cannot be an empty array", ERROR_TYPE_INVALID_REQUEST);
+
+ if (prompt->is_string()) {
+ slot.prompt = prompt->get<std::string>();
+ } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) {
+ slot.prompt = prompt->at(0).get<std::string>();
+ } else {
+ send_error(task, "\"prompt\" must be a string or an array of strings", ERROR_TYPE_INVALID_REQUEST);
return false;
}
}
@@ -1582,14 +1580,18 @@ struct server_context {
switch (task.type) {
case SERVER_TASK_TYPE_COMPLETION:
{
- int id_slot = json_value(task.data, "id_slot", -1);
- std::string prompt = json_value(task.data, "prompt", std::string());
+ const int id_slot = json_value(task.data, "id_slot", -1);
server_slot * slot;
if (id_slot != -1) {
slot = get_slot_by_id(id_slot);
} else {
+ std::string prompt;
+ if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
+ json_value(task.data, "prompt", std::string());
+ }
+
slot = get_available_slot(prompt);
}