diff options
author | Xuan Son Nguyen <thichthat@gmail.com> | 2024-02-20 15:58:27 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-20 15:58:27 +0100 |
commit | 9c405c9f9a7cfd23511fd6b2de05dc72481119b4 (patch) | |
tree | 694b5a169d63eb4640df2d6f536d384cc481b300 /examples/server/utils.hpp | |
parent | 5207b3fbc500f89dfe528693e96540956dbaed96 (diff) |
Server: use llama_chat_apply_template (#5593)
* server: use llama_chat_apply_template
* server: remove trailing space
* server: fix format_chat
* server: fix help message
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
* server: fix formatted_chat
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Diffstat (limited to 'examples/server/utils.hpp')
-rw-r--r-- | examples/server/utils.hpp | 69 |
1 files changed, 33 insertions, 36 deletions
diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 0ee670db..e954fb0e 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -167,50 +167,47 @@ static T json_value(const json &body, const std::string &key, const T &default_v : default_value; } -inline std::string format_llama2(std::vector<json> messages) -{ - std::ostringstream output; - bool is_inside_turn = false; +// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid +inline bool verify_custom_template(const std::string & tmpl) { + llama_chat_message chat[] = {{"user", "test"}}; + std::vector<char> buf(1); + int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, buf.data(), buf.size()); + return res >= 0; +} - for (auto it = messages.begin(); it != messages.end(); ++it) { - if (!is_inside_turn) { - output << "[INST] "; - } - std::string role = json_value(*it, "role", std::string("user")); - std::string content = json_value(*it, "content", std::string("")); - if (role == "system") { - output << "<<SYS>>\n" << content << "\n<<SYS>>\n\n"; - is_inside_turn = true; - } else if (role == "user") { - output << content << " [/INST]"; - is_inside_turn = true; - } else { - output << " " << content << " </s>"; - is_inside_turn = false; - } +// Format given chat. If tmpl is empty, we take the template from model metadata +inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages) +{ + size_t alloc_size = 0; + // vector holding all allocated string to be passed to llama_chat_apply_template + std::vector<std::string> str(messages.size() * 2); + std::vector<llama_chat_message> chat(messages.size()); + + for (size_t i = 0; i < messages.size(); ++i) { + auto &curr_msg = messages[i]; + str[i*2 + 0] = json_value(curr_msg, "role", std::string("")); + str[i*2 + 1] = json_value(curr_msg, "content", std::string("")); + alloc_size += str[i*2 + 1].length(); + chat[i].role = str[i*2 + 0].c_str(); + chat[i].content = str[i*2 + 1].c_str(); } - LOG_VERBOSE("format_llama2", {{"text", output.str()}}); + const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); + std::vector<char> buf(alloc_size * 2); - return output.str(); -} - -inline std::string format_chatml(std::vector<json> messages) -{ - std::ostringstream chatml_msgs; + // run the first time to get the total output length + int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size()); - for (auto it = messages.begin(); it != messages.end(); ++it) { - chatml_msgs << "<|im_start|>" - << json_value(*it, "role", std::string("user")) << '\n'; - chatml_msgs << json_value(*it, "content", std::string("")) - << "<|im_end|>\n"; + // if it turns out that our buffer is too small, we resize it + if ((size_t) res > buf.size()) { + buf.resize(res); + res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size()); } - chatml_msgs << "<|im_start|>assistant" << '\n'; - - LOG_VERBOSE("format_chatml", {{"text", chatml_msgs.str()}}); + std::string formatted_chat(buf.data(), res); + LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}}); - return chatml_msgs.str(); + return formatted_chat; } // |