diff options
Diffstat (limited to 'gguf-py/gguf/lazy.py')
-rw-r--r-- | gguf-py/gguf/lazy.py | 71 |
1 files changed, 23 insertions, 48 deletions
diff --git a/gguf-py/gguf/lazy.py b/gguf-py/gguf/lazy.py index 1167335b..ac98d9a9 100644 --- a/gguf-py/gguf/lazy.py +++ b/gguf-py/gguf/lazy.py @@ -3,10 +3,8 @@ from abc import ABC, ABCMeta, abstractmethod import logging from typing import Any, Callable -from collections import deque import numpy as np -from numpy._typing import _Shape from numpy.typing import DTypeLike @@ -16,16 +14,16 @@ logger = logging.getLogger(__name__) class LazyMeta(ABCMeta): def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs): - def __getattr__(self, __name: str) -> Any: - meta_attr = getattr(self._meta, __name) + def __getattr__(self, name: str) -> Any: + meta_attr = getattr(self._meta, name) if callable(meta_attr): return type(self)._wrap_fn( - (lambda s, *args, **kwargs: getattr(s, __name)(*args, **kwargs)), + (lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)), use_self=self, ) elif isinstance(meta_attr, self._tensor_type): # e.g. self.T with torch.Tensor should still be wrapped - return type(self)._wrap_fn(lambda s: getattr(s, __name))(self) + return type(self)._wrap_fn(lambda s: getattr(s, name))(self) else: # no need to wrap non-tensor properties, # and they likely don't depend on the actual contents of the tensor @@ -75,20 +73,18 @@ class LazyBase(ABC, metaclass=LazyMeta): _tensor_type: type _meta: Any _data: Any | None - _lazy: deque[LazyBase] # shared within a graph, to avoid deep recursion when making eager _args: tuple - _func: Callable[[tuple], Any] | None + _kwargs: dict[str, Any] + _func: Callable[[Any], Any] | None - def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None): + def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None): super().__init__() self._meta = meta self._data = data - self._lazy = lazy if lazy is not None else deque() self._args = args + self._kwargs = kwargs if kwargs is not None else {} self._func = func assert self._func is not None or self._data is not None - if self._data is None: - self._lazy.append(self) def __init_subclass__(cls) -> None: if "_tensor_type" not in cls.__dict__: @@ -118,6 +114,7 @@ class LazyBase(ABC, metaclass=LazyMeta): args = ((use_self,) if use_self is not None else ()) + args meta_args = LazyBase._recurse_apply(args, lambda t: t._meta) + # TODO: maybe handle tensors in kwargs too if isinstance(meta_noop, bool) and not meta_noop: try: @@ -141,21 +138,7 @@ class LazyBase(ABC, metaclass=LazyMeta): res = cls.meta_with_dtype_and_shape(meta_noop, res.shape) if isinstance(res, cls._tensor_type): - def collect_replace(t: LazyBase): - if collect_replace.shared_lazy is None: - collect_replace.shared_lazy = t._lazy - else: - collect_replace.shared_lazy.extend(t._lazy) - t._lazy = collect_replace.shared_lazy - - # emulating a static variable - collect_replace.shared_lazy = None - - LazyBase._recurse_apply(args, collect_replace) - - shared_lazy = collect_replace.shared_lazy - - return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs)) + return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn) else: del res # not needed # non-tensor return likely relies on the contents of the args @@ -167,25 +150,18 @@ class LazyBase(ABC, metaclass=LazyMeta): @classmethod def to_eager(cls, t: Any) -> Any: def simple_to_eager(_t: LazyBase) -> Any: - def already_eager_to_eager(_t: LazyBase) -> Any: - assert _t._data is not None + if _t._data is not None: return _t._data - while _t._data is None: - lt = _t._lazy.popleft() - if lt._data is not None: - # Lazy tensor did not belong in the lazy queue. - # Weirdly only happens with Bloom models... - # likely because tensors aren't unique in the queue. - # The final output is still the same as in eager mode, - # so it's safe to ignore this. - continue - assert lt._func is not None - lt._args = cls._recurse_apply(lt._args, already_eager_to_eager) - lt._data = lt._func(lt._args) - # sanity check - assert lt._data.dtype == lt._meta.dtype - assert lt._data.shape == lt._meta.shape + # NOTE: there's a recursion limit in Python (usually 1000) + + assert _t._func is not None + _t._args = cls._recurse_apply(_t._args, simple_to_eager) + _t._data = _t._func(*_t._args, **_t._kwargs) + # sanity check + assert _t._data is not None + assert _t._data.dtype == _t._meta.dtype + assert _t._data.shape == _t._meta.shape return _t._data @@ -204,7 +180,7 @@ class LazyBase(ABC, metaclass=LazyMeta): @classmethod def from_eager(cls, t: Any) -> Any: if type(t) is cls: - # already eager + # already lazy return t elif isinstance(t, cls._tensor_type): return cls(meta=cls.eager_to_meta(t), data=t) @@ -216,7 +192,7 @@ class LazyNumpyTensor(LazyBase): _tensor_type = np.ndarray @classmethod - def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: _Shape) -> np.ndarray[Any, Any]: + def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]: # The initial idea was to use np.nan as the fill value, # but non-float types like np.int16 can't use that. # So zero it is. @@ -226,8 +202,7 @@ class LazyNumpyTensor(LazyBase): def astype(self, dtype, *args, **kwargs): meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape) full_args = (self, dtype,) + args - # very important to pass the shared _lazy deque, or else there's an infinite loop somewhere. - return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs))) + return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs))) def tofile(self, *args, **kwargs): eager = LazyNumpyTensor.to_eager(self) |