diff options
Diffstat (limited to 'tests/test-tokenizer-random.py')
-rw-r--r-- | tests/test-tokenizer-random.py | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/tests/test-tokenizer-random.py b/tests/test-tokenizer-random.py index 1166ac1e..7e1b656e 100644 --- a/tests/test-tokenizer-random.py +++ b/tests/test-tokenizer-random.py @@ -154,19 +154,22 @@ def generator_custom_text_edge_cases() -> Iterator[str]: '\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM) 'Cửa Việt', # llama-3, ignore_merges = true '<s>a', # Phi-3 fail - '<unk><|endoftext|><s>' # Phi-3 fail + '<unk><|endoftext|><s>', # Phi-3 fail 'a\na', # TODO: Bert fail ] -def generator_random_special_tokens(special_tokens:list[str], iterations=100) -> Iterator[str]: - special_tokens = set(special_tokens) +def generator_random_special_tokens(tokenizer, iterations=100) -> Iterator[str]: + special_tokens = set(tokenizer.all_special_tokens) special_tokens.update([" ", "\n", "\t", "-", "!", "one", "1", "<s>", "</s>"]) special_tokens = list(sorted(special_tokens)) rand = random.Random() for m in range(iterations): rand.seed(m) words = rand.choices(special_tokens, k=500) + if tokenizer.add_bos_token: # skip spam warning of double BOS + while words and words[0] == tokenizer.bos_token: + words.pop(0) yield "".join(words) @@ -290,18 +293,19 @@ def main(argv: list[str] = None): model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096)) tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer) - def func_tokenize2(text: str): - return tokenizer.encode(text, add_special_tokens=False) - - parse_special = all(len(func_tokenize2(t)) == 1 for t in tokenizer.all_special_tokens) + tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", True) + tokenizer.add_eos_token = getattr(tokenizer, "add_eos_token", False) def func_tokenize1(text: str): - return model.tokenize(text, add_special=False, parse_special=parse_special) + return model.tokenize(text, add_special=True, parse_special=True) + + def func_tokenize2(text: str): + return tokenizer.encode(text, add_special_tokens=True) vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True))) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text()) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases()) - test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_special_tokens(tokenizer.all_special_tokens, 10_000)) + test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_special_tokens(tokenizer, 10_000)) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_vocab_words(vocab)) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_chars(10_000)) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000)) |