summaryrefslogtreecommitdiff
path: root/examples/server/utils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'examples/server/utils.hpp')
-rw-r--r--examples/server/utils.hpp43
1 files changed, 21 insertions, 22 deletions
diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp
index 63fde9c9..db6b3b74 100644
--- a/examples/server/utils.hpp
+++ b/examples/server/utils.hpp
@@ -118,36 +118,35 @@ static inline void server_log(const char * level, const char * function, int lin
// Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) {
- size_t alloc_size = 0;
- // vector holding all allocated string to be passed to llama_chat_apply_template
- std::vector<std::string> str(messages.size() * 2);
- std::vector<llama_chat_message> chat(messages.size());
+ std::vector<llama_chat_msg> chat;
for (size_t i = 0; i < messages.size(); ++i) {
const auto & curr_msg = messages[i];
- str[i*2 + 0] = json_value(curr_msg, "role", std::string(""));
- str[i*2 + 1] = json_value(curr_msg, "content", std::string(""));
- alloc_size += str[i*2 + 1].length();
- chat[i].role = str[i*2 + 0].c_str();
- chat[i].content = str[i*2 + 1].c_str();
- }
-
- const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
- std::vector<char> buf(alloc_size * 2);
- // run the first time to get the total output length
- int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
+ std::string role = json_value(curr_msg, "role", std::string(""));
+
+ std::string content;
+ if (curr_msg.contains("content")) {
+ if (curr_msg["content"].is_string()) {
+ content = curr_msg["content"].get<std::string>();
+ } else if (curr_msg["content"].is_array()) {
+ for (const auto & part : curr_msg["content"]) {
+ if (part.contains("text")) {
+ content += "\n" + part["text"].get<std::string>();
+ }
+ }
+ } else {
+ throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
+ }
+ } else {
+ throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
+ }
- // if it turns out that our buffer is too small, we resize it
- if ((size_t) res > buf.size()) {
- buf.resize(res);
- res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
+ chat.push_back({role, content});
}
- const std::string formatted_chat(buf.data(), res);
-
+ auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true);
LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
-
return formatted_chat;
}