diff --git a/test/test_specs.py b/test/test_specs.py index 5334281f0ee..a75ff0352c7 100644 --- a/test/test_specs.py +++ b/test/test_specs.py @@ -59,316 +59,278 @@ ) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) -def test_bounded(dtype): - torch.manual_seed(0) - np.random.seed(0) - for _ in range(100): - bounds = torch.randn(2).sort()[0] - ts = Bounded(bounds[0].item(), bounds[1].item(), torch.Size((1,)), dtype=dtype) - _dtype = dtype - if dtype is None: - _dtype = torch.get_default_dtype() - - r = ts.rand() - assert ts.is_in(r) - assert r.dtype is _dtype - ts.is_in(ts.encode(bounds.mean())) - ts.is_in(ts.encode(bounds.mean().item())) - assert (ts.encode(ts.to_numpy(r)) == r).all() - - -@pytest.mark.parametrize("cls", [OneHot, Categorical]) -def test_discrete(cls): - torch.manual_seed(0) - np.random.seed(0) +class TestRanges: + @pytest.mark.parametrize( + "dtype", [torch.float32, torch.float16, torch.float64, None] + ) + def test_bounded(self, dtype): + torch.manual_seed(0) + np.random.seed(0) + for _ in range(100): + bounds = torch.randn(2).sort()[0] + ts = Bounded( + bounds[0].item(), bounds[1].item(), torch.Size((1,)), dtype=dtype + ) + _dtype = dtype + if dtype is None: + _dtype = torch.get_default_dtype() - ts = cls(10) - for _ in range(100): - r = ts.rand() - ts.to_numpy(r) - ts.encode(torch.tensor([5])) - ts.encode(torch.tensor(5).numpy()) - ts.encode(9) - with pytest.raises(AssertionError), set_global_var( - torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True - ): - ts.encode(torch.tensor([11])) # out of bounds - assert not torchrl.data.tensor_specs._CHECK_SPEC_ENCODE - assert ts.is_in(r) - assert (ts.encode(ts.to_numpy(r)) == r).all() + r = ts.rand() + assert (ts._project(r) == r).all() + assert ts.is_in(r) + assert r.dtype is _dtype + ts.is_in(ts.encode(bounds.mean())) + ts.is_in(ts.encode(bounds.mean().item())) + assert (ts.encode(ts.to_numpy(r)) == r).all() + @pytest.mark.parametrize("cls", [OneHot, Categorical]) + def test_discrete(self, cls): + torch.manual_seed(0) + np.random.seed(0) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) -def test_unbounded(dtype): - torch.manual_seed(0) - np.random.seed(0) - ts = Unbounded(dtype=dtype) - - if dtype is None: - dtype = torch.get_default_dtype() - for _ in range(100): - r = ts.rand() - ts.to_numpy(r) - assert ts.is_in(r) - assert r.dtype is dtype - assert (ts.encode(ts.to_numpy(r)) == r).all() + ts = cls(10) + for _ in range(100): + r = ts.rand() + assert (ts._project(r) == r).all() + ts.to_numpy(r) + ts.encode(torch.tensor([5])) + ts.encode(torch.tensor(5).numpy()) + ts.encode(9) + with pytest.raises(AssertionError), set_global_var( + torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True + ): + ts.encode(torch.tensor([11])) # out of bounds + assert not torchrl.data.tensor_specs._CHECK_SPEC_ENCODE + assert ts.is_in(r) + assert (ts.encode(ts.to_numpy(r)) == r).all() + @pytest.mark.parametrize( + "dtype", [torch.float32, torch.float16, torch.float64, None] + ) + def test_unbounded(self, dtype): + torch.manual_seed(0) + np.random.seed(0) + ts = Unbounded(dtype=dtype) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) -@pytest.mark.parametrize("shape", [[], torch.Size([3])]) -def test_ndbounded(dtype, shape): - torch.manual_seed(0) - np.random.seed(0) - - for _ in range(100): - lb = torch.rand(10) - 1 - ub = torch.rand(10) + 1 - ts = Bounded(lb, ub, dtype=dtype) - _dtype = dtype if dtype is None: - _dtype = torch.get_default_dtype() - - r = ts.rand(shape) - assert r.dtype is _dtype - assert r.shape == torch.Size([*shape, 10]) - assert (r >= lb.to(dtype)).all() and ( - r <= ub.to(dtype) - ).all(), f"{r[r <= lb] - lb.expand_as(r)[r <= lb]} -- {r[r >= ub] - ub.expand_as(r)[r >= ub]} " - ts.to_numpy(r) - assert ts.is_in(r) - ts.encode(lb + torch.rand(10) * (ub - lb)) - ts.encode((lb + torch.rand(10) * (ub - lb)).numpy()) - - if not shape: + dtype = torch.get_default_dtype() + for _ in range(100): + r = ts.rand() + assert (ts._project(r) == r).all() + ts.to_numpy(r) + assert ts.is_in(r) + assert r.dtype is dtype assert (ts.encode(ts.to_numpy(r)) == r).all() - else: - with pytest.raises(RuntimeError, match="Shape mismatch"): - ts.encode(ts.to_numpy(r)) - assert (ts.expand(*shape, *ts.shape).encode(ts.to_numpy(r)) == r).all() - - with pytest.raises(AssertionError), set_global_var( - torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True - ): - ts.encode(torch.rand(10) + 3) # out of bounds - with pytest.raises(AssertionError), set_global_var( - torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True - ): - ts.to_numpy(torch.rand(10) + 3) # out of bounds - assert not torchrl.data.tensor_specs._CHECK_SPEC_ENCODE + @pytest.mark.parametrize( + "dtype", [torch.float32, torch.float16, torch.float64, None] + ) + @pytest.mark.parametrize("shape", [[], torch.Size([3])]) + def test_ndbounded(self, dtype, shape): + torch.manual_seed(0) + np.random.seed(0) -@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.float64, None]) -@pytest.mark.parametrize("n", range(3, 10)) -@pytest.mark.parametrize( - "shape", - [ - [], - torch.Size( - [ - 3, - ] - ), - ], -) -def test_ndunbounded(dtype, n, shape): - torch.manual_seed(0) - np.random.seed(0) + for _ in range(100): + lb = torch.rand(10) - 1 + ub = torch.rand(10) + 1 + ts = Bounded(lb, ub, dtype=dtype) + _dtype = dtype + if dtype is None: + _dtype = torch.get_default_dtype() + + r = ts.rand(shape) + assert (ts._project(r) == r).all() + assert r.dtype is _dtype + assert r.shape == torch.Size([*shape, 10]) + assert (r >= lb.to(dtype)).all() and ( + r <= ub.to(dtype) + ).all(), f"{r[r <= lb] - lb.expand_as(r)[r <= lb]} -- {r[r >= ub] - ub.expand_as(r)[r >= ub]} " + ts.to_numpy(r) + assert ts.is_in(r) + ts.encode(lb + torch.rand(10) * (ub - lb)) + ts.encode((lb + torch.rand(10) * (ub - lb)).numpy()) + + if not shape: + assert (ts.encode(ts.to_numpy(r)) == r).all() + else: + with pytest.raises(RuntimeError, match="Shape mismatch"): + ts.encode(ts.to_numpy(r)) + assert (ts.expand(*shape, *ts.shape).encode(ts.to_numpy(r)) == r).all() + + with pytest.raises(AssertionError), set_global_var( + torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True + ): + ts.encode(torch.rand(10) + 3) # out of bounds + with pytest.raises(AssertionError), set_global_var( + torchrl.data.tensor_specs, "_CHECK_SPEC_ENCODE", True + ): + ts.to_numpy(torch.rand(10) + 3) # out of bounds + assert not torchrl.data.tensor_specs._CHECK_SPEC_ENCODE - ts = Unbounded( - shape=[ - n, - ], - dtype=dtype, + @pytest.mark.parametrize( + "dtype", [torch.float32, torch.float16, torch.float64, None] ) + @pytest.mark.parametrize("n", range(3, 10)) + @pytest.mark.parametrize("shape", [(), torch.Size([3])]) + def test_ndunbounded(self, dtype, n, shape): + torch.manual_seed(0) + np.random.seed(0) - if dtype is None: - dtype = torch.get_default_dtype() + ts = Unbounded(shape=[n], dtype=dtype) - for _ in range(100): - r = ts.rand(shape) - assert r.shape == torch.Size( - [ - *shape, - n, - ] - ) - ts.to_numpy(r) - assert ts.is_in(r) - assert r.dtype is dtype - if not shape: - assert (ts.encode(ts.to_numpy(r)) == r).all() - else: - with pytest.raises(RuntimeError, match="Shape mismatch"): - ts.encode(ts.to_numpy(r)) - assert (ts.expand(*shape, *ts.shape).encode(ts.to_numpy(r)) == r).all() + if dtype is None: + dtype = torch.get_default_dtype() + for _ in range(100): + r = ts.rand(shape) + assert (ts._project(r) == r).all() + assert r.shape == torch.Size( + [ + *shape, + n, + ] + ) + ts.to_numpy(r) + assert ts.is_in(r) + assert r.dtype is dtype + if not shape: + assert (ts.encode(ts.to_numpy(r)) == r).all() + else: + with pytest.raises(RuntimeError, match="Shape mismatch"): + ts.encode(ts.to_numpy(r)) + assert (ts.expand(*shape, *ts.shape).encode(ts.to_numpy(r)) == r).all() + + @pytest.mark.parametrize("n", range(3, 10)) + @pytest.mark.parametrize("shape", [(), torch.Size([3])]) + def test_binary(self, n, shape): + torch.manual_seed(0) + np.random.seed(0) -@pytest.mark.parametrize("n", range(3, 10)) -@pytest.mark.parametrize( - "shape", - [ - [], - torch.Size( - [ - 3, - ] - ), - ], -) -def test_binary(n, shape): - torch.manual_seed(0) - np.random.seed(0) - - ts = Binary(n) - for _ in range(100): - r = ts.rand(shape) - assert r.shape == torch.Size( - [ - *shape, - n, - ] - ) - assert ts.is_in(r) - assert ((r == 0) | (r == 1)).all() - if not shape: - assert (ts.encode(ts.to_numpy(r)) == r).all() - else: - with pytest.raises(RuntimeError, match="Shape mismatch"): - ts.encode(ts.to_numpy(r)) - assert (ts.expand(*shape, *ts.shape).encode(ts.to_numpy(r)) == r).all() + ts = Binary(n) + for _ in range(100): + r = ts.rand(shape) + assert (ts._project(r) == r).all() + assert r.shape == torch.Size([*shape, n]) + assert ts.is_in(r) + assert ((r == 0) | (r == 1)).all() + if not shape: + assert (ts.encode(ts.to_numpy(r)) == r).all() + else: + with pytest.raises(RuntimeError, match="Shape mismatch"): + ts.encode(ts.to_numpy(r)) + assert (ts.expand(*shape, *ts.shape).encode(ts.to_numpy(r)) == r).all() + @pytest.mark.parametrize( + "ns", + [ + [5], + [5, 2, 3], + [4, 4, 1], + ], + ) + @pytest.mark.parametrize("shape", [(), torch.Size([3])]) + def test_mult_onehot(self, shape, ns): + torch.manual_seed(0) + np.random.seed(0) + ts = MultiOneHot(nvec=ns) + for _ in range(100): + r = ts.rand(shape) + assert (ts._project(r) == r).all() + assert r.shape == torch.Size([*shape, sum(ns)]) + assert ts.is_in(r) + assert ((r == 0) | (r == 1)).all() + rsplit = r.split(ns, dim=-1) + for _r, _n in zip(rsplit, ns): + assert (_r.sum(-1) == 1).all() + assert _r.shape[-1] == _n + categorical = ts.to_categorical(r) + assert not ts.is_in(categorical) + # assert (ts.encode(categorical) == r).all() + if not shape: + assert (ts.encode(categorical) == r).all() + else: + with pytest.raises(RuntimeError, match="is invalid for input of size"): + ts.encode(categorical) + assert (ts.expand(*shape, *ts.shape).encode(categorical) == r).all() -@pytest.mark.parametrize( - "ns", - [ + @pytest.mark.parametrize( + "ns", [ 5, + [5, 2, 3], + [4, 5, 1, 3], + [[1, 2], [3, 4]], + [[[2, 4], [3, 5]], [[4, 5], [2, 3]], [[2, 3], [3, 2]]], ], - [5, 2, 3], - [4, 4, 1], - ], -) -@pytest.mark.parametrize( - "shape", - [ - [], - torch.Size( - [ - 3, - ] - ), - ], -) -def test_mult_onehot(shape, ns): - torch.manual_seed(0) - np.random.seed(0) - ts = MultiOneHot(nvec=ns) - for _ in range(100): - r = ts.rand(shape) - assert r.shape == torch.Size( - [ - *shape, - sum(ns), - ] - ) - assert ts.is_in(r) - assert ((r == 0) | (r == 1)).all() - rsplit = r.split(ns, dim=-1) - for _r, _n in zip(rsplit, ns): - assert (_r.sum(-1) == 1).all() - assert _r.shape[-1] == _n - categorical = ts.to_categorical(r) - assert not ts.is_in(categorical) - # assert (ts.encode(categorical) == r).all() - if not shape: - assert (ts.encode(categorical) == r).all() - else: - with pytest.raises(RuntimeError, match="is invalid for input of size"): - ts.encode(categorical) - assert (ts.expand(*shape, *ts.shape).encode(categorical) == r).all() - - -@pytest.mark.parametrize( - "ns", - [ - 5, - [5, 2, 3], - [4, 5, 1, 3], - [[1, 2], [3, 4]], - [[[2, 4], [3, 5]], [[4, 5], [2, 3]], [[2, 3], [3, 2]]], - ], -) -@pytest.mark.parametrize("shape", [None, [], torch.Size([3]), torch.Size([4, 5])]) -@pytest.mark.parametrize("dtype", [torch.float, torch.int, torch.long]) -def test_multi_discrete(shape, ns, dtype): - torch.manual_seed(0) - np.random.seed(0) - ts = MultiCategorical(ns, dtype=dtype) - _real_shape = shape if shape is not None else [] - nvec_shape = torch.tensor(ns).size() - for _ in range(100): - r = ts.rand(shape) - - assert r.shape == torch.Size( - [ - *_real_shape, - *nvec_shape, - ] - ), (r.shape, ns, shape, _real_shape, nvec_shape) - assert ts.is_in(r), (r, r.shape, ns) - rand = torch.rand( - torch.Size( - [ - *_real_shape, - *nvec_shape, - ] - ) ) - projection = ts._project(rand) - - assert rand.shape == projection.shape - assert ts.is_in(projection) - if projection.ndim < 1: - projection.fill_(-1) - else: - projection[..., 0] = -1 - assert not ts.is_in(projection) - - -@pytest.mark.parametrize("n", [1, 4, 7, 99]) -@pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize("shape", [None, [], [1], [1, 2]]) -def test_discrete_conversion(n, device, shape): - categorical = Categorical(n, device=device, shape=shape) - shape_one_hot = [n] if not shape else [*shape, n] - one_hot = OneHot(n, device=device, shape=shape_one_hot) - - assert categorical != one_hot - assert categorical.to_one_hot_spec() == one_hot - assert one_hot.to_categorical_spec() == categorical - - categorical_recon = one_hot.to_categorical(one_hot.rand(shape)) - assert categorical.is_in(categorical_recon), (categorical, categorical_recon) - one_hot_recon = categorical.to_one_hot(categorical.rand(shape)) - assert one_hot.is_in(one_hot_recon), (one_hot, one_hot_recon) + @pytest.mark.parametrize("shape", [None, [], torch.Size([3]), torch.Size([4, 5])]) + @pytest.mark.parametrize("dtype", [torch.float, torch.int, torch.long]) + def test_multi_discrete(self, shape, ns, dtype): + torch.manual_seed(0) + np.random.seed(0) + ts = MultiCategorical(ns, dtype=dtype) + _real_shape = shape if shape is not None else [] + nvec_shape = torch.tensor(ns).size() + for _ in range(100): + r = ts.rand(shape) + assert r.shape == torch.Size( + [ + *_real_shape, + *nvec_shape, + ] + ), (r.shape, ns, shape, _real_shape, nvec_shape) + assert ts.is_in(r), (r, r.shape, ns) + rand = torch.rand( + torch.Size( + [ + *_real_shape, + *nvec_shape, + ] + ) + ) + projection = ts._project(rand) -@pytest.mark.parametrize("ns", [[5], [5, 2, 3], [4, 5, 1, 3]]) -@pytest.mark.parametrize("shape", [torch.Size([3]), torch.Size([4, 5])]) -@pytest.mark.parametrize("device", get_default_devices()) -def test_multi_discrete_conversion(ns, shape, device): - categorical = MultiCategorical(ns, device=device) - one_hot = MultiOneHot(ns, device=device) - - assert categorical != one_hot - assert categorical.to_one_hot_spec() == one_hot - assert one_hot.to_categorical_spec() == categorical - - categorical_recon = one_hot.to_categorical(one_hot.rand(shape)) - assert categorical.is_in(categorical_recon), (categorical, categorical_recon) - one_hot_recon = categorical.to_one_hot(categorical.rand(shape)) - assert one_hot.is_in(one_hot_recon), (one_hot, one_hot_recon) + assert rand.shape == projection.shape + assert ts.is_in(projection) + if projection.ndim < 1: + projection.fill_(-1) + else: + projection[..., 0] = -1 + assert not ts.is_in(projection) + + @pytest.mark.parametrize("n", [1, 4, 7, 99]) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("shape", [None, [], [1], [1, 2]]) + def test_discrete_conversion(self, n, device, shape): + categorical = Categorical(n, device=device, shape=shape) + shape_one_hot = [n] if not shape else [*shape, n] + one_hot = OneHot(n, device=device, shape=shape_one_hot) + + assert categorical != one_hot + assert categorical.to_one_hot_spec() == one_hot + assert one_hot.to_categorical_spec() == categorical + + categorical_recon = one_hot.to_categorical(one_hot.rand(shape)) + assert categorical.is_in(categorical_recon), (categorical, categorical_recon) + one_hot_recon = categorical.to_one_hot(categorical.rand(shape)) + assert one_hot.is_in(one_hot_recon), (one_hot, one_hot_recon) + + @pytest.mark.parametrize("ns", [[5], [5, 2, 3], [4, 5, 1, 3]]) + @pytest.mark.parametrize("shape", [torch.Size([3]), torch.Size([4, 5])]) + @pytest.mark.parametrize("device", get_default_devices()) + def test_multi_discrete_conversion(self, ns, shape, device): + categorical = MultiCategorical(ns, device=device) + one_hot = MultiOneHot(ns, device=device) + + assert categorical != one_hot + assert categorical.to_one_hot_spec() == one_hot + assert one_hot.to_categorical_spec() == categorical + + categorical_recon = one_hot.to_categorical(one_hot.rand(shape)) + assert categorical.is_in(categorical_recon), (categorical, categorical_recon) + one_hot_recon = categorical.to_one_hot(categorical.rand(shape)) + assert one_hot.is_in(one_hot_recon), (one_hot, one_hot_recon) @pytest.mark.parametrize("is_complete", [True, False]) diff --git a/torchrl/data/tensor_specs.py b/torchrl/data/tensor_specs.py index 3733d4e2650..0f6acae9803 100644 --- a/torchrl/data/tensor_specs.py +++ b/torchrl/data/tensor_specs.py @@ -44,6 +44,11 @@ from tensordict.utils import _getitem_batch_size, is_non_tensor, NestedKey from torchrl._utils import _make_ordinal_device, get_binary_env_var, implement_for +try: + from torch.compiler import is_compiling +except ImportError: + from torch._dynamo import is_compiling + DEVICE_TYPING = Union[torch.device, str, int] INDEX_TYPING = Union[int, torch.Tensor, np.ndarray, slice, List] @@ -877,7 +882,7 @@ def project( a torch.Tensor belonging to the TensorSpec box. """ - if not self.is_in(val): + if is_compiling() or not self.is_in(val): return self._project(val) return val @@ -2696,7 +2701,9 @@ def is_in(self, val: torch.Tensor) -> bool: return val.shape == shape and val.dtype == self.dtype def _project(self, val: torch.Tensor) -> torch.Tensor: - return torch.as_tensor(val, dtype=self.dtype).reshape(self.shape) + return torch.as_tensor(val, dtype=self.dtype).reshape( + val.shape[: -self.ndim] + self.shape + ) def enumerate(self) -> Any: raise NotImplementedError("enumerate cannot be called with continuous specs.")