summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorubergarm <leimgrub@gmail.com>2025-07-15 13:54:04 -0400
committerGitHub <noreply@github.com>2025-07-15 19:54:04 +0200
commit13b2f193723486f46efe34297cf797186ab14bc2 (patch)
treebda8a4b50adb20a564302e16dc42bed45ea798d4
parent2081b3fccb9923699bf4d5e926d8719fc1d12c39 (diff)
kimi-k2 convert script and chat template (#612)
* convert_hf_to_gguf for Kimi-K2-Instruct Adapt mainline `PR14653` for tokenizer while maintaining proper MLA tensors. Tested with this workflow using deepseek fp8_cast_bf16.py and triton-cpu to upcast the fp8 safetensors to bf16 safetensors then used this convert_hf_to_gguf. * Add Kimi-K2 chat template moonshotai/Kimi-K2-Instruct https://github.com/ikawrakow/ik_llama.cpp/pull/609#issuecomment-3071259454 * kimi-k2 add ass to template to get response
-rw-r--r--convert_hf_to_gguf.py57
-rwxr-xr-xconvert_hf_to_gguf_update.py1
-rw-r--r--src/llama.cpp19
3 files changed, 77 insertions, 0 deletions
diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index b0a82c80..76f269b3 100644
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -639,6 +639,9 @@ class Model:
if chkhsh == "d5f1dd6f980fec569fb218a81a7658ac45fc56b38c5a0adeb1c232fbe04ef5ec":
# ref: https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base
res = "seed-coder"
+ if chkhsh == "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890":
+ # ref: https://huggingface.co/moonshotai/Kimi-K2-Base
+ res = "kimi-k2"
if res is None:
logger.warning("\n")
@@ -3379,6 +3382,60 @@ class DeepseekV2Model(Model):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
def set_vocab(self):
+
+ if self.hparams["vocab_size"] == 163840: # Kimi-K2 model
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ self.dir_model, trust_remote_code=True
+ )
+ tokpre = self.get_vocab_base_pre(tokenizer)
+
+ # Build merges list using the approach similar to HunYuanMoE
+ merges = []
+ vocab = {}
+ mergeable_ranks = tokenizer.model._mergeable_ranks
+ for token, rank in mergeable_ranks.items():
+ vocab[QwenModel.token_bytes_to_string(token)] = rank
+ if len(token) == 1:
+ continue
+ merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
+ if len(merged) == 2:
+ merges.append(
+ " ".join(map(QwenModel.token_bytes_to_string, merged))
+ )
+
+ # Build token list
+ vocab_size = self.hparams["vocab_size"]
+ special_tokens = tokenizer.special_tokens
+ reverse_vocab = {
+ id_: encoded_tok
+ for encoded_tok, id_ in {**vocab, **special_tokens}.items()
+ }
+ tokens: list[str] = []
+ toktypes: list[int] = []
+
+ for i in range(vocab_size):
+ if i not in reverse_vocab:
+ tokens.append(f"[PAD{i}]")
+ toktypes.append(gguf.TokenType.UNUSED)
+ else:
+ token = reverse_vocab[i]
+ tokens.append(token)
+ if i in special_tokens.values():
+ toktypes.append(gguf.TokenType.CONTROL)
+ else:
+ toktypes.append(gguf.TokenType.NORMAL)
+
+ self.gguf_writer.add_tokenizer_model("gpt2")
+ self.gguf_writer.add_tokenizer_pre(tokpre)
+ self.gguf_writer.add_token_list(tokens)
+ self.gguf_writer.add_token_types(toktypes)
+ self.gguf_writer.add_token_merges(merges)
+
+ special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
+ special_vocab.add_to_gguf(self.gguf_writer)
+ else:
self._set_vocab_gpt2()
def set_gguf_parameters(self):
diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py
index f2e6cc37..d6541987 100755
--- a/convert_hf_to_gguf_update.py
+++ b/convert_hf_to_gguf_update.py
@@ -96,6 +96,7 @@ models = [
{"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", },
{"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"},
{"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", },
+ {"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890", },
]
diff --git a/src/llama.cpp b/src/llama.cpp
index 0a81f2b9..58812fc8 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -1695,6 +1695,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_BITNET,
LLM_CHAT_TEMPLATE_DOTS1,
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
+ LLM_CHAT_TEMPLATE_KIMI_K2,
LLM_CHAT_TEMPLATE_UNKNOWN,
};
@@ -1733,6 +1734,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "megrez", LLM_CHAT_TEMPLATE_MEGREZ },
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
+ { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
{ "bitnet", LLM_CHAT_TEMPLATE_BITNET },
};
@@ -23270,6 +23272,8 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_DOTS1;
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
+ } else if (tmpl_contains("<|im_middle|>") && tmpl_contains("<|im_end|>")) {
+ return LLM_CHAT_TEMPLATE_KIMI_K2;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
@@ -23715,6 +23719,21 @@ static int32_t llama_chat_apply_template_internal(
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
}
}
+ } else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) {
+ // moonshotai/Kimi-K2-Instruct
+ for (auto message : chat) {
+ std::string role(message->role);
+ if (role == "system") {
+ ss << "<|im_system|>system<|im_middle|>" << message->content << "<|im_end|>";
+ } else if (role == "assistant") {
+ ss << "<|im_user|>user<|im_middle|>" << message->content << "<|im_end|>";
+ } else {
+ ss << "<|im_assistant|>assistant<|im_middle|>" << message->content << "<|im_end|>";
+ }
+ }
+ if (add_ass) {
+ ss << "<|im_assistant|>assistant<|im_middle|>";
+ }
} else {
// template not supported
return -1;