summaryrefslogtreecommitdiff
path: root/gguf-py/gguf/gguf_writer.py
diff options
context:
space:
mode:
Diffstat (limited to 'gguf-py/gguf/gguf_writer.py')
-rw-r--r--gguf-py/gguf/gguf_writer.py31
1 files changed, 29 insertions, 2 deletions
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py
index ff9326d5..e3dbca45 100644
--- a/gguf-py/gguf/gguf_writer.py
+++ b/gguf-py/gguf/gguf_writer.py
@@ -6,7 +6,8 @@ import struct
import tempfile
from enum import Enum, auto
from io import BufferedWriter
-from typing import IO, Any, Sequence
+from typing import IO, Any, Sequence, Mapping
+from string import ascii_letters, digits
import numpy as np
@@ -466,7 +467,33 @@ class GGUFWriter:
def add_add_space_prefix(self, value: bool) -> None:
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
- def add_chat_template(self, value: str) -> None:
+ def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
+ if isinstance(value, list):
+ template_default = None
+ template_names = set()
+
+ for choice in value:
+ name = choice.get('name', '')
+ template = choice.get('template')
+
+ # Allowing non-alphanumerical characters in template name is probably not a good idea, so filter it
+ name = ''.join((c if c in ascii_letters + digits else '_' for c in name))
+
+ if name and template is not None:
+ if name == 'default':
+ template_default = template
+ else:
+ template_names.add(name)
+ self.add_string(Keys.Tokenizer.CHAT_TEMPLATE_N.format(name=name), template)
+
+ if template_names:
+ self.add_array(Keys.Tokenizer.CHAT_TEMPLATES, list(template_names))
+
+ if template_default is None:
+ return
+
+ value = template_default
+
self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
def add_prefix_token_id(self, id: int) -> None: