summaryrefslogtreecommitdiff
path: root/scripts/gen-unicode-data.py
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-05-04 08:32:32 +0300
committerGitHub <noreply@github.com>2024-05-04 08:32:32 +0300
commit92139b90af4841d7fd060b526bdd443b621770ff (patch)
tree9679c3de1b39970ca73b5bd988c63ddac0359ca6 /scripts/gen-unicode-data.py
parenta2ac89d6efb41b535778bfeaecaae8fe295b6ed3 (diff)
tests : add test-tokenizer-0.sh + fix some tokenizers (#7036)
* tests : add test-tokenizer-0.sh * unicode : add all unicode number ranges * starcoder : fix pre-tokenizer * tests : add test that fails with DeepSeek tokenizers * falcon : fix regex * unicode : regenerate unicode tables * refact : add tokenizer model * lint : fix * tests : disable failing tests ggml-ci * refact : add tests files ggml-ci * convert : print -> logging ggml-ci * lint : fix * unicode : digit -> number * phi-3 : update
Diffstat (limited to 'scripts/gen-unicode-data.py')
-rw-r--r--scripts/gen-unicode-data.py66
1 files changed, 66 insertions, 0 deletions
diff --git a/scripts/gen-unicode-data.py b/scripts/gen-unicode-data.py
new file mode 100644
index 00000000..d49cbf2a
--- /dev/null
+++ b/scripts/gen-unicode-data.py
@@ -0,0 +1,66 @@
+import regex
+
+
+def cpt_to_utf8_str(cpt):
+ if cpt <= 0xFF:
+ return bytes([cpt, 0, 0, 0])
+ elif cpt <= 0xFFFF:
+ return bytes([cpt & 0xFF, cpt >> 8, 0, 0])
+ elif cpt <= 0xFFFFFF:
+ return bytes([cpt & 0xFF, (cpt >> 8) & 0xFF, (cpt >> 16) & 0xFF, 0])
+ else:
+ return bytes([cpt & 0xFF, (cpt >> 8) & 0xFF, (cpt >> 16) & 0xFF, cpt >> 24])
+
+
+def is_match(codepoint, regex_expr):
+ try:
+ res = regex.match(regex_expr, cpt_to_utf8_str(codepoint).decode('utf-32'))
+ return res is not None
+ except Exception:
+ return False
+
+
+def get_matches(regex_expr):
+ unicode_ranges = []
+ current_range = None
+
+ for codepoint in range(0x110000):
+ if is_match(codepoint, regex_expr):
+ if current_range is None:
+ current_range = [codepoint, codepoint]
+ else:
+ current_range[1] = codepoint
+ elif current_range is not None:
+ unicode_ranges.append(tuple(current_range))
+ current_range = None
+
+ if current_range is not None:
+ unicode_ranges.append(tuple(current_range))
+
+ return unicode_ranges
+
+
+def print_cat(cat, ranges):
+ print("const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_{} = {{".format(cat))
+ cnt = 0
+ for start, end in ranges:
+ if cnt % 4 != 0:
+ print(" ", end="")
+ print("{{0x{:08X}, 0x{:08X}}},".format(start, end), end="")
+ if cnt % 4 == 3:
+ print("")
+ cnt += 1
+
+ if cnt % 4 != 0:
+ print("")
+ print("};")
+ print("")
+
+
+print_cat("number", get_matches(r'\p{N}'))
+print_cat("letter", get_matches(r'\p{L}'))
+print_cat("whitespace", get_matches(r'\p{Z}'))
+print_cat("accent_mark", get_matches(r'\p{M}'))
+print_cat("punctuation", get_matches(r'\p{P}'))
+print_cat("symbol", get_matches(r'\p{S}'))
+print_cat("control", get_matches(r'\p{C}'))