summaryrefslogtreecommitdiff
path: root/scripts/gen-unicode-data.py
diff options
context:
space:
mode:
authorjaime-m-p <167997752+jaime-m-p@users.noreply.github.com>2024-06-18 18:40:52 +0200
committerGitHub <noreply@github.com>2024-06-18 18:40:52 +0200
commit37bef8943312d91183ff06d8f1214082a17344a5 (patch)
tree7713dc5aceb3b181568db3d21b1383762de41c4a /scripts/gen-unicode-data.py
parent91c188d6c296bd3384f2a02a83b71187aa3d18b3 (diff)
tokenizer : BPE fixes (#7530)
* Random test: add_bos_token, add_eos_token * Random test: add BPE models for testing * Custom regex split fails with codepoint 0 * Fix falcon punctuation regex * Refactor llm_tokenizer_bpe: move code to constructor * Move 'add_special_bos/eos' logic to llm_tokenizer_bpe * Move tokenizer flags to vocab structure. * Default values for special_add_bos/eos * Build vocab.special_tokens_cache using vocab token types * Generalize 'jina-v2' per token attributes * Fix unicode whitespaces (deepseek-coder, deepseek-llm) * Skip missing byte tokens (falcon) * Better unicode data generation * Replace char32_t with uint32_t
Diffstat (limited to 'scripts/gen-unicode-data.py')
-rw-r--r--scripts/gen-unicode-data.py180
1 files changed, 120 insertions, 60 deletions
diff --git a/scripts/gen-unicode-data.py b/scripts/gen-unicode-data.py
index 744873c2..890e4d7c 100644
--- a/scripts/gen-unicode-data.py
+++ b/scripts/gen-unicode-data.py
@@ -1,83 +1,143 @@
-import regex
-import ctypes
+import array
import unicodedata
-
-
-class CoodepointFlags (ctypes.Structure):
- _fields_ = [ # see definition in unicode.h
- ("is_undefined", ctypes.c_uint16, 1),
- ("is_number", ctypes.c_uint16, 1), # regex: \p{N}
- ("is_letter", ctypes.c_uint16, 1), # regex: \p{L}
- ("is_separator", ctypes.c_uint16, 1), # regex: \p{Z}
- ("is_accent_mark", ctypes.c_uint16, 1), # regex: \p{M}
- ("is_punctuation", ctypes.c_uint16, 1), # regex: \p{P}
- ("is_symbol", ctypes.c_uint16, 1), # regex: \p{S}
- ("is_control", ctypes.c_uint16, 1), # regex: \p{C}
- ]
-
-
-assert (ctypes.sizeof(CoodepointFlags) == 2)
+import requests
MAX_CODEPOINTS = 0x110000
-regex_number = regex.compile(r'\p{N}')
-regex_letter = regex.compile(r'\p{L}')
-regex_separator = regex.compile(r'\p{Z}')
-regex_accent_mark = regex.compile(r'\p{M}')
-regex_punctuation = regex.compile(r'\p{P}')
-regex_symbol = regex.compile(r'\p{S}')
-regex_control = regex.compile(r'\p{C}')
-regex_whitespace = regex.compile(r'\s')
-
-codepoint_flags = (CoodepointFlags * MAX_CODEPOINTS)()
+UNICODE_DATA_URL = "https://www.unicode.org/Public/UCD/latest/ucd/UnicodeData.txt"
+
+
+# see https://www.unicode.org/L2/L1999/UnicodeData.html
+def unicode_data_iter():
+ res = requests.get(UNICODE_DATA_URL)
+ res.raise_for_status()
+ data = res.content.decode()
+
+ prev = []
+
+ for line in data.splitlines():
+ # ej: 0000;<control>;Cc;0;BN;;;;;N;NULL;;;;
+ line = line.split(";")
+
+ cpt = int(line[0], base=16)
+ assert cpt < MAX_CODEPOINTS
+
+ cpt_lower = int(line[-2] or "0", base=16)
+ assert cpt_lower < MAX_CODEPOINTS
+
+ cpt_upper = int(line[-3] or "0", base=16)
+ assert cpt_upper < MAX_CODEPOINTS
+
+ categ = line[2].strip()
+ assert len(categ) == 2
+
+ bidir = line[4].strip()
+ assert len(categ) == 2
+
+ name = line[1]
+ if name.endswith(", First>"):
+ prev = (cpt, cpt_lower, cpt_upper, categ, bidir)
+ continue
+ if name.endswith(", Last>"):
+ assert prev[1:] == (0, 0, categ, bidir)
+ for c in range(prev[0], cpt):
+ yield (c, cpt_lower, cpt_upper, categ, bidir)
+
+ yield (cpt, cpt_lower, cpt_upper, categ, bidir)
+
+
+# see definition in unicode.h
+CODEPOINT_FLAG_UNDEFINED = 0x0001 #
+CODEPOINT_FLAG_NUMBER = 0x0002 # \p{N}
+CODEPOINT_FLAG_LETTER = 0x0004 # \p{L}
+CODEPOINT_FLAG_SEPARATOR = 0x0008 # \p{Z}
+CODEPOINT_FLAG_MARK = 0x0010 # \p{M}
+CODEPOINT_FLAG_PUNCTUATION = 0x0020 # \p{P}
+CODEPOINT_FLAG_SYMBOL = 0x0040 # \p{S}
+CODEPOINT_FLAG_CONTROL = 0x0080 # \p{C}
+
+UNICODE_CATEGORY_TO_FLAG = {
+ "Cn": CODEPOINT_FLAG_UNDEFINED, # Undefined
+ "Cc": CODEPOINT_FLAG_CONTROL, # Control
+ "Cf": CODEPOINT_FLAG_CONTROL, # Format
+ "Co": CODEPOINT_FLAG_CONTROL, # Private Use
+ "Cs": CODEPOINT_FLAG_CONTROL, # Surrrogate
+ "Ll": CODEPOINT_FLAG_LETTER, # Lowercase Letter
+ "Lm": CODEPOINT_FLAG_LETTER, # Modifier Letter
+ "Lo": CODEPOINT_FLAG_LETTER, # Other Letter
+ "Lt": CODEPOINT_FLAG_LETTER, # Titlecase Letter
+ "Lu": CODEPOINT_FLAG_LETTER, # Uppercase Letter
+ "L&": CODEPOINT_FLAG_LETTER, # Cased Letter
+ "Mc": CODEPOINT_FLAG_MARK, # Spacing Mark
+ "Me": CODEPOINT_FLAG_MARK, # Enclosing Mark
+ "Mn": CODEPOINT_FLAG_MARK, # Nonspacing Mark
+ "Nd": CODEPOINT_FLAG_NUMBER, # Decimal Number
+ "Nl": CODEPOINT_FLAG_NUMBER, # Letter Number
+ "No": CODEPOINT_FLAG_NUMBER, # Other Number
+ "Pc": CODEPOINT_FLAG_PUNCTUATION, # Connector Punctuation
+ "Pd": CODEPOINT_FLAG_PUNCTUATION, # Dash Punctuation
+ "Pe": CODEPOINT_FLAG_PUNCTUATION, # Close Punctuation
+ "Pf": CODEPOINT_FLAG_PUNCTUATION, # Final Punctuation
+ "Pi": CODEPOINT_FLAG_PUNCTUATION, # Initial Punctuation
+ "Po": CODEPOINT_FLAG_PUNCTUATION, # Other Punctuation
+ "Ps": CODEPOINT_FLAG_PUNCTUATION, # Open Punctuation
+ "Sc": CODEPOINT_FLAG_SYMBOL, # Currency Symbol
+ "Sk": CODEPOINT_FLAG_SYMBOL, # Modifier Symbol
+ "Sm": CODEPOINT_FLAG_SYMBOL, # Math Symbol
+ "So": CODEPOINT_FLAG_SYMBOL, # Other Symbol
+ "Zl": CODEPOINT_FLAG_SEPARATOR, # Line Separator
+ "Zp": CODEPOINT_FLAG_SEPARATOR, # Paragraph Separator
+ "Zs": CODEPOINT_FLAG_SEPARATOR, # Space Separator
+}
+
+
+codepoint_flags = array.array('H', [CODEPOINT_FLAG_UNDEFINED]) * MAX_CODEPOINTS
table_whitespace = []
table_lowercase = []
table_uppercase = []
table_nfd = []
-for codepoint in range(MAX_CODEPOINTS):
+for (cpt, cpt_lower, cpt_upper, categ, bidir) in unicode_data_iter():
# convert codepoint to unicode character
- char = chr(codepoint)
-
- # regex categories
- flags = codepoint_flags[codepoint]
- flags.is_number = bool(regex_number.match(char))
- flags.is_letter = bool(regex_letter.match(char))
- flags.is_separator = bool(regex_separator.match(char))
- flags.is_accent_mark = bool(regex_accent_mark.match(char))
- flags.is_punctuation = bool(regex_punctuation.match(char))
- flags.is_symbol = bool(regex_symbol.match(char))
- flags.is_control = bool(regex_control.match(char))
- flags.is_undefined = bytes(flags)[0] == 0
- assert (not flags.is_undefined)
-
- # whitespaces
- if bool(regex_whitespace.match(char)):
- table_whitespace.append(codepoint)
+ char = chr(cpt)
+
+ # codepoint category flags
+ codepoint_flags[cpt] = UNICODE_CATEGORY_TO_FLAG[categ]
# lowercase conversion
- lower = ord(char.lower()[0])
- if codepoint != lower:
- table_lowercase.append((codepoint, lower))
+ if cpt_lower:
+ table_lowercase.append((cpt, cpt_lower))
# uppercase conversion
- upper = ord(char.upper()[0])
- if codepoint != upper:
- table_uppercase.append((codepoint, upper))
+ if cpt_upper:
+ table_uppercase.append((cpt, cpt_upper))
# NFD normalization
norm = ord(unicodedata.normalize('NFD', char)[0])
- if codepoint != norm:
- table_nfd.append((codepoint, norm))
+ if cpt != norm:
+ table_nfd.append((cpt, norm))
+
+
+# whitespaces, see "<White_Space>" https://www.unicode.org/Public/UCD/latest/ucd/PropList.txt
+table_whitespace.extend(range(0x0009, 0x000D + 1))
+table_whitespace.extend(range(0x2000, 0x200A + 1))
+table_whitespace.extend([0x0020, 0x0085, 0x00A0, 0x1680, 0x2028, 0x2029, 0x202F, 0x205F, 0x3000])
+
+
+# sort by codepoint
+table_whitespace.sort()
+table_lowercase.sort()
+table_uppercase.sort()
+table_nfd.sort()
# group ranges with same flags
ranges_flags = [(0, codepoint_flags[0])] # start, flags
for codepoint, flags in enumerate(codepoint_flags):
- if bytes(flags) != bytes(ranges_flags[-1][1]):
+ if flags != ranges_flags[-1][1]:
ranges_flags.append((codepoint, flags))
-ranges_flags.append((MAX_CODEPOINTS, CoodepointFlags()))
+ranges_flags.append((MAX_CODEPOINTS, 0x0000))
# group ranges with same nfd
@@ -90,8 +150,8 @@ for codepoint, norm in table_nfd:
ranges_nfd[-1] = (start, codepoint, norm)
-# Generate 'unicode-data.cpp'
-
+# Generate 'unicode-data.cpp':
+# python ./scripts//gen-unicode-data.py > unicode-data.cpp
def out(line=""):
print(line, end='\n') # noqa
@@ -110,12 +170,12 @@ out("""\
out("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1")
for codepoint, flags in ranges_flags:
- flags = int.from_bytes(bytes(flags), "little")
out("{0x%06X, 0x%04X}," % (codepoint, flags))
out("};\n")
out("const std::unordered_set<uint32_t> unicode_set_whitespace = {")
-out(", ".join("0x%06X" % cpt for cpt in table_whitespace))
+for codepoint in table_whitespace:
+ out("0x%06X," % codepoint)
out("};\n")
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")