summaryrefslogtreecommitdiff
path: root/convert.py
diff options
context:
space:
mode:
Diffstat (limited to 'convert.py')
-rwxr-xr-xconvert.py21
1 files changed, 12 insertions, 9 deletions
diff --git a/convert.py b/convert.py
index 0680f71e..bfbfab28 100755
--- a/convert.py
+++ b/convert.py
@@ -366,16 +366,19 @@ class SentencePieceVocab:
added_tokens = {}
vocab_size: int = self.sentencepiece_tokenizer.vocab_size()
- expected_ids = list(range(vocab_size, vocab_size + len(added_tokens)))
- actual_ids = sorted(added_tokens.values())
- if expected_ids != actual_ids:
- raise Exception(f"Expected added token IDs to be sequential and start at {vocab_size}; got {actual_ids}")
- items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1])
- self.added_tokens_list = [text for (text, idx) in items]
- self.vocab_size_base: int = vocab_size
- self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list)
- self.fname_tokenizer = fname_tokenizer
+ new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size}
+ expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens)))
+ actual_new_ids = sorted(new_tokens.keys())
+
+ if expected_new_ids != actual_new_ids:
+ raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}")
+
+ # Token pieces that were added to the base vocabulary.
+ self.added_tokens_list = [new_tokens[id] for id in actual_new_ids]
+ self.vocab_size_base = vocab_size
+ self.vocab_size = self.vocab_size_base + len(self.added_tokens_list)
+ self.fname_tokenizer = fname_tokenizer
self.fname_added_tokens = fname_added_tokens
def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: