diff options
author | Aleksey Nikiforov <lexn82@gmail.com> | 2025-07-14 12:43:52 -0400 |
---|---|---|
committer | GitHub <noreply@github.com> | 2025-07-14 18:43:52 +0200 |
commit | f5353047ef461e6fc9d527e09a06c9802c699929 (patch) | |
tree | 206c8c56efd3dcac1e39655e73788affe6c02832 | |
parent | 255c22046bcaef41850125be924f3e42e2a65571 (diff) |
Ported kimi-k2 support from llama.cpp (#609)
Original patch by @gabriellarson:
https://github.com/ggml-org/llama.cpp/pull/14654
Co-authored-by: anikifoss <anikifoss>
-rw-r--r-- | include/llama.h | 1 | ||||
-rw-r--r-- | src/llama-vocab.cpp | 7 | ||||
-rw-r--r-- | src/llama.cpp | 6 | ||||
-rw-r--r-- | src/unicode.cpp | 208 | ||||
-rw-r--r-- | src/unicode.h | 2 |
5 files changed, 223 insertions, 1 deletions
diff --git a/include/llama.h b/include/llama.h index 96895afa..eaa5d69d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -112,6 +112,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_FALCON_E = 35, LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 36, //llama.cpp lists this as 35 LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 37, //llama.cpp lists this as 36 + LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 38, //llama.cpp lists this as 37 }; // note: these values should be synchronized with ggml_rope diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 7bae4fec..109a6659 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -464,6 +464,13 @@ struct llm_tokenizer_bpe { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_KIMI_K2: + regex_exprs = { + // K2 trigger pattern - this will activate the custom K2 handler in unicode.cpp + // The custom handler implements all K2 patterns with proper Han character exclusion + "\\p{Han}+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_SUPERBPE: regex_exprs = { "\\p{N}+", diff --git a/src/llama.cpp b/src/llama.cpp index 5777689e..ac02abf6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -116,7 +116,7 @@ // bump if necessary #define LLAMA_MAX_LAYERS 512 -#define LLAMA_MAX_EXPERTS 256 // DeepSeekV2 +#define LLAMA_MAX_EXPERTS 384 // Kimi-K2 // // helpers @@ -6402,6 +6402,10 @@ static void llm_load_vocab( tokenizer_pre == "hunyuan") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_HUNYUAN; vocab.tokenizer_clean_spaces = false; + } else if ( + tokenizer_pre == "kimi-k2") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_KIMI_K2; + vocab.tokenizer_clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } diff --git a/src/unicode.cpp b/src/unicode.cpp index a57456ea..c911fd26 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -374,6 +374,178 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string& te return bpe_offsets; } +// K2 system regex patterns (from tokenization_kimi.py): +// [\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+ +static std::vector<size_t> unicode_regex_split_custom_kimi_k2(const std::string & text, const std::vector<size_t> & offsets) { + std::vector<size_t> bpe_offsets; + bpe_offsets.reserve(offsets.size()); + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF; + auto _get_cpt = [&] (const size_t pos) -> uint32_t { + return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; + }; + + auto _get_flags = [&] (const size_t pos) -> codepoint_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{}; + }; + + size_t _prev_end = offset_ini; + auto _add_token = [&] (const size_t end) -> size_t { + assert(_prev_end <= end && end <= offset_end); + size_t len = end - _prev_end; + if (len > 0) { + bpe_offsets.push_back(len); + } + _prev_end = end; + return len; + }; + + for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { + const uint32_t cpt = _get_cpt(pos); + const auto flags = _get_flags(pos); + + // Pattern 1: [\p{Han}]+ (Chinese characters) + if (unicode_cpt_is_han(cpt)) { + while (unicode_cpt_is_han(_get_cpt(pos))) { + pos++; + } + _add_token(pos); + continue; + } + + // Pattern 2 & 3: Letter words excluding Han characters with optional contractions + // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?:'s|'t|'re|'ve|'m|'ll|'d)? + // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?:'s|'t|'re|'ve|'m|'ll|'d)? + // Check if current char is a letter OR if current char could be a leading char and next char is a letter + bool is_letter_pattern = (flags.is_letter && !unicode_cpt_is_han(cpt)) || + (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number) && + _get_flags(pos + 1).is_letter && !unicode_cpt_is_han(_get_cpt(pos + 1))); + + if (is_letter_pattern) { + // Handle optional leading non-letter/non-number character + bool has_leading_char = false; + if (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number)) { + has_leading_char = true; + pos++; + } + + // Match letter sequence (excluding Han characters) + bool has_letters = false; + while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) { + has_letters = true; + pos++; + } + + // Only proceed if we found letters (after potentially skipping leading char) + if (has_letters || (!has_leading_char && _get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos)))) { + if (!has_letters) pos++; // consume the first letter if we didn't already + + // Continue consuming letters + while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) { + pos++; + } + + // Check for optional contractions (?:'s|'t|'re|'ve|'m|'ll|'d) + if (_get_cpt(pos) == '\'' && pos + 1 < offset_end) { + uint32_t cpt_next = unicode_tolower(_get_cpt(pos + 1)); + if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { + pos += 2; + } else if (pos + 2 < offset_end) { + uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos + 2)); + if ((cpt_next == 'r' && cpt_next_next == 'e') || + (cpt_next == 'v' && cpt_next_next == 'e') || + (cpt_next == 'l' && cpt_next_next == 'l')) { + pos += 3; + } + } + } + + _add_token(pos); + continue; + } else if (has_leading_char) { + // We consumed a leading char but found no letters, backtrack + pos--; + } + } + + // Pattern 4: \p{N}{1,3} (numbers 1-3 digits) + if (flags.is_number) { + size_t ini = pos; + while (_get_flags(pos).is_number) { + if (++pos - ini >= 3) { + _add_token(pos); + ini = pos; + } + } + _add_token(pos); + continue; + } + + // Pattern 5: ?[^\s\p{L}\p{N}]+[\r\n]* (optional space + non-word chars + optional newlines) + auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags); + if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) { + pos += (cpt == ' '); + while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) { + flags2 = _get_flags(++pos); + } + // Match optional [\r\n]* + uint32_t cpt2 = _get_cpt(pos); + while (cpt2 == '\r' || cpt2 == '\n') { + cpt2 = _get_cpt(++pos); + } + _add_token(pos); + continue; + } + + // Count whitespace characters + size_t num_whitespaces = 0; + size_t last_end_r_or_n = 0; + while (_get_flags(pos + num_whitespaces).is_whitespace) { + uint32_t cpt2 = _get_cpt(pos + num_whitespaces); + if (cpt2 == '\r' || cpt2 == '\n') { + last_end_r_or_n = pos + num_whitespaces + 1; + } + num_whitespaces++; + } + + // Pattern 6: \s*[\r\n]+ (whitespace with newlines) + if (last_end_r_or_n > 0) { + pos = last_end_r_or_n; + _add_token(pos); + continue; + } + + // Pattern 7: \s+(?!\S) (trailing whitespace) + if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) { + pos += num_whitespaces - 1; + _add_token(pos); + continue; + } + + // Pattern 8: \s+ (general whitespace) + if (num_whitespaces > 0) { + pos += num_whitespaces; + _add_token(pos); + continue; + } + + // No matches - consume single character + _add_token(++pos); + } + } + + return bpe_offsets; +} + // LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string& text, const std::vector<size_t>& offsets) { std::vector<size_t> bpe_offsets; // store the offset of each word @@ -587,6 +759,10 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string& text, c bpe_offsets = unicode_regex_split_custom_llama3(text, offsets); } + else if (regex_expr == "\\p{Han}+") { + // K2's first pattern - handle all K2 patterns together + bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets); + } return bpe_offsets; } @@ -662,6 +838,38 @@ codepoint_flags unicode_cpt_flags(const std::string& utf8) { return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset)); } +bool unicode_cpt_is_han(uint32_t cpt) { + // Han character ranges (Chinese/CJK characters) + // CJK Unified Ideographs (most common) + if (cpt >= 0x4E00 && cpt <= 0x9FFF) return true; + + // CJK Extension A + if (cpt >= 0x3400 && cpt <= 0x4DBF) return true; + + // CJK Extension B + if (cpt >= 0x20000 && cpt <= 0x2A6DF) return true; + + // CJK Extension C + if (cpt >= 0x2A700 && cpt <= 0x2B73F) return true; + + // CJK Extension D + if (cpt >= 0x2B740 && cpt <= 0x2B81F) return true; + + // CJK Extension E + if (cpt >= 0x2B820 && cpt <= 0x2CEAF) return true; + + // CJK Extension F + if (cpt >= 0x2CEB0 && cpt <= 0x2EBEF) return true; + + // CJK Compatibility Ideographs + if (cpt >= 0xF900 && cpt <= 0xFAFF) return true; + + // CJK Compatibility Ideographs Supplement + if (cpt >= 0x2F800 && cpt <= 0x2FA1F) return true; + + return false; +} + std::string unicode_byte_to_utf8(uint8_t byte) { static std::unordered_map<uint8_t, std::string> map = unicode_byte_to_utf8_map(); return map.at(byte); diff --git a/src/unicode.h b/src/unicode.h index 008532a2..48940239 100644 --- a/src/unicode.h +++ b/src/unicode.h @@ -64,4 +64,6 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8); uint32_t unicode_tolower(uint32_t cp); +bool unicode_cpt_is_han(uint32_t cpt); + std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs); |