summaryrefslogtreecommitdiff
path: root/gguf-py/gguf/gguf_writer.py
diff options
context:
space:
mode:
Diffstat (limited to 'gguf-py/gguf/gguf_writer.py')
-rw-r--r--gguf-py/gguf/gguf_writer.py51
1 files changed, 5 insertions, 46 deletions
diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py
index 8dcf9330..96574358 100644
--- a/gguf-py/gguf/gguf_writer.py
+++ b/gguf-py/gguf/gguf_writer.py
@@ -7,7 +7,7 @@ import struct
import tempfile
from enum import Enum, auto
from io import BufferedWriter
-from typing import IO, Any, Callable, Sequence, Mapping
+from typing import IO, Any, Sequence, Mapping
from string import ascii_letters, digits
import numpy as np
@@ -28,47 +28,6 @@ from .constants import (
logger = logging.getLogger(__name__)
-class LazyTensor:
- data: Callable[[], np.ndarray[Any, Any]]
- # to avoid too deep recursion
- functions: list[Callable[[np.ndarray[Any, Any]], np.ndarray[Any, Any]]]
- dtype: np.dtype[Any]
- shape: tuple[int, ...]
-
- def __init__(self, data: Callable[[], np.ndarray[Any, Any]], *, dtype: type, shape: tuple[int, ...]):
- self.data = data
- self.functions = []
- self.dtype = np.dtype(dtype)
- self.shape = shape
-
- def astype(self, dtype: type, **kwargs) -> LazyTensor:
- self.functions.append(lambda n: n.astype(dtype, **kwargs))
- self.dtype = np.dtype(dtype)
- return self
-
- @property
- def nbytes(self) -> int:
- size = 1
- for n in self.shape:
- size *= n
- return size * self.dtype.itemsize
-
- def tofile(self, *args, **kwargs) -> None:
- data = self.data()
- for f in self.functions:
- data = f(data)
- assert data.shape == self.shape
- assert data.dtype == self.dtype
- assert data.nbytes == self.nbytes
- self.functions = []
- self.data = lambda: data
- data.tofile(*args, **kwargs)
-
- def byteswap(self, *args, **kwargs) -> LazyTensor:
- self.functions.append(lambda n: n.byteswap(*args, **kwargs))
- return self
-
-
class WriterState(Enum):
EMPTY = auto()
HEADER = auto()
@@ -79,7 +38,7 @@ class WriterState(Enum):
class GGUFWriter:
fout: BufferedWriter
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
- tensors: list[np.ndarray[Any, Any] | LazyTensor]
+ tensors: list[np.ndarray[Any, Any]]
_simple_value_packing = {
GGUFValueType.UINT8: "B",
GGUFValueType.INT8: "b",
@@ -278,7 +237,7 @@ class GGUFWriter:
self.ti_data_count += 1
def add_tensor(
- self, name: str, tensor: np.ndarray[Any, Any] | LazyTensor, raw_shape: Sequence[int] | None = None,
+ self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None,
) -> None:
if self.endianess == GGUFEndian.BIG:
@@ -303,7 +262,7 @@ class GGUFWriter:
if pad != 0:
fp.write(bytes([0] * pad))
- def write_tensor_data(self, tensor: np.ndarray[Any, Any] | LazyTensor) -> None:
+ def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
if self.state is not WriterState.TI_DATA:
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')
@@ -391,7 +350,7 @@ class GGUFWriter:
def add_name(self, name: str) -> None:
self.add_string(Keys.General.NAME, name)
- def add_quantization_version(self, quantization_version: GGMLQuantizationType) -> None:
+ def add_quantization_version(self, quantization_version: int) -> None:
self.add_uint32(
Keys.General.QUANTIZATION_VERSION, quantization_version)