summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2024-02-06 07:47:22 +0200
committerGitHub <noreply@github.com>2024-02-06 07:47:22 +0200
commit906cff55c2848fda091d888a1585915ec0c9ea9e (patch)
treea1229b3104020dac71711b6128ee9df94b16f7e8
parent098f6d737b65134cf220d12b9b706e8cfc5e4610 (diff)
py : handle byte tokens in `get_token_type` (#5341)
* py : handle byte tokens in `get_token_type` * py : fix empty bytes arg
-rwxr-xr-xconvert.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/convert.py b/convert.py
index 75c10011..4a2847a2 100755
--- a/convert.py
+++ b/convert.py
@@ -515,10 +515,14 @@ class HfVocab:
# Yield token text, score, and type
yield token_text, self.get_token_score(token_id), self.get_token_type(
- token_id, self.special_ids # Reuse already stored special IDs
+ token_id, token_text, self.special_ids # Reuse already stored special IDs
)
- def get_token_type(self, token_id: int, special_ids: set[int]) -> gguf.TokenType:
+ def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType:
+ # Special case for byte tokens
+ if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text):
+ return gguf.TokenType.BYTE
+
# Determine token type based on whether it's a special token
return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL
@@ -530,7 +534,7 @@ class HfVocab:
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
for text in self.added_tokens_list:
if text in self.specials:
- toktype = self.get_token_type(self.specials[text], self.special_ids)
+ toktype = self.get_token_type(self.specials[text], b'', self.special_ids)
score = self.get_token_score(self.specials[text])
else:
toktype = gguf.TokenType.USER_DEFINED