summaryrefslogtreecommitdiff
path: root/gguf-py/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'gguf-py/scripts')
-rw-r--r--gguf-py/scripts/__init__.py1
-rw-r--r--gguf-py/scripts/gguf-new-metadata.py190
2 files changed, 191 insertions, 0 deletions
diff --git a/gguf-py/scripts/__init__.py b/gguf-py/scripts/__init__.py
index 77132db7..1ad45639 100644
--- a/gguf-py/scripts/__init__.py
+++ b/gguf-py/scripts/__init__.py
@@ -8,5 +8,6 @@ os.environ["NO_LOCAL_GGUF"] = "TRUE"
gguf_convert_endian_entrypoint = import_module("scripts.gguf-convert-endian").main
gguf_dump_entrypoint = import_module("scripts.gguf-dump").main
gguf_set_metadata_entrypoint = import_module("scripts.gguf-set-metadata").main
+gguf_new_metadata_entrypoint = import_module("scripts.gguf-new-metadata").main
del import_module, os
diff --git a/gguf-py/scripts/gguf-new-metadata.py b/gguf-py/scripts/gguf-new-metadata.py
new file mode 100644
index 00000000..3444ab41
--- /dev/null
+++ b/gguf-py/scripts/gguf-new-metadata.py
@@ -0,0 +1,190 @@
+#!/usr/bin/env python3
+import logging
+import argparse
+import os
+import sys
+import json
+from pathlib import Path
+
+import numpy as np
+from typing import Any, Mapping, Sequence
+
+# Necessary to load the local gguf package
+if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists():
+ sys.path.insert(0, str(Path(__file__).parent.parent))
+
+import gguf
+
+logger = logging.getLogger("gguf-new-metadata")
+
+
+def get_byteorder(reader: gguf.GGUFReader) -> gguf.GGUFEndian:
+ if np.uint32(1) == np.uint32(1).newbyteorder("<"):
+ # Host is little endian
+ host_endian = gguf.GGUFEndian.LITTLE
+ swapped_endian = gguf.GGUFEndian.BIG
+ else:
+ # Sorry PDP or other weird systems that don't use BE or LE.
+ host_endian = gguf.GGUFEndian.BIG
+ swapped_endian = gguf.GGUFEndian.LITTLE
+
+ if reader.byte_order == "S":
+ return swapped_endian
+ else:
+ return host_endian
+
+
+def decode_field(field: gguf.ReaderField) -> Any:
+ if field and field.types:
+ main_type = field.types[0]
+
+ if main_type == gguf.GGUFValueType.ARRAY:
+ sub_type = field.types[-1]
+
+ if sub_type == gguf.GGUFValueType.STRING:
+ return [str(bytes(field.parts[idx]), encoding='utf8') for idx in field.data]
+ else:
+ return [pv for idx in field.data for pv in field.parts[idx].tolist()]
+ if main_type == gguf.GGUFValueType.STRING:
+ return str(bytes(field.parts[-1]), encoding='utf8')
+ else:
+ return field.parts[-1][0]
+
+ return None
+
+
+def get_field_data(reader: gguf.GGUFReader, key: str) -> Any:
+ field = reader.get_field(key)
+
+ return decode_field(field)
+
+
+def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: Mapping[str, str], remove_metadata: Sequence[str]) -> None:
+ for field in reader.fields.values():
+ # Suppress virtual fields and fields written by GGUFWriter
+ if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'):
+ logger.debug(f'Suppressing {field.name}')
+ continue
+
+ # Skip old chat templates if we have new ones
+ if field.name.startswith(gguf.Keys.Tokenizer.CHAT_TEMPLATE) and gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
+ logger.debug(f'Skipping {field.name}')
+ continue
+
+ if field.name in remove_metadata:
+ logger.debug(f'Removing {field.name}')
+ continue
+
+ old_val = decode_field(field)
+ val = new_metadata.get(field.name, old_val)
+
+ if field.name in new_metadata:
+ logger.debug(f'Modifying {field.name}: "{old_val}" -> "{val}"')
+ del new_metadata[field.name]
+ elif val is not None:
+ logger.debug(f'Copying {field.name}')
+
+ if val is not None:
+ writer.add_key(field.name)
+ writer.add_val(val, field.types[0])
+
+ if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata:
+ logger.debug('Adding chat template(s)')
+ writer.add_chat_template(new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE])
+ del new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE]
+
+ # TODO: Support other types than string?
+ for key, val in new_metadata.items():
+ logger.debug(f'Adding {key}: {val}')
+ writer.add_key(key)
+ writer.add_val(val, gguf.GGUFValueType.STRING)
+
+ for tensor in reader.tensors:
+ # Dimensions are written in reverse order, so flip them first
+ shape = np.flipud(tensor.shape)
+ writer.add_tensor_info(tensor.name, shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type)
+
+ writer.write_header_to_file()
+ writer.write_kv_data_to_file()
+ writer.write_ti_data_to_file()
+
+ for tensor in reader.tensors:
+ writer.write_tensor_data(tensor.data)
+
+ writer.close()
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Make a copy of a GGUF file with new metadata")
+ parser.add_argument("input", type=Path, help="GGUF format model input filename")
+ parser.add_argument("output", type=Path, help="GGUF format model output filename")
+ parser.add_argument("--general-name", type=str, help="The models general.name")
+ parser.add_argument("--general-description", type=str, help="The models general.description")
+ parser.add_argument("--chat-template", type=str, help="Chat template string (or JSON string containing templates)")
+ parser.add_argument("--chat-template-config", type=Path, help="Config file (tokenizer_config.json) containing chat template(s)")
+ parser.add_argument("--remove-metadata", action="append", type=str, help="Remove metadata (by key name) from output model")
+ parser.add_argument("--force", action="store_true", help="Bypass warnings without confirmation")
+ parser.add_argument("--verbose", action="store_true", help="Increase output verbosity")
+ args = parser.parse_args(None if len(sys.argv) > 2 else ["--help"])
+
+ logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
+
+ new_metadata = {}
+ remove_metadata = args.remove_metadata or []
+
+ if args.general_name:
+ new_metadata[gguf.Keys.General.NAME] = args.general_name
+
+ if args.general_description:
+ new_metadata[gguf.Keys.General.DESCRIPTION] = args.general_description
+
+ if args.chat_template:
+ new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = json.loads(args.chat_template) if args.chat_template.startswith('[') else args.chat_template
+
+ if args.chat_template_config:
+ with open(args.chat_template_config, 'r') as fp:
+ config = json.load(fp)
+ template = config.get('chat_template')
+ if template:
+ new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = template
+
+ if remove_metadata:
+ logger.warning('*** Warning *** Warning *** Warning **')
+ logger.warning('* Most metadata is required for a fully functional GGUF file,')
+ logger.warning('* removing crucial metadata may result in a corrupt output file!')
+
+ if not args.force:
+ logger.warning('* Enter exactly YES if you are positive you want to proceed:')
+ response = input('YES, I am sure> ')
+ if response != 'YES':
+ logger.info("You didn't enter YES. Okay then, see ya!")
+ sys.exit(0)
+
+ logger.info(f'* Loading: {args.input}')
+ reader = gguf.GGUFReader(args.input, 'r')
+
+ arch = get_field_data(reader, gguf.Keys.General.ARCHITECTURE)
+ endianess = get_byteorder(reader)
+
+ if os.path.isfile(args.output) and not args.force:
+ logger.warning('*** Warning *** Warning *** Warning **')
+ logger.warning(f'* The "{args.output}" GGUF file already exists, it will be overwritten!')
+ logger.warning('* Enter exactly YES if you are positive you want to proceed:')
+ response = input('YES, I am sure> ')
+ if response != 'YES':
+ logger.info("You didn't enter YES. Okay then, see ya!")
+ sys.exit(0)
+
+ logger.info(f'* Writing: {args.output}')
+ writer = gguf.GGUFWriter(args.output, arch=arch, endianess=endianess)
+
+ alignment = get_field_data(reader, gguf.Keys.General.ALIGNMENT)
+ if alignment is not None:
+ logger.debug(f'Setting custom alignment: {alignment}')
+ writer.data_alignment = alignment
+
+ copy_with_new_metadata(reader, writer, new_metadata, remove_metadata)
+
+
+if __name__ == '__main__':
+ main()