|
|
@@ -3,7 +3,6 @@ from abc import ABC, ABCMeta, abstractmethod
|
|
|
|
|
|
import logging
|
|
|
from typing import Any, Callable
|
|
|
-from collections import deque
|
|
|
|
|
|
import numpy as np
|
|
|
from numpy.typing import DTypeLike
|
|
|
@@ -74,20 +73,18 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
|
|
_tensor_type: type
|
|
|
_meta: Any
|
|
|
_data: Any | None
|
|
|
- _lazy: deque[LazyBase] # shared within a graph, to avoid deep recursion when making eager
|
|
|
_args: tuple
|
|
|
- _func: Callable[[tuple], Any] | None
|
|
|
+ _kwargs: dict[str, Any]
|
|
|
+ _func: Callable[[Any], Any] | None
|
|
|
|
|
|
- def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None):
|
|
|
+ def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
|
|
|
super().__init__()
|
|
|
self._meta = meta
|
|
|
self._data = data
|
|
|
- self._lazy = lazy if lazy is not None else deque()
|
|
|
self._args = args
|
|
|
+ self._kwargs = kwargs if kwargs is not None else {}
|
|
|
self._func = func
|
|
|
assert self._func is not None or self._data is not None
|
|
|
- if self._data is None:
|
|
|
- self._lazy.append(self)
|
|
|
|
|
|
def __init_subclass__(cls) -> None:
|
|
|
if "_tensor_type" not in cls.__dict__:
|
|
|
@@ -117,6 +114,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
|
|
args = ((use_self,) if use_self is not None else ()) + args
|
|
|
|
|
|
meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
|
|
|
+ # TODO: maybe handle tensors in kwargs too
|
|
|
|
|
|
if isinstance(meta_noop, bool) and not meta_noop:
|
|
|
try:
|
|
|
@@ -140,23 +138,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
|
|
res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
|
|
|
|
|
|
if isinstance(res, cls._tensor_type):
|
|
|
- class CollectSharedLazy:
|
|
|
- # emulating a static variable
|
|
|
- shared_lazy: None | deque[LazyBase] = None
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def collect_replace(t: LazyBase):
|
|
|
- if CollectSharedLazy.shared_lazy is None:
|
|
|
- CollectSharedLazy.shared_lazy = t._lazy
|
|
|
- else:
|
|
|
- CollectSharedLazy.shared_lazy.extend(t._lazy)
|
|
|
- t._lazy = CollectSharedLazy.shared_lazy
|
|
|
-
|
|
|
- LazyBase._recurse_apply(args, CollectSharedLazy.collect_replace)
|
|
|
-
|
|
|
- shared_lazy = CollectSharedLazy.shared_lazy
|
|
|
-
|
|
|
- return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
|
|
|
+ return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
|
|
|
else:
|
|
|
del res # not needed
|
|
|
# non-tensor return likely relies on the contents of the args
|
|
|
@@ -168,26 +150,18 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
|
|
@classmethod
|
|
|
def to_eager(cls, t: Any) -> Any:
|
|
|
def simple_to_eager(_t: LazyBase) -> Any:
|
|
|
- def already_eager_to_eager(_t: LazyBase) -> Any:
|
|
|
- assert _t._data is not None
|
|
|
+ if _t._data is not None:
|
|
|
return _t._data
|
|
|
|
|
|
- while _t._data is None:
|
|
|
- lt = _t._lazy.popleft()
|
|
|
- if lt._data is not None:
|
|
|
- # Lazy tensor did not belong in the lazy queue.
|
|
|
- # Weirdly only happens with Bloom models...
|
|
|
- # likely because tensors aren't unique in the queue.
|
|
|
- # The final output is still the same as in eager mode,
|
|
|
- # so it's safe to ignore this.
|
|
|
- continue
|
|
|
- assert lt._func is not None
|
|
|
- lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
|
|
|
- lt._data = lt._func(lt._args)
|
|
|
- # sanity check
|
|
|
- assert lt._data is not None
|
|
|
- assert lt._data.dtype == lt._meta.dtype
|
|
|
- assert lt._data.shape == lt._meta.shape
|
|
|
+ # NOTE: there's a recursion limit in Python (usually 1000)
|
|
|
+
|
|
|
+ assert _t._func is not None
|
|
|
+ _t._args = cls._recurse_apply(_t._args, simple_to_eager)
|
|
|
+ _t._data = _t._func(*_t._args, **_t._kwargs)
|
|
|
+ # sanity check
|
|
|
+ assert _t._data is not None
|
|
|
+ assert _t._data.dtype == _t._meta.dtype
|
|
|
+ assert _t._data.shape == _t._meta.shape
|
|
|
|
|
|
return _t._data
|
|
|
|
|
|
@@ -206,7 +180,7 @@ class LazyBase(ABC, metaclass=LazyMeta):
|
|
|
@classmethod
|
|
|
def from_eager(cls, t: Any) -> Any:
|
|
|
if type(t) is cls:
|
|
|
- # already eager
|
|
|
+ # already lazy
|
|
|
return t
|
|
|
elif isinstance(t, cls._tensor_type):
|
|
|
return cls(meta=cls.eager_to_meta(t), data=t)
|
|
|
@@ -228,8 +202,7 @@ class LazyNumpyTensor(LazyBase):
|
|
|
def astype(self, dtype, *args, **kwargs):
|
|
|
meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
|
|
|
full_args = (self, dtype,) + args
|
|
|
- # very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
|
|
|
- return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))
|
|
|
+ return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))
|
|
|
|
|
|
def tofile(self, *args, **kwargs):
|
|
|
eager = LazyNumpyTensor.to_eager(self)
|