diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2024-05-04 08:32:32 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-05-04 08:32:32 +0300 |
commit | 92139b90af4841d7fd060b526bdd443b621770ff (patch) | |
tree | 9679c3de1b39970ca73b5bd988c63ddac0359ca6 /scripts/gen-unicode-data.py | |
parent | a2ac89d6efb41b535778bfeaecaae8fe295b6ed3 (diff) |
tests : add test-tokenizer-0.sh + fix some tokenizers (#7036)
* tests : add test-tokenizer-0.sh
* unicode : add all unicode number ranges
* starcoder : fix pre-tokenizer
* tests : add test that fails with DeepSeek tokenizers
* falcon : fix regex
* unicode : regenerate unicode tables
* refact : add tokenizer model
* lint : fix
* tests : disable failing tests
ggml-ci
* refact : add tests files
ggml-ci
* convert : print -> logging
ggml-ci
* lint : fix
* unicode : digit -> number
* phi-3 : update
Diffstat (limited to 'scripts/gen-unicode-data.py')
-rw-r--r-- | scripts/gen-unicode-data.py | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/scripts/gen-unicode-data.py b/scripts/gen-unicode-data.py new file mode 100644 index 00000000..d49cbf2a --- /dev/null +++ b/scripts/gen-unicode-data.py @@ -0,0 +1,66 @@ +import regex + + +def cpt_to_utf8_str(cpt): + if cpt <= 0xFF: + return bytes([cpt, 0, 0, 0]) + elif cpt <= 0xFFFF: + return bytes([cpt & 0xFF, cpt >> 8, 0, 0]) + elif cpt <= 0xFFFFFF: + return bytes([cpt & 0xFF, (cpt >> 8) & 0xFF, (cpt >> 16) & 0xFF, 0]) + else: + return bytes([cpt & 0xFF, (cpt >> 8) & 0xFF, (cpt >> 16) & 0xFF, cpt >> 24]) + + +def is_match(codepoint, regex_expr): + try: + res = regex.match(regex_expr, cpt_to_utf8_str(codepoint).decode('utf-32')) + return res is not None + except Exception: + return False + + +def get_matches(regex_expr): + unicode_ranges = [] + current_range = None + + for codepoint in range(0x110000): + if is_match(codepoint, regex_expr): + if current_range is None: + current_range = [codepoint, codepoint] + else: + current_range[1] = codepoint + elif current_range is not None: + unicode_ranges.append(tuple(current_range)) + current_range = None + + if current_range is not None: + unicode_ranges.append(tuple(current_range)) + + return unicode_ranges + + +def print_cat(cat, ranges): + print("const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_{} = {{".format(cat)) + cnt = 0 + for start, end in ranges: + if cnt % 4 != 0: + print(" ", end="") + print("{{0x{:08X}, 0x{:08X}}},".format(start, end), end="") + if cnt % 4 == 3: + print("") + cnt += 1 + + if cnt % 4 != 0: + print("") + print("};") + print("") + + +print_cat("number", get_matches(r'\p{N}')) +print_cat("letter", get_matches(r'\p{L}')) +print_cat("whitespace", get_matches(r'\p{Z}')) +print_cat("accent_mark", get_matches(r'\p{M}')) +print_cat("punctuation", get_matches(r'\p{P}')) +print_cat("symbol", get_matches(r'\p{S}')) +print_cat("control", get_matches(r'\p{C}')) |