diff --git a/test/test_transforms.py b/test/test_transforms.py index 40c4f76e539..4f84001480f 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -56,6 +56,7 @@ LazyTensorStorage, ReplayBuffer, TensorDictReplayBuffer, + TensorSpec, TensorStorage, UnboundedContinuousTensorSpec, ) @@ -120,6 +121,7 @@ _has_tv, BatchSizeTransform, FORWARD_NOT_IMPLEMENTED, + Transform, ) from torchrl.envs.transforms.vc1 import _has_vc from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform @@ -7952,6 +7954,28 @@ def test_added_transforms_are_in_eval_mode(): class TestTransformedEnv: + def test_attr_error(self): + class BuggyTransform(Transform): + def transform_observation_spec( + self, observation_spec: TensorSpec + ) -> TensorSpec: + raise AttributeError + + def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec: + raise RuntimeError("reward!") + + env = TransformedEnv(CountingEnv(), BuggyTransform()) + with pytest.raises( + AttributeError, match="because an internal error was raised" + ): + env.observation_spec + with pytest.raises( + AttributeError, match="'CountingEnv' object has no attribute 'tralala'" + ): + env.tralala + with pytest.raises(RuntimeError, match="reward!"): + env.transform.transform_reward_spec(env.base_env.full_reward_spec) + def test_independent_obs_specs_from_shared_env(self): obs_spec = CompositeSpec( observation=BoundedTensorSpec(low=0, high=10, shape=torch.Size((1,))) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 6aeea0529ce..96c7ddfa35e 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -926,7 +926,12 @@ def __getattr__(self, attr: str) -> Any: return super().__getattr__( attr ) # make sure that appropriate exceptions are raised - except Exception as err: + except AttributeError as err: + if attr.endswith("_spec"): + raise AttributeError( + f"Could not get {attr} because an internal error was raised. To find what this error " + f"is, call env.transform.transform__spec(env.base_env.spec)." + ) if attr.startswith("__"): raise AttributeError( "passing built-in private methods is "