1
0

lazy.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. from __future__ import annotations
  2. from abc import ABC, ABCMeta, abstractmethod
  3. import logging
  4. from typing import Any, Callable
  5. import numpy as np
  6. from numpy.typing import DTypeLike
  7. logger = logging.getLogger(__name__)
  8. class LazyMeta(ABCMeta):
  9. def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs):
  10. def __getattr__(self, name: str) -> Any:
  11. meta_attr = getattr(self._meta, name)
  12. if callable(meta_attr):
  13. return type(self)._wrap_fn(
  14. (lambda s, *args, **kwargs: getattr(s, name)(*args, **kwargs)),
  15. use_self=self,
  16. )
  17. elif isinstance(meta_attr, self._tensor_type):
  18. # e.g. self.T with torch.Tensor should still be wrapped
  19. return type(self)._wrap_fn(lambda s: getattr(s, name))(self)
  20. else:
  21. # no need to wrap non-tensor properties,
  22. # and they likely don't depend on the actual contents of the tensor
  23. return meta_attr
  24. namespace["__getattr__"] = __getattr__
  25. # need to make a builder for the wrapped wrapper to copy the name,
  26. # or else it fails with very cryptic error messages,
  27. # because somehow the same string would end up in every closures
  28. def mk_wrap(op_name: str, *, meta_noop: bool = False):
  29. # need to wrap the wrapper to get self
  30. def wrapped_special_op(self, *args, **kwargs):
  31. return type(self)._wrap_fn(
  32. getattr(type(self)._tensor_type, op_name),
  33. meta_noop=meta_noop,
  34. )(self, *args, **kwargs)
  35. return wrapped_special_op
  36. # special methods bypass __getattr__, so they need to be added manually
  37. # ref: https://docs.python.org/3/reference/datamodel.html#special-lookup
  38. # NOTE: doing this from a metaclass is very convenient
  39. # TODO: make this even more comprehensive
  40. for binary_op in (
  41. "lt", "le", "eq", "ne", "ge", "gt",
  42. "add", "and", "floordiv", "lshift", "mod", "mul", "matmul",
  43. "or", "pow", "rshift", "sub", "truediv", "xor",
  44. "iadd", "iand", "ifloordiv", "ilshift", "imod", "imul", "ior", "irshift", "isub", "ixor",
  45. "radd", "rand", "rfloordiv", "rmul", "ror", "rpow", "rsub", "rtruediv", "rxor",
  46. ):
  47. attr_name = f"__{binary_op}__"
  48. # evaluation on the meta tensor is needed in case there's broadcasting
  49. namespace[attr_name] = mk_wrap(attr_name, meta_noop=False)
  50. for unary_op in ("not", "abs", "invert", "neg", "pos"):
  51. attr_name = f"__{unary_op}__"
  52. # the result of these operators usually has the same shape and dtype as the input,
  53. # so evaluation on the meta tensor can be skipped.
  54. namespace[attr_name] = mk_wrap(attr_name, meta_noop=True)
  55. for special_op in (
  56. "getitem", "setitem", "len",
  57. ):
  58. attr_name = f"__{special_op}__"
  59. namespace[attr_name] = mk_wrap(attr_name, meta_noop=False)
  60. return super().__new__(cls, name, bases, namespace, **kwargs)
  61. # Tree of lazy tensors
  62. class LazyBase(ABC, metaclass=LazyMeta):
  63. _tensor_type: type
  64. _meta: Any
  65. _data: Any | None
  66. _args: tuple
  67. _kwargs: dict[str, Any]
  68. _func: Callable[[Any], Any] | None
  69. def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None):
  70. super().__init__()
  71. self._meta = meta
  72. self._data = data
  73. self._args = args
  74. self._kwargs = kwargs if kwargs is not None else {}
  75. self._func = func
  76. assert self._func is not None or self._data is not None
  77. def __init_subclass__(cls) -> None:
  78. if "_tensor_type" not in cls.__dict__:
  79. raise TypeError(f"property '_tensor_type' must be defined for {cls!r}")
  80. return super().__init_subclass__()
  81. @staticmethod
  82. def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any:
  83. # TODO: dict and set
  84. if isinstance(o, (list, tuple)):
  85. L = []
  86. for item in o:
  87. L.append(LazyBase._recurse_apply(item, fn))
  88. if isinstance(o, tuple):
  89. L = tuple(L)
  90. return L
  91. elif isinstance(o, LazyBase):
  92. return fn(o)
  93. else:
  94. return o
  95. @classmethod
  96. def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]:
  97. def wrapped_fn(*args, **kwargs):
  98. if kwargs is None:
  99. kwargs = {}
  100. args = ((use_self,) if use_self is not None else ()) + args
  101. meta_args = LazyBase._recurse_apply(args, lambda t: t._meta)
  102. # TODO: maybe handle tensors in kwargs too
  103. if isinstance(meta_noop, bool) and not meta_noop:
  104. try:
  105. res = fn(*meta_args, **kwargs)
  106. except NotImplementedError:
  107. # running some operations on PyTorch's Meta tensors can cause this exception
  108. res = None
  109. else:
  110. # some operators don't need to actually run on the meta tensors
  111. assert len(args) > 0
  112. res = args[0]
  113. assert isinstance(res, cls)
  114. res = res._meta
  115. # allow operations to override the dtype and shape
  116. if meta_noop is not True:
  117. if isinstance(meta_noop, tuple):
  118. dtype, shape = meta_noop
  119. assert callable(shape)
  120. res = cls.meta_with_dtype_and_shape(dtype, shape(res.shape))
  121. else:
  122. res = cls.meta_with_dtype_and_shape(meta_noop, res.shape)
  123. if isinstance(res, cls._tensor_type):
  124. return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn)
  125. elif isinstance(res, tuple) and all(isinstance(t, cls._tensor_type) for t in res):
  126. # share the evaluation between lazy tuple elements
  127. shared_args: list = [args, None]
  128. def eager_tuple_element(a: list[Any], i: int = 0, /, **kw) -> LazyBase:
  129. assert len(a) == 2
  130. if a[1] is None:
  131. a[1] = fn(*a[0], **kw)
  132. return a[1][i]
  133. return tuple(cls(meta=cls.eager_to_meta(res[i]), args=(shared_args, i), kwargs=kwargs, func=eager_tuple_element) for i in range(len(res)))
  134. else:
  135. del res # not needed
  136. # non-tensor return likely relies on the contents of the args
  137. # (e.g. the result of torch.equal)
  138. eager_args = cls.to_eager(args)
  139. return fn(*eager_args, **kwargs)
  140. return wrapped_fn
  141. @classmethod
  142. def to_eager(cls, t: Any) -> Any:
  143. def simple_to_eager(_t: LazyBase) -> Any:
  144. if _t._data is not None:
  145. return _t._data
  146. # NOTE: there's a recursion limit in Python (usually 1000)
  147. assert _t._func is not None
  148. _t._args = cls._recurse_apply(_t._args, simple_to_eager)
  149. _t._data = _t._func(*_t._args, **_t._kwargs)
  150. # sanity check
  151. assert _t._data is not None
  152. assert _t._data.dtype == _t._meta.dtype
  153. assert _t._data.shape == _t._meta.shape
  154. return _t._data
  155. # recurse into lists and/or tuples, keeping their structure
  156. return cls._recurse_apply(t, simple_to_eager)
  157. @classmethod
  158. def eager_to_meta(cls, t: Any) -> Any:
  159. return cls.meta_with_dtype_and_shape(t.dtype, t.shape)
  160. # must be overridden, meta tensor init is backend-specific
  161. @classmethod
  162. @abstractmethod
  163. def meta_with_dtype_and_shape(cls, dtype: Any, shape: Any) -> Any: pass
  164. @classmethod
  165. def from_eager(cls, t: Any) -> Any:
  166. if type(t) is cls:
  167. # already lazy
  168. return t
  169. elif isinstance(t, cls._tensor_type):
  170. return cls(meta=cls.eager_to_meta(t), data=t)
  171. else:
  172. return TypeError(f"{type(t)!r} is not compatible with {cls._tensor_type!r}")
  173. class LazyNumpyTensor(LazyBase):
  174. _tensor_type = np.ndarray
  175. shape: tuple[int, ...] # Makes the type checker happy in quants.py
  176. @classmethod
  177. def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]:
  178. # The initial idea was to use np.nan as the fill value,
  179. # but non-float types like np.int16 can't use that.
  180. # So zero it is.
  181. cheat = np.zeros(1, dtype)
  182. return np.lib.stride_tricks.as_strided(cheat, shape, (0 for _ in shape))
  183. def astype(self, dtype, *args, **kwargs):
  184. meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape)
  185. full_args = (self, dtype,) + args
  186. return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)))
  187. def tofile(self, *args, **kwargs):
  188. eager = LazyNumpyTensor.to_eager(self)
  189. return eager.tofile(*args, **kwargs)
  190. # TODO: __array_function__