diff options
author | slaren <slarengh@gmail.com> | 2023-11-19 11:10:52 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-19 11:10:52 +0100 |
commit | e937066420b79a757bf80e9836eb12b88420a218 (patch) | |
tree | 93a3c3a9889c4367e6e6310674d0661cbf0f5d5f /gguf-py/gguf/vocab.py | |
parent | 28a2e6e7d476717881be6eb9e2d3331342cec57b (diff) |
gguf-py : export chat templates (#4125)
* gguf-py : export chat templates
* llama.cpp : escape new lines in gguf kv info prints
* gguf-py : bump version
* gguf-py : check chat_template type
* gguf-py : initialize chat_template
Diffstat (limited to 'gguf-py/gguf/vocab.py')
-rw-r--r-- | gguf-py/gguf/vocab.py | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index b9f50a0a..de3e5edb 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -13,6 +13,7 @@ class SpecialVocab: merges: list[str] add_special_token: dict[str, bool] special_token_ids: dict[str, int] + chat_template: str | None def __init__( self, path: str | os.PathLike[str], load_merges: bool = False, @@ -24,6 +25,7 @@ class SpecialVocab: self.n_vocab = n_vocab self.load_merges = load_merges self.merges = [] + self.chat_template = None if special_token_types is not None: self.special_token_types = special_token_types else: @@ -67,6 +69,10 @@ class SpecialVocab: if not quiet: print(f'gguf: Setting add_{typ}_token to {value}') add_handler(value) + if self.chat_template is not None: + if not quiet: + print(f'gguf: Setting chat_template to {self.chat_template}') + gw.add_chat_template(self.chat_template) def _load(self, path: Path) -> None: self._try_load_from_tokenizer_json(path) @@ -132,6 +138,14 @@ class SpecialVocab: return True with open(tokenizer_config_file, encoding = 'utf-8') as f: tokenizer_config = json.load(f) + chat_template = tokenizer_config.get('chat_template') + if chat_template is None or isinstance(chat_template, str): + self.chat_template = chat_template + else: + print( + f'gguf: WARNING: Bad type for chat_template field in {tokenizer_config_file!r} - ignoring', + file = sys.stderr + ) for typ in self.special_token_types: add_entry = tokenizer_config.get(f'add_{typ}_token') if isinstance(add_entry, bool): |