Skip to content

Commit

Permalink
fix NestedTensor attributes lost in op
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Mar 30, 2023
1 parent 5153080 commit bbb7fe0
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 72 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ jobs:
- uses: actions/checkout@v3
- uses: ricardochaves/[email protected]
with:
use-pylint: false
use-pycodestyle: false
use-black: false
python-root-list: "danling"
extra-pylint-options: "--max-line-length 120"
extra-pylint-options: "--max-line-length 120 --disable E0012,E0401 --fail-under 9.2 --output-format=colorized"
extra-pycodestyle-options: "--max-line-length 120"
extra-flake8-options: "--max-line-length 120"
extra-black-options: "--line-length 120"
Expand Down
139 changes: 69 additions & 70 deletions danling/tensors/nested_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ class NestedTensor:

storage: Sequence[Tensor] = []
batch_first: bool = True
padding_value: SupportsFloat = 0.0
mask_value: bool = False

def __init__(
self,
Expand Down Expand Up @@ -276,25 +278,27 @@ def where(self, condition, other) -> NestedTensor:
"""

if isinstance(condition, NestedTensor) and isinstance(other, NestedTensor):
return NestedTensor(x.where(c, y) for x, c, y in zip(self.storage, condition.storage, other.storage))
return NestedTensor(
[x.where(c, y) for x, c, y in zip(self.storage, condition.storage, other.storage)], **self._state()
)
if isinstance(condition, NestedTensor):
return NestedTensor(x.where(c, other) for x, c in zip(self.storage, condition.storage))
return NestedTensor([x.where(c, other) for x, c in zip(self.storage, condition.storage)], **self._state())
if isinstance(other, NestedTensor):
return NestedTensor(x.where(condition, y) for x, y in zip(self.storage, other.storage))
return NestedTensor([x.where(condition, y) for x, y in zip(self.storage, other.storage)], **self._state())
return NestedTensor(x.where(condition, other) for x in self.storage)

def __abs__(self):
return NestedTensor(abs(value) for value in self.storage)
return NestedTensor([abs(value) for value in self.storage], **self._state())

def __add__(self, other):
if isinstance(other, NestedTensor):
return NestedTensor(x + y for x, y in zip(self.storage, other.storage))
return NestedTensor(value + other for value in self.storage)
return NestedTensor([x + y for x, y in zip(self.storage, other.storage)], **self._state())
return NestedTensor([value + other for value in self.storage], **self._state())

def __radd__(self, other):
if isinstance(other, NestedTensor):
return NestedTensor(y + x for x, y in zip(self.storage, other.storage))
return NestedTensor(other + value for value in self.storage)
return NestedTensor([y + x for x, y in zip(self.storage, other.storage)], **self._state())
return NestedTensor([other + value for value in self.storage], **self._state())

def __iadd__(self, other):
if isinstance(other, NestedTensor):
Expand All @@ -307,13 +311,13 @@ def __iadd__(self, other):

def __and__(self, other):
if isinstance(other, NestedTensor):
return NestedTensor(x & y for x, y in zip(self.storage, other.storage))
return NestedTensor(value & other for value in self.storage)
return NestedTensor([x & y for x, y in zip(self.storage, other.storage)], **self._state())
return NestedTensor([value & other for value in self.storage], **self._state())

def __rand__(self, other):
if isinstance(other, NestedTensor):
return NestedTensor(y & x for x, y in zip(self.storage, other.storage))
return NestedTensor(other & value for value in self.storage)
return NestedTensor([y & x for x, y in zip(self.storage, other.storage)], **self._state())
return NestedTensor([other & value for value in self.storage], **self._state())

def __iand__(self, other):
if isinstance(other, NestedTensor):
Expand All @@ -326,13 +330,13 @@ def __iand__(self, other):

def __floordiv__(self, other):
if isinstance(other, NestedTensor):
return NestedTensor(x // y for x, y in zip(self.storage, other.storage))
return NestedTensor(value // other for value in self.storage)
return NestedTensor([x // y for x, y in zip(self.storage, other.storage)], **self._state())
return NestedTensor([value // other for value in self.storage], **self._state())

def __rfloordiv__(self, other):
if isinstance(other, NestedTensor):
return NestedTensor(y // x for x, y in zip(self.storage, other.storage))
return NestedTensor(other // value for value in self.storage)
return NestedTensor([y // x for x, y in zip(self.storage, other.storage)], **self._state())
return NestedTensor([other // value for value in self.storage], **self._state())

def __ifloordiv__(self, other):
if isinstance(other, NestedTensor):
Expand All @@ -345,13 +349,13 @@ def __ifloordiv__(self, other):

def __mod__(self, other):
if isinstance(other, NestedTensor):
return NestedTensor(x % y for x, y in zip(self.storage, other.storage))
return NestedTensor(value % other for value in self.storage)
return NestedTensor([x % y for x, y in zip(self.storage, other.storage)], **self._state())
return NestedTensor([value % other for value in self.storage], **self._state())

def __rmod__(self, other):
if isinstance(other, NestedTensor):
return NestedTensor(y % x for x, y in zip(self.storage, other.storage))
return NestedTensor(other % value for value in self.storage)
return NestedTensor([y % x for x, y in zip(self.storage, other.storage)], **self._state())
return NestedTensor([other % value for value in self.storage], **self._state())

def __imod__(self, other):
if isinstance(other, NestedTensor):
Expand All @@ -364,13 +368,13 @@ def __imod__(self, other):

def __mul__(self, other):
if isinstance(other, NestedTensor):
return NestedTensor(x * y for x, y in zip(self.storage, other.storage))
return NestedTensor(value * other for value in self.storage)
return NestedTensor([x * y for x, y in zip(self.storage, other.storage)], **self._state())
return NestedTensor([value * other for value in self.storage], **self._state())

def __rmul__(self, other):
if isinstance(other, NestedTensor):
return NestedTensor(y * x for x, y in zip(self.storage, other.storage))
return NestedTensor(other * value for value in self.storage)
return NestedTensor([y * x for x, y in zip(self.storage, other.storage)], **self._state())
return NestedTensor([other * value for value in self.storage], **self._state())

def __imul__(self, other):
if isinstance(other, NestedTensor):
Expand All @@ -383,13 +387,13 @@ def __imul__(self, other):

def __matmul__(self, other):
if isinstance(other, NestedTensor):
return NestedTensor(x @ y for x, y in zip(self.storage, other.storage))
return NestedTensor(value @ other for value in self.storage)
return NestedTensor([x @ y for x, y in zip(self.storage, other.storage)], **self._state())
return NestedTensor([value @ other for value in self.storage], **self._state())

def __rmatmul__(self, other):
if isinstance(other, NestedTensor):
return NestedTensor(y @ x for x, y in zip(self.storage, other.storage))
return NestedTensor(other @ value for value in self.storage)
return NestedTensor([y @ x for x, y in zip(self.storage, other.storage)], **self._state())
return NestedTensor([other @ value for value in self.storage], **self._state())

def __imatmul__(self, other):
if isinstance(other, NestedTensor):
Expand All @@ -402,13 +406,13 @@ def __imatmul__(self, other):

def __pow__(self, other):
if isinstance(other, NestedTensor):
return NestedTensor(x**y for x, y in zip(self.storage, other.storage))
return NestedTensor(value**other for value in self.storage)
return NestedTensor([x**y for x, y in zip(self.storage, other.storage)], **self._state())
return NestedTensor([value**other for value in self.storage], **self._state())

def __rpow__(self, other):
if isinstance(other, NestedTensor):
return NestedTensor(y**x for x, y in zip(self.storage, other.storage))
return NestedTensor(other**value for value in self.storage)
return NestedTensor([y**x for x, y in zip(self.storage, other.storage)], **self._state())
return NestedTensor([other**value for value in self.storage], **self._state())

def __ipow__(self, other):
if isinstance(other, NestedTensor):
Expand All @@ -421,13 +425,13 @@ def __ipow__(self, other):

def __truediv__(self, other):
if isinstance(other, NestedTensor):
return NestedTensor(x / y for x, y in zip(self.storage, other.storage))
return NestedTensor(value / other for value in self.storage)
return NestedTensor([x / y for x, y in zip(self.storage, other.storage)], **self._state())
return NestedTensor([value / other for value in self.storage], **self._state())

def __rtruediv__(self, other):
if isinstance(other, NestedTensor):
return NestedTensor(y / x for x, y in zip(self.storage, other.storage))
return NestedTensor(other / value for value in self.storage)
return NestedTensor([y / x for x, y in zip(self.storage, other.storage)], **self._state())
return NestedTensor([other / value for value in self.storage], **self._state())

def __itruediv__(self, other):
if isinstance(other, NestedTensor):
Expand All @@ -440,13 +444,13 @@ def __itruediv__(self, other):

def __sub__(self, other):
if isinstance(other, NestedTensor):
return NestedTensor(x - y for x, y in zip(self.storage, other.storage))
return NestedTensor(value - other for value in self.storage)
return NestedTensor([x - y for x, y in zip(self.storage, other.storage)], **self._state())
return NestedTensor([value - other for value in self.storage], **self._state())

def __rsub__(self, other):
if isinstance(other, NestedTensor):
return NestedTensor(y - x for x, y in zip(self.storage, other.storage))
return NestedTensor(other - value for value in self.storage)
return NestedTensor([y - x for x, y in zip(self.storage, other.storage)], **self._state())
return NestedTensor([other - value for value in self.storage], **self._state())

def __isub__(self, other):
if isinstance(other, NestedTensor):
Expand All @@ -469,9 +473,9 @@ def __getattr__(self, name) -> Any:
ret = [getattr(i, name) for i in self.storage]
elem = ret[0]
if isinstance(elem, Tensor):
return NestedTensor(ret)
return NestedTensor(ret, **self._state())
if callable(elem):
return NestedTensorFuncWrapper(ret)
return NestedTensorFuncWrapper(ret, state=self._state())
if elem.__hash__ is not None and len(set(ret)) == 1:
return elem
return ret
Expand All @@ -494,14 +498,17 @@ def __eq__(self, other) -> Union[bool, Tensor, NestedTensor]: # type: ignore[ov
if isinstance(other, Tensor):
return self.tensor == other
if isinstance(other, SupportsFloat):
return NestedTensor(x == other for x in self.storage)
return NestedTensor([x == other for x in self.storage], **self._state())
raise NotImplementedError(f"Cannot compare {self.__class__.__name__} with {other.__class__.__name__}")

def __getstate__(self) -> Mapping:
def _state(self) -> Mapping:
return {k: v for k, v in self.__dict__.items() if k != "storage"}

def __state__(self) -> Mapping:
return self.__dict__

def __setstate__(self, states: Mapping) -> None:
self.__dict__.update(states)
def __setstate__(self, state: Mapping) -> None:
self.__dict__.update(state)

def __repr__(self):
return self.__class__.__name__ + repr(self.tensor)[len(self.tensor.__class__.__name__) :] # noqa: E203
Expand All @@ -516,10 +523,11 @@ def _tensor(storage, batch_first, padding_value: float = 0) -> Tensor:
@staticmethod
@lru_cache(maxsize=None)
def _mask(storage, mask_value: bool = False) -> Tensor:
# pylint: disable=E1101
if storage[0].dim() == 0:
return torch.ones(len(storage), dtype=torch.bool) # pylint: disable=E1101
lens = torch.tensor([len(t) for t in storage], device=storage[0].device) # pylint: disable=E1101
arange = torch.arange(max(lens), device=storage[0].device)[None, :] # pylint: disable=E1101
return torch.ones(len(storage), dtype=torch.bool)
lens = torch.tensor([len(t) for t in storage], device=storage[0].device)
arange = torch.arange(max(lens), device=storage[0].device)[None, :]
return arange >= lens[:, None] if mask_value else arange < lens[:, None]

@staticmethod
Expand All @@ -529,12 +537,11 @@ def _device(storage) -> torch.device: # pylint: disable=E1101

@staticmethod
@lru_cache(maxsize=None)
def _size(storage) -> torch.Size: # pylint: disable=E1101
def _size(storage) -> torch.Size:
# pylint: disable=E1101
if storage[0].dim() == 0:
return torch.Size([len(storage)]) # pylint: disable=E1101
return torch.Size( # pylint: disable=E1101
[len(storage), max(t.shape[0] for t in storage), *storage[0].shape[1:]]
)
return torch.Size([len(storage)])
return torch.Size([len(storage), max(t.shape[0] for t in storage), *storage[0].shape[1:]])


NestedTensorFunc = TorchFuncRegistry()
Expand All @@ -555,7 +562,7 @@ def mean(
def cat(tensors, dim: int = 0):
if dim != 0:
raise NotImplementedError(f"NestedTensor only supports cat when dim=0, but got {dim}")
return NestedTensor([t for tensor in tensors for t in tensor.storage])
return NestedTensor([t for tensor in tensors for t in tensor.storage], tensors[0]._state())


@NestedTensorFunc.implement(torch.stack) # pylint: disable=E1101
Expand All @@ -574,26 +581,15 @@ def isin(elements, test_elements, *, assume_unique: bool = False, invert: bool =

class NestedTensorFuncWrapper:
r"""
Wrapper for tensors to be converted to `NestedTensor`.
`PNTensor` is a subclass of `torch.Tensor`.
It implements two additional methods as `NestedTensor`: `tensor` and `mask`.
Although it is possible to construct `NestedTensor` in dataset,
the best practice is to do so in `collate_fn`.
However, it is hard to tell if a batch of `Tensor` should be stacked or converted to `NestedTensor`.
`PNTensor` is introduced overcome this limitation.
Convert tensors that will be converted to `NestedTensor` to a `PNTensor`,
and all you need to do is to convert `PNTensor` to `NestedTensor` in `collate_fn`.
Function Wrapper to handle NestedTensor as input.
"""

# pylint: disable=R0903

storage: Sequence[Callable] = []
state: Mapping = {}

def __init__(self, callables) -> None:
def __init__(self, callables, state: Optional[Mapping] = None) -> None:
if not isinstance(callables, Sequence):
raise ValueError(f"NestedTensorFuncWrapper must be initialised with a Sequence, bug got {type(callables)}")
if len(callables) == 0:
Expand All @@ -603,12 +599,15 @@ def __init__(self, callables) -> None:
f"NestedTensorFuncWrapper must be initialised with a Sequence of Callable, bug got {type(callables[0])}"
)
self.storage = callables
if state is None:
state = {}
self.state = state

def __call__(self, *args, **kwargs) -> Union[NestedTensor, Sequence[Tensor]]:
ret = [call(*args, **kwargs) for call in self.storage]
elem = ret[0]
if isinstance(elem, Tensor):
return NestedTensor(ret)
return NestedTensor(ret, **self.state)
if elem.__hash__ is not None and len(set(ret)) == 1:
return elem
return ret

0 comments on commit bbb7fe0

Please sign in to comment.