diff options
Diffstat (limited to 'gguf-py/gguf/lazy.py')
-rw-r--r-- | gguf-py/gguf/lazy.py | 29 |
1 files changed, 20 insertions, 9 deletions
diff --git a/gguf-py/gguf/lazy.py b/gguf-py/gguf/lazy.py index 650bea11..1167335b 100644 --- a/gguf-py/gguf/lazy.py +++ b/gguf-py/gguf/lazy.py @@ -6,6 +6,7 @@ from typing import Any, Callable from collections import deque import numpy as np +from numpy._typing import _Shape from numpy.typing import DTypeLike @@ -110,7 +111,7 @@ class LazyBase(ABC, metaclass=LazyMeta): return o @classmethod - def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike = False) -> Callable[[Any], Any]: + def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]: def wrapped_fn(*args, **kwargs): if kwargs is None: kwargs = {} @@ -130,9 +131,14 @@ class LazyBase(ABC, metaclass=LazyMeta): res = args[0] assert isinstance(res, cls) res = res._meta - # allow operations to override the dtype + # allow operations to override the dtype and shape if meta_noop is not True: - res = cls.meta_with_dtype(res, meta_noop) + if isinstance(meta_noop, tuple): + dtype, shape = meta_noop + assert callable(shape) + res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape)) + else: + res = cls.meta_with_dtype_and_shape(meta_noop, res.shape) if isinstance(res, cls._tensor_type): def collect_replace(t: LazyBase): @@ -168,7 +174,12 @@ class LazyBase(ABC, metaclass=LazyMeta): while _t._data is None: lt = _t._lazy.popleft() if lt._data is not None: - raise ValueError(f"{lt} did not belong in the lazy queue") + # Lazy tensor did not belong in the lazy queue. + # Weirdly only happens with Bloom models... + # likely because tensors aren't unique in the queue. + # The final output is still the same as in eager mode, + # so it's safe to ignore this. + continue assert lt._func is not None lt._args = cls._recurse_apply(lt._args, already_eager_to_eager) lt._data = lt._func(lt._args) @@ -183,12 +194,12 @@ class LazyBase(ABC, metaclass=LazyMeta): @classmethod def eager_to_meta(cls, t: Any) -> Any: - return cls.meta_with_dtype(t, t.dtype) + return cls.meta_with_dtype_and_shape(t.dtype, t.shape) # must be overridden, meta tensor init is backend-specific @classmethod @abstractmethod - def meta_with_dtype(cls, m: Any, dtype: Any) -> Any: pass + def meta_with_dtype_and_shape(cls, dtype: Any, shape: Any) -> Any: pass @classmethod def from_eager(cls, t: Any) -> Any: @@ -205,15 +216,15 @@ class LazyNumpyTensor(LazyBase): _tensor_type = np.ndarray @classmethod - def meta_with_dtype(cls, m: np.ndarray[Any, Any], dtype: DTypeLike) -> np.ndarray[Any, Any]: + def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: _Shape) -> np.ndarray[Any, Any]: # The initial idea was to use np.nan as the fill value, # but non-float types like np.int16 can't use that. # So zero it is. cheat = np.zeros(1, dtype) - return np.lib.stride_tricks.as_strided(cheat, m.shape, (0 for _ in m.shape)) + return np.lib.stride_tricks.as_strided(cheat, shape, (0 for _ in shape)) def astype(self, dtype, *args, **kwargs): - meta = type(self).meta_with_dtype(self._meta, dtype) + meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape) full_args = (self, dtype,) + args # very important to pass the shared _lazy deque, or else there's an infinite loop somewhere. return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs))) |