From 56c93cbc88a2ae63bcb4a4b4b11d7f7dba09558d Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 13 Sep 2024 17:48:02 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- .../linux_examples/scripts/run_all.sh | 8 +- .../linux_libs/scripts_rlhf/install.sh | 8 +- torchrl/__init__.py | 13 +- torchrl/data/tensor_specs.py | 112 +++++++++--------- 4 files changed, 73 insertions(+), 68 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_all.sh b/.github/unittest/linux_examples/scripts/run_all.sh index 18e6075baae..a4f2f07dc25 100755 --- a/.github/unittest/linux_examples/scripts/run_all.sh +++ b/.github/unittest/linux_examples/scripts/run_all.sh @@ -150,15 +150,15 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with %s\n" "${CU_VERSION}" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch torchvision numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/nightly/cpu -U + pip3 install --pre torch torchvision numpy==1.26.4 numpy-base==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cpu -U else - pip3 install --pre torch torchvision numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION + pip3 install --pre torch torchvision numpy==1.26.4 numpy-base==1.26.4 --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch torchvision numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/cpu + pip3 install torch torchvision numpy==1.26.4 numpy-base==1.26.4 --index-url https://download.pytorch.org/whl/cpu else - pip3 install torch torchvision numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/$CU_VERSION + pip3 install torch torchvision numpy==1.26.4 numpy-base==1.26.4 --index-url https://download.pytorch.org/whl/$CU_VERSION fi else printf "Failed to install pytorch" diff --git a/.github/unittest/linux_libs/scripts_rlhf/install.sh b/.github/unittest/linux_libs/scripts_rlhf/install.sh index 4c769ba9bd6..c33934b64f8 100755 --- a/.github/unittest/linux_libs/scripts_rlhf/install.sh +++ b/.github/unittest/linux_libs/scripts_rlhf/install.sh @@ -31,15 +31,15 @@ git submodule sync && git submodule update --init --recursive printf "Installing PyTorch with cu121" if [[ "$TORCH_VERSION" == "nightly" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install --pre torch numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/nightly/cpu -U + pip3 install --pre torch numpy==1.26.4 numpy-base==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cpu -U else - pip3 install --pre torch numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/nightly/cu121 -U + pip3 install --pre torch numpy==1.26.4 numpy-base==1.26.4 --index-url https://download.pytorch.org/whl/nightly/cu121 -U fi elif [[ "$TORCH_VERSION" == "stable" ]]; then if [ "${CU_VERSION:-}" == cpu ] ; then - pip3 install torch numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/cpu + pip3 install torch numpy==1.26.4 numpy-base==1.26.4 --index-url https://download.pytorch.org/whl/cpu else - pip3 install torch numpy==1.26.4 numpy-base<2.0 --index-url https://download.pytorch.org/whl/cu121 + pip3 install torch numpy==1.26.4 numpy-base==1.26.4 --index-url https://download.pytorch.org/whl/cu121 fi else printf "Failed to install pytorch" diff --git a/torchrl/__init__.py b/torchrl/__init__.py index aa5fbcb3d5c..cbd7b66a65e 100644 --- a/torchrl/__init__.py +++ b/torchrl/__init__.py @@ -54,11 +54,14 @@ _THREAD_POOL_INIT = torch.get_num_threads() + # monkey-patch dist transforms until https://github.com/pytorch/pytorch/pull/135001/ finds a home @property -def inv(self): - """ +def _inv(self): + """Patched version of Transform.inv. + Returns the inverse :class:`Transform` of this transform. + This should satisfy ``t.inv.inv is t``. """ inv = None @@ -71,11 +74,11 @@ def inv(self): return inv -torch.distributions.transforms.Transform.inv = inv +torch.distributions.transforms.Transform.inv = _inv @property -def inv(self): +def _inv(self): inv = None if self._inv is not None: inv = self._inv() @@ -91,4 +94,4 @@ def inv(self): return inv -ComposeTransform.inv = inv +ComposeTransform.inv = _inv diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 60c1009990e..98a32de5715 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -83,6 +83,12 @@ ) +def _size(list_of_ints): + # ensures that np int64 elements don't slip through Size + # see https://github.com/pytorch/pytorch/issues/127194 + return torch.Size([int(i) for i in list_of_ints]) + + # Akin to TD's NO_DEFAULT but won't raise a KeyError when found in a TD or used as default class _NoDefault(enum.IntEnum): ZERO = 0 @@ -640,7 +646,7 @@ def __ne__(self, other): def __setattr__(self, key, value): if key == "shape": - value = torch.Size(value) + value = _size(value) super().__setattr__(key, value) def to_numpy( @@ -686,7 +692,7 @@ def ndimension(self) -> int: @property def _safe_shape(self) -> torch.Size: """Returns a shape where all heterogeneous values are replaced by one (to be expandable).""" - return torch.Size([int(v) if v >= 0 else 1 for v in self.shape]) + return _size([int(v) if v >= 0 else 1 for v in self.shape]) @abc.abstractmethod def index( @@ -752,9 +758,7 @@ def make_neg_dim(self, dim: int) -> T: dim = self.ndim + dim if dim < 0 or dim > self.ndim - 1: raise ValueError(f"dim={dim} is out of bound for ndim={self.ndim}") - self.shape = torch.Size( - [s if i != dim else -1 for i, s in enumerate(self.shape)] - ) + self.shape = _size([s if i != dim else -1 for i, s in enumerate(self.shape)]) @overload def reshape(self, shape) -> T: @@ -914,7 +918,7 @@ def zero(self, shape: torch.Size = None) -> torch.Tensor | TensorDictBase: """ if shape is None: - shape = torch.Size([]) + shape = _size([]) return torch.zeros( (*shape, *self._safe_shape), dtype=self.dtype, device=self.device ) @@ -1318,7 +1322,7 @@ def shape(self): if dim < 0: dim = len(shape) + dim + 1 shape.insert(dim, len(self._specs)) - return torch.Size(shape) + return _size(shape) @shape.setter def shape(self, shape): @@ -1330,7 +1334,7 @@ def shape(self, shape): raise RuntimeError( f"The shape attribute mismatches between the input {shape} and self.shape={self.shape}." ) - shape_strip = torch.Size([s for i, s in enumerate(self.shape) if i != self.dim]) + shape_strip = _size([s for i, s in enumerate(self.shape) if i != self.dim]) for spec in self._specs: spec.shape = shape_strip @@ -1479,9 +1483,9 @@ def __init__( self.use_register = use_register space = CategoricalBox(n) if shape is None: - shape = torch.Size((space.n,)) + shape = _size((space.n,)) else: - shape = torch.Size(shape) + shape = _size(shape) if not len(shape) or shape[-1] != space.n: raise ValueError( f"The last value of the shape must match n for transform of type {self.__class__}. " @@ -1667,7 +1671,7 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor: if shape is None: shape = self.shape[:-1] else: - shape = torch.Size([*shape, *self.shape[:-1]]) + shape = _size([*shape, *self.shape[:-1]]) mask = self.mask if mask is None: n = self.space.n @@ -1746,7 +1750,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): indexed_shape = _shape_indexing(self.shape[:-1], idx) return self.__class__( n=self.space.n, - shape=torch.Size(indexed_shape + [self.shape[-1]]), + shape=_size(indexed_shape + [self.shape[-1]]), device=self.device, dtype=self.dtype, use_register=self.use_register, @@ -1997,9 +2001,9 @@ def __init__( ) if shape is not None and not isinstance(shape, torch.Size): if isinstance(shape, int): - shape = torch.Size([shape]) + shape = _size([shape]) else: - shape = torch.Size(list(shape)) + shape = _size(list(shape)) if shape is not None: shape_corr = _remove_neg_shapes(shape) else: @@ -2032,9 +2036,9 @@ def __init__( shape = low.shape else: if isinstance(shape_corr, float): - shape_corr = torch.Size([shape_corr]) + shape_corr = _size([shape_corr]) elif not isinstance(shape_corr, torch.Size): - shape_corr = torch.Size(shape_corr) + shape_corr = _size(shape_corr) shape_corr_err_msg = ( f"low and shape_corr mismatch, got {low.shape} and {shape_corr}" ) @@ -2167,7 +2171,7 @@ def unbind(self, dim: int = 0): def rand(self, shape: torch.Size = None) -> torch.Tensor: if shape is None: - shape = torch.Size([]) + shape = _size([]) a, b = self.space if self.dtype in (torch.float, torch.double, torch.half): shape = [*shape, *self._safe_shape] @@ -2191,9 +2195,7 @@ def rand(self, shape: torch.Size = None) -> torch.Tensor: else: mini = self.space.low interval = maxi - mini - r = torch.rand( - torch.Size([*shape, *self._safe_shape]), device=interval.device - ) + r = torch.rand(_size([*shape, *self._safe_shape]), device=interval.device) r = interval * r r = self.space.low + r r = r.to(self.dtype).to(self.device) @@ -2284,7 +2286,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): "Pending resolution of https://github.com/pytorch/pytorch/issues/100080." ) - indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) + indexed_shape = _size(_shape_indexing(self.shape, idx)) # Expand is required as pytorch.tensor indexing return self.__class__( low=self.space.low[idx].clone().expand(indexed_shape), @@ -2365,7 +2367,7 @@ def __init__( **kwargs, ): if isinstance(shape, int): - shape = torch.Size([shape]) + shape = _size([shape]) _, device = _default_dtype_and_device(None, device) domain = None @@ -2424,7 +2426,7 @@ def is_in(self, val: torch.Tensor) -> bool: def expand(self, *shape): if len(shape) == 1 and isinstance(shape[0], (tuple, list, torch.Size)): shape = shape[0] - shape = torch.Size(shape) + shape = _size(shape) if not all( (old == 1) or (old == new) for old, new in zip(self.shape, shape[-len(self.shape) :]) @@ -2447,7 +2449,7 @@ def _unflatten(self, dim, sizes): def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" - indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) + indexed_shape = _size(_shape_indexing(self.shape, idx)) return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype) def unbind(self, dim: int = 0): @@ -2548,7 +2550,7 @@ def __init__( **kwargs, ): if isinstance(shape, int): - shape = torch.Size([shape]) + shape = _size([shape]) dtype, device = _default_dtype_and_device(dtype, device) if dtype == torch.bool: @@ -2596,7 +2598,7 @@ def clone(self) -> Unbounded: def rand(self, shape: torch.Size = None) -> torch.Tensor: if shape is None: - shape = torch.Size([]) + shape = _size([]) shape = [*shape, *self.shape] if self.dtype.is_floating_point: return torch.randn(shape, device=self.device, dtype=self.dtype) @@ -2637,7 +2639,7 @@ def _unflatten(self, dim, sizes): def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" - indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) + indexed_shape = _size(_shape_indexing(self.shape, idx)) return self.__class__(shape=indexed_shape, device=self.device, dtype=self.dtype) def unbind(self, dim: int = 0): @@ -2754,9 +2756,9 @@ def __init__( self.nvec = nvec dtype, device = _default_dtype_and_device(dtype, device) if shape is None: - shape = torch.Size((sum(nvec),)) + shape = _size((sum(nvec),)) else: - shape = torch.Size(shape) + shape = _size(shape) if shape[-1] != sum(nvec): raise ValueError( f"The last value of the shape must match sum(nvec) for transform of type {self.__class__}. " @@ -2857,7 +2859,7 @@ def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: if shape is None: shape = self.shape[:-1] else: - shape = torch.Size([*shape, *self.shape[:-1]]) + shape = _size([*shape, *self.shape[:-1]]) mask = self.mask if mask is None: @@ -3133,7 +3135,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): indexed_shape = _shape_indexing(self.shape[:-1], idx) return self.__class__( nvec=self.nvec, - shape=torch.Size(indexed_shape + [self.shape[-1]]), + shape=_size(indexed_shape + [self.shape[-1]]), device=self.device, dtype=self.dtype, ) @@ -3198,7 +3200,7 @@ def __init__( mask: torch.Tensor | None = None, ): if shape is None: - shape = torch.Size([]) + shape = _size([]) dtype, device = _default_dtype_and_device(dtype, device) space = CategoricalBox(n) super().__init__( @@ -3241,12 +3243,12 @@ def update_mask(self, mask): def rand(self, shape: torch.Size = None) -> torch.Tensor: if shape is None: - shape = torch.Size([]) + shape = _size([]) if self.mask is None: return torch.randint( 0, self.space.n, - torch.Size([*shape, *self.shape]), + _size([*shape, *self.shape]), device=self.device, dtype=self.dtype, ) @@ -3266,7 +3268,7 @@ def _project(self, val: torch.Tensor) -> torch.Tensor: if self.mask is None: return val.clamp_(min=0, max=self.space.n - 1) shape = self.mask.shape - shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) + shape = _size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) mask_expand = self.mask.expand(shape) gathered = mask_expand.gather(-1, val.unsqueeze(-1)) oob = ~gathered.all(-1) @@ -3285,14 +3287,14 @@ def is_in(self, val: torch.Tensor) -> bool: return False return (0 <= val).all() and (val < self.space.n).all() shape = self.mask.shape - shape = torch.Size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) + shape = _size([*torch.broadcast_shapes(shape[:-1], val.shape), shape[-1]]) mask_expand = self.mask.expand(shape) gathered = mask_expand.gather(-1, val.unsqueeze(-1)) return gathered.all() def __getitem__(self, idx: SHAPE_INDEX_TYPING): """Indexes the current TensorSpec based on the provided index.""" - indexed_shape = torch.Size(_shape_indexing(self.shape, idx)) + indexed_shape = _size(_shape_indexing(self.shape, idx)) return self.__class__( n=self.space.n, shape=indexed_shape, @@ -3535,9 +3537,9 @@ def __init__( if n is None: n = shape[-1] if shape is None or not len(shape): - shape = torch.Size((n,)) + shape = _size((n,)) else: - shape = torch.Size(shape) + shape = _size(shape) if shape[-1] != n: raise ValueError( f"The last value of the shape must match n for spec {self.__class__}. " @@ -3636,7 +3638,7 @@ def __getitem__(self, idx: SHAPE_INDEX_TYPING): indexed_shape = _shape_indexing(self.shape[:-1], idx) return self.__class__( n=self.shape[-1], - shape=torch.Size(indexed_shape + [self.shape[-1]]), + shape=_size(indexed_shape + [self.shape[-1]]), device=self.device, dtype=self.dtype, ) @@ -3697,7 +3699,7 @@ def __init__( if shape is None: shape = nvec.shape else: - shape = torch.Size(shape) + shape = _size(shape) if shape[-1] != nvec.shape[-1]: raise ValueError( f"The last value of the shape must match nvec.shape[-1] for transform of type {self.__class__}. " @@ -3827,7 +3829,7 @@ def rand(self, shape: Optional[torch.Size] = None) -> torch.Tensor: *self.shape[:-1], ) x = self._rand(space=self.space, shape=shape, i=self.nvec.ndim) - if self.remove_singleton and self.shape == torch.Size([1]): + if self.remove_singleton and self.shape == _size([1]): x = x.squeeze(-1) return x @@ -4174,7 +4176,7 @@ def shape(self, value: torch.Size): f"{self.ndim} first dimensions should match but got self['{key}'].shape={spec.shape} and " f"Composite.shape={self.shape}." ) - self._shape = torch.Size(value) + self._shape = _size(value) def is_empty(self): """Whether the composite spec contains specs or not.""" @@ -4211,8 +4213,8 @@ def __init__( shape = batch_size if shape is None: - shape = torch.Size(()) - self._shape = torch.Size(shape) + shape = _size(()) + self._shape = _size(shape) self._specs = {} for key, value in kwargs.items(): self.set(key, value) @@ -4384,7 +4386,7 @@ def encode( if isinstance(vals, TensorDict): out = vals.empty() # create and empty tensordict similar to vals else: - out = TensorDict._new_unsafe({}, torch.Size([])) + out = TensorDict._new_unsafe({}, _size([])) for key, item in vals.items(): if item is None: raise RuntimeError( @@ -4444,7 +4446,7 @@ def project(self, val: TensorDictBase) -> TensorDictBase: def rand(self, shape: torch.Size = None) -> TensorDictBase: if shape is None: - shape = torch.Size([]) + shape = _size([]) _dict = {} for key, item in self.items(): if item is not None: @@ -4453,7 +4455,7 @@ def rand(self, shape: torch.Size = None) -> TensorDictBase: # TensorDict requirements return TensorDict._new_unsafe( _dict, - batch_size=torch.Size([*shape, *self.shape]), + batch_size=_size([*shape, *self.shape]), device=self._device, ) @@ -4621,7 +4623,7 @@ def to_numpy(self, val: TensorDict, safe: bool = None) -> dict: def zero(self, shape: torch.Size = None) -> TensorDictBase: if shape is None: - shape = torch.Size([]) + shape = _size([]) try: device = self.device except RuntimeError: @@ -4632,7 +4634,7 @@ def zero(self, shape: torch.Size = None) -> TensorDictBase: for key in self.keys(True) if isinstance(key, str) and self[key] is not None }, - torch.Size([*shape, *self._safe_shape]), + _size([*shape, *self._safe_shape]), device=device, ) @@ -5078,7 +5080,7 @@ def shape(self): if dim < 0: dim = len(shape) + dim + 1 shape.insert(dim, len(self._specs)) - return torch.Size(shape) + return _size(shape) def expand(self, *shape): if len(shape) == 1 and not isinstance(shape[0], (int,)): @@ -5279,7 +5281,7 @@ def _squeezed_shape(shape: torch.Size, dim: int | None) -> torch.Size | None: if dim is None: if len(shape) == 1 or shape.count(1) == 0: return None - new_shape = torch.Size([s for s in shape if s != 1]) + new_shape = _size([s for s in shape if s != 1]) else: if dim < 0: dim += len(shape) @@ -5287,7 +5289,7 @@ def _squeezed_shape(shape: torch.Size, dim: int | None) -> torch.Size | None: if shape[dim] != 1: return None - new_shape = torch.Size([s for i, s in enumerate(shape) if i != dim]) + new_shape = _size([s for i, s in enumerate(shape) if i != dim]) return new_shape @@ -5303,7 +5305,7 @@ def _unsqueezed_shape(shape: torch.Size, dim: int) -> torch.Size: new_shape = list(shape) new_shape.insert(dim, 1) - return torch.Size(new_shape) + return _size(new_shape) class _CompositeSpecItemsView: @@ -5451,7 +5453,7 @@ def _remove_neg_shapes(*shape): if isinstance(shape, np.integer): shape = (int(shape),) return _remove_neg_shapes(*shape) - return torch.Size([int(d) if d >= 0 else 1 for d in shape]) + return _size([int(d) if d >= 0 else 1 for d in shape]) ##############