summaryrefslogtreecommitdiff
path: root/tests/test-chat-template.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test-chat-template.cpp')
-rw-r--r--tests/test-chat-template.cpp60
1 files changed, 57 insertions, 3 deletions
diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp
index cef9a650..a8222cae 100644
--- a/tests/test-chat-template.cpp
+++ b/tests/test-chat-template.cpp
@@ -1,4 +1,3 @@
-#include <iostream>
#include <string>
#include <vector>
#include <sstream>
@@ -7,6 +6,7 @@
#include <cassert>
#include "llama.h"
+#include "common.h"
int main(void) {
llama_chat_message conversation[] = {
@@ -56,7 +56,15 @@ int main(void) {
//Phi-3-medium
"{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
//Phi-3-vision
- "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}"
+ "{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
+ // ChatGLM3
+ "{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
+ // ChatGLM4
+ u8"[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
+ // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
+ u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
+ // DeepSeek-V2
+ "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
};
std::vector<std::string> expected_output = {
// teknium/OpenHermes-2.5-Mistral-7B
@@ -93,6 +101,14 @@ int main(void) {
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
//Phi-3-vision
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
+ // ChatGLM3
+ "[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>",
+ // ChatGLM4
+ "[gMASK]<sop><|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
+ // MiniCPM-3B-OpenHermes-2.5-v2-GGUF
+ u8"You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>",
+ // DeepSeek-V2
+ u8"You are a helpful assistant\n\nUser: Hello\n\nAssistant: Hi there<|end▁of▁sentence|>User: Who are you\n\nAssistant: I am an assistant <|end▁of▁sentence|>User: Another question\n\nAssistant:",
};
std::vector<char> formatted_chat(1024);
int32_t res;
@@ -116,8 +132,46 @@ int main(void) {
);
formatted_chat.resize(res);
std::string output(formatted_chat.data(), formatted_chat.size());
- std::cout << output << "\n-------------------------\n";
+ printf("%s\n", output.c_str());
+ printf("-------------------------\n");
assert(output == expected);
}
+
+
+ // test llama_chat_format_single for system message
+ printf("\n\n=== llama_chat_format_single (system message) ===\n\n");
+ std::vector<llama_chat_msg> chat2;
+ llama_chat_msg sys_msg{"system", "You are a helpful assistant"};
+
+ auto fmt_sys = [&](std::string tmpl) {
+ auto output = llama_chat_format_single(nullptr, tmpl, chat2, sys_msg, false);
+ printf("fmt_sys(%s) : %s\n", tmpl.c_str(), output.c_str());
+ printf("-------------------------\n");
+ return output;
+ };
+ assert(fmt_sys("chatml") == "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n");
+ assert(fmt_sys("llama2") == "[INST] You are a helpful assistant\n");
+ assert(fmt_sys("gemma") == ""); // for gemma, system message is merged with user message
+ assert(fmt_sys("llama3") == "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>");
+
+
+ // test llama_chat_format_single for user message
+ printf("\n\n=== llama_chat_format_single (user message) ===\n\n");
+ chat2.push_back({"system", "You are a helpful assistant"});
+ chat2.push_back({"user", "Hello"});
+ chat2.push_back({"assistant", "I am assistant"});
+ llama_chat_msg new_msg{"user", "How are you"};
+
+ auto fmt_single = [&](std::string tmpl) {
+ auto output = llama_chat_format_single(nullptr, tmpl, chat2, new_msg, true);
+ printf("fmt_single(%s) : %s\n", tmpl.c_str(), output.c_str());
+ printf("-------------------------\n");
+ return output;
+ };
+ assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
+ assert(fmt_single("llama2") == "[INST] How are you [/INST]");
+ assert(fmt_single("gemma") == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
+ assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
+
return 0;
}