summaryrefslogtreecommitdiff
path: root/examples/server/server.cpp
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-08-12 15:14:32 +0200
committerGitHub <noreply@github.com>2024-08-12 15:14:32 +0200
commit8f43e551038af2547b5c01d0e9edd641c0e4bd29 (patch)
tree07a4373620a9381d0b5c7189a475990a6feb48a5 /examples/server/server.cpp
parentf5d1af61d79fb53ccfbac2e665e43208c07b083d (diff)
Merge mainline - Aug 12 2024 (#17)
* Merge mainline * Fix after merge * Remove CI check --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'examples/server/server.cpp')
-rw-r--r--examples/server/server.cpp75
1 files changed, 72 insertions, 3 deletions
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 7813a295..360f571e 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -78,6 +78,7 @@ enum server_task_type {
SERVER_TASK_TYPE_SLOT_SAVE,
SERVER_TASK_TYPE_SLOT_RESTORE,
SERVER_TASK_TYPE_SLOT_ERASE,
+ SERVER_TASK_TYPE_SET_LORA,
};
struct server_task {
@@ -622,6 +623,7 @@ struct server_response {
struct server_context {
llama_model * model = nullptr;
llama_context * ctx = nullptr;
+ std::vector<llama_lora_adapter_container> lora_adapters;
gpt_params params;
@@ -677,7 +679,11 @@ struct server_context {
// dedicate one sequence to the system prompt
params.n_parallel += 1;
- std::tie(model, ctx) = llama_init_from_gpt_params(params);
+ llama_init_result llama_init = llama_init_from_gpt_params(params);
+
+ model = llama_init.model;
+ ctx = llama_init.context;
+ lora_adapters = llama_init.lora_adapters;
params.n_parallel -= 1; // but be sneaky about it
if (model == nullptr) {
LOG_ERROR("unable to load model", {{"model", params.model}});
@@ -900,7 +906,7 @@ struct server_context {
slot.params.stream = json_value(data, "stream", false);
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
- slot.params.n_predict = json_value(data, "n_predict", default_params.n_predict);
+ slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
@@ -969,6 +975,8 @@ struct server_context {
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
slot.prompt = *prompt;
+ } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
+ slot.prompt = prompt->at(0);
} else {
send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST);
return false;
@@ -1847,6 +1855,14 @@ struct server_context {
};
queue_results.send(result);
} break;
+ case SERVER_TASK_TYPE_SET_LORA:
+ {
+ llama_lora_adapters_apply(ctx, lora_adapters);
+ server_task_result result;
+ result.id = task.id;
+ result.data = json{{ "success", true }};
+ queue_results.send(result);
+ } break;
}
}
@@ -3325,6 +3341,55 @@ int main(int argc, char ** argv) {
return res.set_content(root.dump(), "application/json; charset=utf-8");
};
+ const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) {
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
+ json result = json::array();
+ for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) {
+ auto & la = ctx_server.lora_adapters[i];
+ result.push_back({
+ {"id", i},
+ {"path", la.path},
+ {"scale", la.scale},
+ });
+ }
+ res.set_content(result.dump(), "application/json");
+ res.status = 200; // HTTP OK
+ };
+
+ const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
+
+ const std::vector<json> body = json::parse(req.body);
+ int max_idx = ctx_server.lora_adapters.size();
+
+ // clear existing value
+ for (auto & la : ctx_server.lora_adapters) {
+ la.scale = 0.0f;
+ }
+
+ // set value
+ for (auto entry : body) {
+ int id = entry.at("id");
+ float scale = entry.at("scale");
+ if (0 <= id && id < max_idx) {
+ ctx_server.lora_adapters[id].scale = scale;
+ } else {
+ throw std::runtime_error("invalid adapter id");
+ }
+ }
+
+ server_task task;
+ task.type = SERVER_TASK_TYPE_SET_LORA;
+ const int id_task = ctx_server.queue_tasks.post(task);
+ ctx_server.queue_results.add_waiting_task_id(id_task);
+
+ server_task_result result = ctx_server.queue_results.recv(id_task);
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
+
+ res.set_content(result.data.dump(), "application/json");
+ res.status = 200; // HTTP OK
+ };
+
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
@@ -3363,7 +3428,6 @@ int main(int argc, char ** argv) {
// register API routes
svr->Get ("/health", handle_health);
- svr->Get ("/slots", handle_slots);
svr->Get ("/metrics", handle_metrics);
svr->Get ("/props", handle_props);
svr->Get ("/v1/models", handle_models);
@@ -3378,6 +3442,11 @@ int main(int argc, char ** argv) {
svr->Post("/v1/embeddings", handle_embeddings);
svr->Post("/tokenize", handle_tokenize);
svr->Post("/detokenize", handle_detokenize);
+ // LoRA adapters hotswap
+ svr->Get ("/lora-adapters", handle_lora_adapters_list);
+ svr->Post("/lora-adapters", handle_lora_adapters_apply);
+ // Save & load slots
+ svr->Get ("/slots", handle_slots);
if (!params.slot_save_path.empty()) {
// only enable slot endpoints if slot_save_path is set
svr->Post("/slots/:id_slot", handle_slots_action);