summaryrefslogtreecommitdiff
path: root/gguf-py/gguf/vocab.py
diff options
context:
space:
mode:
Diffstat (limited to 'gguf-py/gguf/vocab.py')
-rw-r--r--gguf-py/gguf/vocab.py164
1 files changed, 164 insertions, 0 deletions
diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py
new file mode 100644
index 00000000..71192a92
--- /dev/null
+++ b/gguf-py/gguf/vocab.py
@@ -0,0 +1,164 @@
+from __future__ import annotations
+
+import json
+import os
+import sys
+from pathlib import Path
+from typing import Any, Callable
+
+from .gguf_writer import GGUFWriter
+
+
+class SpecialVocab:
+ merges: list[str]
+ add_special_token: dict[str, bool]
+ special_token_ids: dict[str, int]
+
+ def __init__(
+ self, path: str | os.PathLike[str], load_merges: bool = False,
+ special_token_types: tuple[str, ...] | None = None,
+ n_vocab: int | None = None,
+ ):
+ self.special_token_ids = {}
+ self.add_special_token = {}
+ self.n_vocab = n_vocab
+ self.load_merges = load_merges
+ self.merges = []
+ if special_token_types is not None:
+ self.special_token_types = special_token_types
+ else:
+ self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad')
+ self._load(Path(path))
+
+ def __repr__(self) -> str:
+ return '<SpecialVocab with {} merges, special tokens {}, add special tokens {}>'.format(
+ len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset",
+ )
+
+ def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None:
+ if self.merges:
+ if not quiet:
+ print(f'gguf: Adding {len(self.merges)} merge(s).')
+ gw.add_token_merges(self.merges)
+ elif self.load_merges:
+ print(
+ 'gguf: WARNING: Adding merges requested but no merges found, output may be non-functional.',
+ file = sys.stderr,
+ )
+ for typ, tokid in self.special_token_ids.items():
+ id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None)
+ if id_handler is None:
+ print(
+ f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping',
+ file = sys.stderr,
+ )
+ continue
+ if not quiet:
+ print(f'gguf: Setting special token type {typ} to {tokid}')
+ id_handler(tokid)
+ for typ, value in self.add_special_token.items():
+ add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None)
+ if add_handler is None:
+ print(
+ f'gguf: WARNING: No handler for add_{typ}_token with value {value} - skipping',
+ file = sys.stderr,
+ )
+ continue
+ if not quiet:
+ print(f'gguf: Setting add_{typ}_token to {value}')
+ add_handler(value)
+
+ def _load(self, path: Path) -> None:
+ self._try_load_from_tokenizer_json(path)
+ self._try_load_from_config_json(path)
+ if self.load_merges and not self.merges:
+ self._try_load_merges_txt(path)
+
+ def _try_load_merges_txt(self, path: Path) -> bool:
+ merges_file = path / 'merges.txt'
+ if not merges_file.is_file():
+ return False
+ with open(merges_file, 'r') as fp:
+ first_line = next(fp, '').strip()
+ if not first_line.startswith('#'):
+ fp.seek(0)
+ line_num = 0
+ else:
+ line_num = 1
+ merges = []
+ for line in fp:
+ line_num += 1
+ line = line.strip()
+ if not line:
+ continue
+ parts = line.split(None, 3)
+ if len(parts) != 2:
+ print(
+ f'gguf: WARNING: {merges_file.name}: Line {line_num}: Entry malformed, ignoring',
+ file = sys.stderr,
+ )
+ continue
+ merges.append(f'{parts[0]} {parts[1]}')
+ self.merges = merges
+ return True
+
+ def _set_special_token(self, typ: str, tid: Any) -> None:
+ if not isinstance(tid, int) or tid < 0:
+ return
+ if self.n_vocab is None or tid < self.n_vocab:
+ if typ in self.special_token_ids:
+ return
+ self.special_token_ids[typ] = tid
+ return
+ print(
+ f'gguf: WARNING: Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping',
+ file = sys.stderr,
+ )
+
+ def _try_load_from_tokenizer_json(self, path: Path) -> bool:
+ tokenizer_file = path / 'tokenizer.json'
+ if not tokenizer_file.is_file():
+ return False
+ with open(tokenizer_file, encoding = 'utf-8') as f:
+ tokenizer = json.load(f)
+ if self.load_merges:
+ merges = tokenizer.get('model', {}).get('merges')
+ if isinstance(merges, list) and merges and isinstance(merges[0], str):
+ self.merges = merges
+ tokenizer_config_file = path / 'tokenizer_config.json'
+ added_tokens = tokenizer.get('added_tokens')
+ if added_tokens is None or not tokenizer_config_file.is_file():
+ return True
+ with open(tokenizer_config_file, encoding = 'utf-8') as f:
+ tokenizer_config = json.load(f)
+ for typ in self.special_token_types:
+ add_entry = tokenizer_config.get(f'add_{typ}_token')
+ if isinstance(add_entry, bool):
+ self.add_special_token[typ] = add_entry
+ entry = tokenizer_config.get(f'{typ}_token')
+ if isinstance(entry, str):
+ tc_content = entry
+ elif isinstance(entry, dict):
+ entry_content = entry.get('content')
+ if not isinstance(entry_content, str):
+ continue
+ tc_content = entry_content
+ else:
+ continue
+ # We only need the first match here.
+ maybe_token_id = next(
+ (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content),
+ None,
+ )
+ self._set_special_token(typ, maybe_token_id)
+ return True
+
+ def _try_load_from_config_json(self, path: Path) -> bool:
+ config_file = path / 'config.json'
+ if not config_file.is_file():
+ return False
+ with open(config_file, encoding = 'utf-8') as f:
+ config = json.load(f)
+ for typ in self.special_token_types:
+ self._set_special_token(typ, config.get(f'{typ}_token_id'))
+ return True