summaryrefslogtreecommitdiff
path: root/gguf-py/gguf/vocab.py
diff options
context:
space:
mode:
authorslaren <slarengh@gmail.com>2023-11-19 11:10:52 +0100
committerGitHub <noreply@github.com>2023-11-19 11:10:52 +0100
commite937066420b79a757bf80e9836eb12b88420a218 (patch)
tree93a3c3a9889c4367e6e6310674d0661cbf0f5d5f /gguf-py/gguf/vocab.py
parent28a2e6e7d476717881be6eb9e2d3331342cec57b (diff)
gguf-py : export chat templates (#4125)
* gguf-py : export chat templates * llama.cpp : escape new lines in gguf kv info prints * gguf-py : bump version * gguf-py : check chat_template type * gguf-py : initialize chat_template
Diffstat (limited to 'gguf-py/gguf/vocab.py')
-rw-r--r--gguf-py/gguf/vocab.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py
index b9f50a0a..de3e5edb 100644
--- a/gguf-py/gguf/vocab.py
+++ b/gguf-py/gguf/vocab.py
@@ -13,6 +13,7 @@ class SpecialVocab:
merges: list[str]
add_special_token: dict[str, bool]
special_token_ids: dict[str, int]
+ chat_template: str | None
def __init__(
self, path: str | os.PathLike[str], load_merges: bool = False,
@@ -24,6 +25,7 @@ class SpecialVocab:
self.n_vocab = n_vocab
self.load_merges = load_merges
self.merges = []
+ self.chat_template = None
if special_token_types is not None:
self.special_token_types = special_token_types
else:
@@ -67,6 +69,10 @@ class SpecialVocab:
if not quiet:
print(f'gguf: Setting add_{typ}_token to {value}')
add_handler(value)
+ if self.chat_template is not None:
+ if not quiet:
+ print(f'gguf: Setting chat_template to {self.chat_template}')
+ gw.add_chat_template(self.chat_template)
def _load(self, path: Path) -> None:
self._try_load_from_tokenizer_json(path)
@@ -132,6 +138,14 @@ class SpecialVocab:
return True
with open(tokenizer_config_file, encoding = 'utf-8') as f:
tokenizer_config = json.load(f)
+ chat_template = tokenizer_config.get('chat_template')
+ if chat_template is None or isinstance(chat_template, str):
+ self.chat_template = chat_template
+ else:
+ print(
+ f'gguf: WARNING: Bad type for chat_template field in {tokenizer_config_file!r} - ignoring',
+ file = sys.stderr
+ )
for typ in self.special_token_types:
add_entry = tokenizer_config.get(f'add_{typ}_token')
if isinstance(add_entry, bool):