summaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/server/server.cpp41
1 files changed, 26 insertions, 15 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 035eb24a..0aada8e2 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -25,6 +25,7 @@
#include <thread>
#include <mutex>
#include <chrono>
+#include <condition_variable>
#ifndef SERVER_VERBOSE
#define SERVER_VERBOSE 1
@@ -541,7 +542,9 @@ struct llama_server_context
std::vector<task_result> queue_results;
std::vector<task_multi> queue_multitasks;
std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks
+ std::condition_variable condition_tasks;
std::mutex mutex_results;
+ std::condition_variable condition_results;
~llama_server_context()
{
@@ -1169,7 +1172,7 @@ struct llama_server_context
void send_error(task_server& task, std::string error)
{
- std::lock_guard<std::mutex> lock(mutex_results);
+ std::unique_lock<std::mutex> lock(mutex_results);
task_result res;
res.id = task.id;
res.multitask_id = task.multitask_id;
@@ -1177,6 +1180,7 @@ struct llama_server_context
res.error = true;
res.result_json = { { "content", error } };
queue_results.push_back(res);
+ condition_results.notify_all();
}
void add_multi_task(int id, std::vector<int>& sub_ids)
@@ -1186,6 +1190,7 @@ struct llama_server_context
multi.id = id;
std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
queue_multitasks.push_back(multi);
+ condition_tasks.notify_one();
}
void update_multi_task(int multitask_id, int subtask_id, task_result& result)
@@ -1197,6 +1202,7 @@ struct llama_server_context
{
multitask.subtasks_remaining.erase(subtask_id);
multitask.results.push_back(result);
+ condition_tasks.notify_one();
}
}
}
@@ -1244,7 +1250,7 @@ struct llama_server_context
void send_partial_response(llama_client_slot &slot, completion_token_output tkn)
{
- std::lock_guard<std::mutex> lock(mutex_results);
+ std::unique_lock<std::mutex> lock(mutex_results);
task_result res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
@@ -1280,11 +1286,12 @@ struct llama_server_context
}
queue_results.push_back(res);
+ condition_results.notify_all();
}
void send_final_response(llama_client_slot &slot)
{
- std::lock_guard<std::mutex> lock(mutex_results);
+ std::unique_lock<std::mutex> lock(mutex_results);
task_result res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
@@ -1340,11 +1347,12 @@ struct llama_server_context
}
queue_results.push_back(res);
+ condition_results.notify_all();
}
void send_embedding(llama_client_slot &slot)
{
- std::lock_guard<std::mutex> lock(mutex_results);
+ std::unique_lock<std::mutex> lock(mutex_results);
task_result res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
@@ -1372,6 +1380,7 @@ struct llama_server_context
};
}
queue_results.push_back(res);
+ condition_results.notify_all();
}
int request_completion(json data, bool infill, bool embedding, int multitask_id)
@@ -1395,6 +1404,7 @@ struct llama_server_context
// otherwise, it's a single-prompt task, we actually queue it
queue_tasks.push_back(task);
+ condition_tasks.notify_one();
return task.id;
}
@@ -1402,13 +1412,10 @@ struct llama_server_context
{
while (true)
{
- std::this_thread::sleep_for(std::chrono::microseconds(5));
- std::lock_guard<std::mutex> lock(mutex_results);
-
- if (queue_results.empty())
- {
- continue;
- }
+ std::unique_lock<std::mutex> lock(mutex_results);
+ condition_results.wait(lock, [&]{
+ return !queue_results.empty();
+ });
for (int i = 0; i < (int) queue_results.size(); i++)
{
@@ -1504,12 +1511,13 @@ struct llama_server_context
void request_cancel(int task_id)
{
- std::lock_guard<std::mutex> lock(mutex_tasks);
+ std::unique_lock<std::mutex> lock(mutex_tasks);
task_server task;
task.id = id_gen++;
task.type = CANCEL_TASK;
task.target_id = task_id;
queue_tasks.push_back(task);
+ condition_tasks.notify_one();
}
int split_multiprompt_task(task_server& multiprompt_task)
@@ -1535,7 +1543,7 @@ struct llama_server_context
void process_tasks()
{
- std::lock_guard<std::mutex> lock(mutex_tasks);
+ std::unique_lock<std::mutex> lock(mutex_tasks);
while (!queue_tasks.empty())
{
task_server task = queue_tasks.front();
@@ -1607,6 +1615,7 @@ struct llama_server_context
std::lock_guard<std::mutex> lock(mutex_results);
queue_results.push_back(aggregate_result);
+ condition_results.notify_all();
queue_iterator = queue_multitasks.erase(queue_iterator);
}
@@ -1637,8 +1646,10 @@ struct llama_server_context
LOG_TEE("all slots are idle and system prompt is empty, clear the KV cache\n");
kv_cache_clear();
}
- // avoid 100% usage of cpu all time
- std::this_thread::sleep_for(std::chrono::milliseconds(5));
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ condition_tasks.wait(lock, [&]{
+ return !queue_tasks.empty();
+ });
}
for (llama_client_slot &slot : slots)