summaryrefslogtreecommitdiff
path: root/examples/server/server.cpp
diff options
context:
space:
mode:
authorAdrian Hesketh <a-h@users.noreply.github.com>2023-11-01 09:28:28 +0000
committerGitHub <noreply@github.com>2023-11-01 11:28:28 +0200
commitca190bca8e844d171020d6147687e71472d71734 (patch)
tree449b9b27e0f8b7b02522536e5f04a5831094cadc /examples/server/server.cpp
parent71e3718abdb2771b50c9606d3a7569623a0b0afe (diff)
server : re-enable completion and embedded at the same time (#3876)
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r--examples/server/server.cpp16
1 files changed, 10 insertions, 6 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index c163c7f8..47ae0d55 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -149,6 +149,7 @@ struct task_server {
task_type type;
json data;
bool infill_mode = false;
+ bool embedding_mode = false;
};
struct task_result {
@@ -371,6 +372,7 @@ struct llama_client_slot
std::vector<completion_token_output> generated_token_probs;
bool infill = false;
+ bool embedding = false;
bool has_next_token = true;
bool truncated = false;
bool stopped_eos = false;
@@ -1244,13 +1246,14 @@ struct llama_server_context
queue_results.push_back(res);
}
- int request_completion(json data, bool infill)
+ int request_completion(json data, bool infill, bool embedding)
{
std::lock_guard<std::mutex> lock(mutex_tasks);
task_server task;
task.id = id_gen++;
task.data = data;
task.infill_mode = infill;
+ task.embedding_mode = embedding;
task.type = COMPLETION_TASK;
queue_tasks.push_back(task);
return task.id;
@@ -1376,7 +1379,7 @@ struct llama_server_context
{
LOG_TEE("slot unavailable\n");
// send error result
- send_error(task.id, "slot unavaliable");
+ send_error(task.id, "slot unavailable");
return;
}
@@ -1388,6 +1391,7 @@ struct llama_server_context
slot->reset();
slot->infill = task.infill_mode;
+ slot->embedding = task.embedding_mode;
slot->task_id = task.id;
if (!launch_slot_with_data(slot, task.data))
@@ -1695,7 +1699,7 @@ struct llama_server_context
}
// prompt evaluated for embedding
- if (params.embedding)
+ if (slot.embedding)
{
send_embedding(slot);
slot.release();
@@ -2274,7 +2278,7 @@ int main(int argc, char **argv)
svr.Post("/completion", [&llama](const httplib::Request &req, httplib::Response &res)
{
json data = json::parse(req.body);
- const int task_id = llama.request_completion(data, false);
+ const int task_id = llama.request_completion(data, false, false);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.next_result(task_id);
@@ -2329,7 +2333,7 @@ int main(int argc, char **argv)
svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
{
json data = json::parse(req.body);
- const int task_id = llama.request_completion(data, true);
+ const int task_id = llama.request_completion(data, true, false);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.next_result(task_id);
@@ -2433,7 +2437,7 @@ int main(int argc, char **argv)
{
prompt = "";
}
- const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false);
+ const int task_id = llama.request_completion({ {"prompt", prompt}, { "n_predict", 0} }, false, true);
task_result result = llama.next_result(task_id);
return res.set_content(result.result_json.dump(), "application/json");
});