1
0

lazy.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. from __future__ import annotations
  2. from abc import ABC, ABCMeta, abstractmethod
  3. import logging
  4. from typing import Any, Callable
  5. from collections import deque
  6. import numpy as np
  7. from numpy._typing import _Shape
  8. from numpy.typing import DTypeLike
  9. logger = logging.getLogger(__name__)
  10. class LazyMeta(ABCMeta):
  11. def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs):
  12. def __getattr__(self, name: str) -> Any:
  13. meta_attr = getattr(self._meta, name)
  14. if callable(meta_attr):
  15. return type(self)._wrap_fn(
  16. (lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)),
  17. use_self=self,
  18. )
  19. elif isinstance(meta_attr, self._tensor_type):
  20. # e.g. self.T with torch.Tensor should still be wrapped
  21. return type(self)._wrap_fn(lambda s: getattr(s, name))(self)
  22. else:
  23. # no need to wrap non-tensor properties,
  24. # and they likely don't depend on the actual contents of the tensor
  25. return meta_attr
  26. namespace["__getattr__"] = __getattr__
  27. # need to make a builder for the wrapped wrapper to copy the name,
  28. # or else it fails with very cryptic error messages,
  29. # because somehow the same string would end up in every closures
  30. def mk_wrap(op_name: str, *, meta_noop: bool = False):
  31. # need to wrap the wrapper to get self
  32. def wrapped_special_op(self, *args, **kwargs):
  33. return type(self)._wrap_fn(
  34. getattr(type(self)._tensor_type, op_name),
  35. meta_noop=meta_noop,
  36. )(self, *args, **kwargs)
  37. return wrapped_special_op
  38. # special methods bypass __getattr__, so they need to be added manually
  39. # ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
  40. # NOTE: doing this from a metaclass is very convenient
  41. # TODO: make this even more comprehensive
  42. for binary_op in (
  43. "lt", "le", "eq", "ne", "ge", "gt", "not"
  44. "abs", "add", "and", "floordiv", "invert", "lshift", "mod", "mul", "matmul",
  45. "neg", "or", "pos", "pow", "rshift", "sub", "truediv", "xor",
  46. "iadd", "iand", "ifloordiv", "ilshift", "imod", "imul", "ior", "irshift", "isub", "ixor",
  47. "radd", "rand", "rfloordiv", "rmul", "ror", "rpow", "rsub", "rtruediv", "rxor",
  48. ):
  49. attr_name = f"__{binary_op}__"
  50. # the result of these operators usually has the same shape and dtype as the input,
  51. # so evaluation on the meta tensor can be skipped.
  52. namespace[attr_name] = mk_wrap(attr_name, meta_noop=True)
  53. for special_op in (
  54. "getitem", "setitem", "len",
  55. ):
  56. attr_name = f"__{special_op}__"
  57. namespace[attr_name] = mk_wrap(attr_name, meta_noop=False)
  58. return super().__new__(cls, name, bases, namespace, **kwargs)
  59. # Tree of lazy tensors
  60. class LazyBase(ABC, metaclass=LazyMeta):
  61. _tensor_type: type
  62. _meta: Any
  63. _data: Any | None
  64. _lazy: deque[LazyBase] # shared within a graph, to avoid deep recursion when making eager
  65. _args: tuple
  66. _func: Callable[[tuple], Any] | None
  67. def __init__(self, *, meta: Any, data: Any | None = None, lazy: deque[LazyBase] | None = None, args: tuple = (), func: Callable[[tuple], Any] | None = None):
  68. super().__init__()
  69. self._meta = meta
  70. self._data = data
  71. self._lazy = lazy if lazy is not None else deque()
  72. self._args = args
  73. self._func = func
  74. assert self._func is not None or self._data is not None
  75. if self._data is None:
  76. self._lazy.append(self)
  77. def __init_subclass__(cls) -> None:
  78. if "_tensor_type" not in cls.__dict__:
  79. raise TypeError(f"property '_tensor_type' must be defined for {cls!r}")
  80. return super().__init_subclass__()
  81. @staticmethod
  82. def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
  83. # TODO: dict and set
  84. if isinstance(o, (list, tuple)):
  85. L = []
  86. for item in o:
  87. L.append(LazyBase._recurse_apply(item, fn))
  88. if isinstance(o, tuple):
  89. L = tuple(L)
  90. return L
  91. elif isinstance(o, LazyBase):
  92. return fn(o)
  93. else:
  94. return o
  95. @classmethod
  96. def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]:
  97. def wrapped_fn(*args, **kwargs):
  98. if kwargs is None:
  99. kwargs = {}
  100. args = ((use_self,) if use_self is not None else ()) + args
  101. meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
  102. if isinstance(meta_noop, bool) and not meta_noop:
  103. try:
  104. res = fn(*meta_args, **kwargs)
  105. except NotImplementedError:
  106. # running some operations on PyTorch's Meta tensors can cause this exception
  107. res = None
  108. else:
  109. # some operators don't need to actually run on the meta tensors
  110. assert len(args) > 0
  111. res = args[0]
  112. assert isinstance(res, cls)
  113. res = res._meta
  114. # allow operations to override the dtype and shape
  115. if meta_noop is not True:
  116. if isinstance(meta_noop, tuple):
  117. dtype, shape = meta_noop
  118. assert callable(shape)
  119. res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape))
  120. else:
  121. res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
  122. if isinstance(res, cls._tensor_type):
  123. class CollectSharedLazy:
  124. # emulating a static variable
  125. shared_lazy: None | deque[LazyBase] = None
  126. @staticmethod
  127. def collect_replace(t: LazyBase):
  128. if CollectSharedLazy.shared_lazy is None:
  129. CollectSharedLazy.shared_lazy = t._lazy
  130. else:
  131. CollectSharedLazy.shared_lazy.extend(t._lazy)
  132. t._lazy = CollectSharedLazy.shared_lazy
  133. LazyBase._recurse_apply(args, CollectSharedLazy.collect_replace)
  134. shared_lazy = CollectSharedLazy.shared_lazy
  135. return cls(meta=cls.eager_to_meta(res), lazy=shared_lazy, args=args, func=lambda a: fn(*a, **kwargs))
  136. else:
  137. del res # not needed
  138. # non-tensor return likely relies on the contents of the args
  139. # (e.g. the result of torch.equal)
  140. eager_args = cls.to_eager(args)
  141. return fn(*eager_args, **kwargs)
  142. return wrapped_fn
  143. @classmethod
  144. def to_eager(cls, t: Any) -> Any:
  145. def simple_to_eager(_t: LazyBase) -> Any:
  146. def already_eager_to_eager(_t: LazyBase) -> Any:
  147. assert _t._data is not None
  148. return _t._data
  149. while _t._data is None:
  150. lt = _t._lazy.popleft()
  151. if lt._data is not None:
  152. # Lazy tensor did not belong in the lazy queue.
  153. # Weirdly only happens with Bloom models...
  154. # likely because tensors aren't unique in the queue.
  155. # The final output is still the same as in eager mode,
  156. # so it's safe to ignore this.
  157. continue
  158. assert lt._func is not None
  159. lt._args = cls._recurse_apply(lt._args, already_eager_to_eager)
  160. lt._data = lt._func(lt._args)
  161. # sanity check
  162. assert lt._data is not None
  163. assert lt._data.dtype == lt._meta.dtype
  164. assert lt._data.shape == lt._meta.shape
  165. return _t._data
  166. # recurse into lists and/or tuples, keeping their structure
  167. return cls._recurse_apply(t, simple_to_eager)
  168. @classmethod
  169. def eager_to_meta(cls, t: Any) -> Any:
  170. return cls.meta_with_dtype_and_shape(t.dtype, t.shape)
  171. # must be overridden, meta tensor init is backend-specific
  172. @classmethod
  173. @abstractmethod
  174. def meta_with_dtype_and_shape(cls, dtype: Any, shape: Any) -> Any: pass
  175. @classmethod
  176. def from_eager(cls, t: Any) -> Any:
  177. if type(t) is cls:
  178. # already eager
  179. return t
  180. elif isinstance(t, cls._tensor_type):
  181. return cls(meta=cls.eager_to_meta(t), data=t)
  182. else:
  183. return TypeError(f"{type(t)!r} is not compatible with {cls._tensor_type!r}")
  184. class LazyNumpyTensor(LazyBase):
  185. _tensor_type = np.ndarray
  186. @classmethod
  187. def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: _Shape) -> np.ndarray[Any, Any]:
  188. # The initial idea was to use np.nan as the fill value,
  189. # but non-float types like np.int16 can't use that.
  190. # So zero it is.
  191. cheat = np.zeros(1, dtype)
  192. return np.lib.stride_tricks.as_strided(cheat, shape, (0 for _ in shape))
  193. def astype(self, dtype, *args, **kwargs):
  194. meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
  195. full_args = (self, dtype,) + args
  196. # very important to pass the shared _lazy deque, or else there's an infinite loop somewhere.
  197. return type(self)(meta=meta, args=full_args, lazy=self._lazy, func=(lambda a: a[0].astype(*a[1:], **kwargs)))
  198. def tofile(self, *args, **kwargs):
  199. eager = LazyNumpyTensor.to_eager(self)
  200. return eager.tofile(*args, **kwargs)
  201. # TODO: __array_function__