summaryrefslogtreecommitdiff
path: root/gguf-py/gguf/lazy.py
diff options
context:
space:
mode:
authorKawrakow <48489457+ikawrakow@users.noreply.github.com>2024-07-27 07:55:01 +0200
committerGitHub <noreply@github.com>2024-07-27 07:55:01 +0200
commit154e0d75fccf1784fe9ff6fd76a630b66563da3d (patch)
tree81ce6dbb5b1900c1aa78a879f0593c694cab9d27 /gguf-py/gguf/lazy.py
parent0684c3e9c70d49323b4fc517128cbe222cab7f96 (diff)
Merge mainline llama.cpp (#3)
* Merging mainline - WIP * Merging mainline - WIP AVX2 and CUDA appear to work. CUDA performance seems slightly (~1-2%) lower as it is so often the case with llama.cpp/ggml after some "improvements" have been made. * Merging mainline - fix Metal * Remove check --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
Diffstat (limited to 'gguf-py/gguf/lazy.py')
-rw-r--r--gguf-py/gguf/lazy.py71
1 files changed, 23 insertions, 48 deletions
diff --git a/gguf-py/gguf/lazy.py b/gguf-py/gguf/lazy.py
index 1167335b..ac98d9a9 100644
--- a/gguf-py/gguf/lazy.py
+++ b/gguf-py/gguf/lazy.py
@@ -3,10 +3,8 @@ from abc import ABC, ABCMeta, abstractmethod
import logging
from typing import Any, Callable
-from collections import deque
import numpy as np
-from numpy._typing import _Shape
from numpy.typing import DTypeLike
@@ -16,16 +14,16 @@ logger = logging.getLogger(__name__)
class LazyMeta(ABCMeta):
def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs):
- def __getattr__(self, __name: str) -> Any:
- meta_attr = getattr(self._meta, __name)
+ def __getattr__(self, name: str) -> Any:
+ meta_attr = getattr(self._meta, name)
if callable(meta_attr):
return type(self)._wrap_fn(
- (lambda s, *args, **kwargs: getattr(s, __name)(*args, **kwargs)),
+ (lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)),
use_self=self,
)
elif isinstance(meta_attr, self._tensor_type):
# e.g. self.T with torch.Tensor should still be wrapped
- return type(self)._wrap_fn(lambda s: getattr(s, __name))(self)
+ return type(self)._wrap_fn(lambda s: getattr(s, name))(self)
else:
# no need to wrap non-tensor properties,
# and they likely don't depend on the actual contents of the tensor
@@ -75,20 +73,18 @@ class LazyBase(ABC, metaclass=LazyMeta):
_tensor_type: type
_meta: Any
_data: Any | None
- _lazy: deque[LazyBase] # shared within a graph, to avoid deep recursion when making eager
_args: tuple
- _func: Callable[[tuple], Any] | None
+ _kwargs: dict[str, Any]
+ _func: Callable[[Any], Any] | None
- def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None):
+ def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
super().__init__()
self._meta = meta
self._data = data
- self._lazy = lazy if lazy is not None else deque()
self._args = args
+ self._kwargs = kwargs if kwargs is not None else {}
self._func = func
assert self._func is not None or self._data is not None
- if self._data is None:
- self._lazy.append(self)
def __init_subclass__(cls) -> None:
if "_tensor_type" not in cls.__dict__:
@@ -118,6 +114,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
args = ((use_self,) if use_self is not None else ()) + args
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
+ # TODO: maybe handle tensors in kwargs too
if isinstance(meta_noop, bool) and not meta_noop:
try:
@@ -141,21 +138,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
if isinstance(res, cls._tensor_type):
- def collect_replace(t: LazyBase):
- if collect_replace.shared_lazy is None:
- collect_replace.shared_lazy = t._lazy
- else:
- collect_replace.shared_lazy.extend(t._lazy)
- t._lazy = collect_replace.shared_lazy
-
- # emulating a static variable
- collect_replace.shared_lazy = None
-
- LazyBase._recurse_apply(args, collect_replace)
-
- shared_lazy = collect_replace.shared_lazy
-
- return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
+ return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
else:
del res # not needed
# non-tensor return likely relies on the contents of the args
@@ -167,25 +150,18 @@ class LazyBase(ABC, metaclass=LazyMeta):
@classmethod
def to_eager(cls, t: Any) -> Any:
def simple_to_eager(_t: LazyBase) -> Any:
- def already_eager_to_eager(_t: LazyBase) -> Any:
- assert _t._data is not None
+ if _t._data is not None:
return _t._data
- while _t._data is None:
- lt = _t._lazy.popleft()
- if lt._data is not None:
- # Lazy tensor did not belong in the lazy queue.
- # Weirdly only happens with Bloom models...
- # likely because tensors aren't unique in the queue.
- # The final output is still the same as in eager mode,
- # so it's safe to ignore this.
- continue
- assert lt._func is not None
- lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
- lt._data = lt._func(lt._args)
- # sanity check
- assert lt._data.dtype == lt._meta.dtype
- assert lt._data.shape == lt._meta.shape
+ # NOTE: there's a recursion limit in Python (usually 1000)
+
+ assert _t._func is not None
+ _t._args = cls._recurse_apply(_t._args, simple_to_eager)
+ _t._data = _t._func(*_t._args, **_t._kwargs)
+ # sanity check
+ assert _t._data is not None
+ assert _t._data.dtype == _t._meta.dtype
+ assert _t._data.shape == _t._meta.shape
return _t._data
@@ -204,7 +180,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
@classmethod
def from_eager(cls, t: Any) -> Any:
if type(t) is cls:
- # already eager
+ # already lazy
return t
elif isinstance(t, cls._tensor_type):
return cls(meta=cls.eager_to_meta(t), data=t)
@@ -216,7 +192,7 @@ class LazyNumpyTensor(LazyBase):
_tensor_type = np.ndarray
@classmethod
- def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: _Shape) -> np.ndarray[Any, Any]:
+ def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]:
# The initial idea was to use np.nan as the fill value,
# but non-float types like np.int16 can't use that.
# So zero it is.
@@ -226,8 +202,7 @@ class LazyNumpyTensor(LazyBase):
def astype(self, dtype, *args, **kwargs):
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
full_args = (self, dtype,) + args
- # very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
- return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))
+ return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))
def tofile(self, *args, **kwargs):
eager = LazyNumpyTensor.to_eager(self)