diff options
Diffstat (limited to 'gguf-py/gguf/gguf_writer.py')
-rw-r--r-- | gguf-py/gguf/gguf_writer.py | 77 |
1 files changed, 68 insertions, 9 deletions
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index d9cfbf71..8dcf9330 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -7,7 +7,7 @@ import struct import tempfile from enum import Enum, auto from io import BufferedWriter -from typing import IO, Any, Sequence, Mapping +from typing import IO, Any, Callable, Sequence, Mapping from string import ascii_letters, digits import numpy as np @@ -28,6 +28,47 @@ from .constants import ( logger = logging.getLogger(__name__) +class LazyTensor: + data: Callable[[], np.ndarray[Any, Any]] + # to avoid too deep recursion + functions: list[Callable[[np.ndarray[Any, Any]], np.ndarray[Any, Any]]] + dtype: np.dtype[Any] + shape: tuple[int, ...] + + def __init__(self, data: Callable[[], np.ndarray[Any, Any]], *, dtype: type, shape: tuple[int, ...]): + self.data = data + self.functions = [] + self.dtype = np.dtype(dtype) + self.shape = shape + + def astype(self, dtype: type, **kwargs) -> LazyTensor: + self.functions.append(lambda n: n.astype(dtype, **kwargs)) + self.dtype = np.dtype(dtype) + return self + + @property + def nbytes(self) -> int: + size = 1 + for n in self.shape: + size *= n + return size * self.dtype.itemsize + + def tofile(self, *args, **kwargs) -> None: + data = self.data() + for f in self.functions: + data = f(data) + assert data.shape == self.shape + assert data.dtype == self.dtype + assert data.nbytes == self.nbytes + self.functions = [] + self.data = lambda: data + data.tofile(*args, **kwargs) + + def byteswap(self, *args, **kwargs) -> LazyTensor: + self.functions.append(lambda n: n.byteswap(*args, **kwargs)) + return self + + class WriterState(Enum): EMPTY = auto() HEADER = auto() @@ -38,7 +79,7 @@ class WriterState(Enum): class GGUFWriter: fout: BufferedWriter temp_file: tempfile.SpooledTemporaryFile[bytes] | None - tensors: list[np.ndarray[Any, Any]] + tensors: list[np.ndarray[Any, Any] | LazyTensor] _simple_value_packing = { GGUFValueType.UINT8: "B", GGUFValueType.INT8: "b", @@ -176,7 +217,7 @@ class GGUFWriter: if pack_fmt is not None: self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL) elif vtype == GGUFValueType.STRING: - encoded_val = val.encode("utf8") if isinstance(val, str) else val + encoded_val = val.encode("utf-8") if isinstance(val, str) else val self.kv_data += self._pack("Q", len(encoded_val)) self.kv_data += encoded_val elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and val: @@ -205,7 +246,7 @@ class GGUFWriter: raise ValueError(f'Duplicated tensor name {name}') self.ti_names.add(name) - encoded_name = name.encode("utf8") + encoded_name = name.encode("utf-8") self.ti_data += self._pack("Q", len(encoded_name)) self.ti_data += encoded_name n_dims = len(tensor_shape) @@ -237,7 +278,7 @@ class GGUFWriter: self.ti_data_count += 1 def add_tensor( - self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, + self, name: str, tensor: np.ndarray[Any, Any] | LazyTensor, raw_shape: Sequence[int] | None = None, raw_dtype: GGMLQuantizationType | None = None, ) -> None: if self.endianess == GGUFEndian.BIG: @@ -262,7 +303,7 @@ class GGUFWriter: if pad != 0: fp.write(bytes([0] * pad)) - def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: + def write_tensor_data(self, tensor: np.ndarray[Any, Any] | LazyTensor) -> None: if self.state is not WriterState.TI_DATA: raise ValueError(f'Expected output file to contain tensor info, got {self.state}') @@ -272,15 +313,33 @@ class GGUFWriter: tensor.tofile(self.fout) self.write_padding(self.fout, tensor.nbytes) - def write_tensors_to_file(self) -> None: + def write_tensors_to_file(self, *, progress: bool = False) -> None: self.write_ti_data_to_file() self.write_padding(self.fout, self.fout.tell()) if self.temp_file is None: + self.tensors.reverse() # to pop from the "beginning" in constant time + + if progress: + from tqdm import tqdm + + total_bytes = sum(t.nbytes for t in self.tensors) + + bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) + + while True: + try: + tensor = self.tensors.pop() + except IndexError: + break + tensor.tofile(self.fout) + bar.update(tensor.nbytes) + self.write_padding(self.fout, tensor.nbytes) + return while True: try: - tensor = self.tensors.pop(0) + tensor = self.tensors.pop() except IndexError: break tensor.tofile(self.fout) @@ -479,7 +538,7 @@ class GGUFWriter: self.add_bool(Keys.Tokenizer.ADD_PREFIX, value) def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: - if isinstance(value, list): + if not isinstance(value, str): template_default = None template_names = set() |