Skip to content

Commit

Permalink
[Quality] Capture errors in specs transforms (#2092)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 18, 2024
1 parent 6b87184 commit 9d3530f
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
24 changes: 24 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
LazyTensorStorage,
ReplayBuffer,
TensorDictReplayBuffer,
TensorSpec,
TensorStorage,
UnboundedContinuousTensorSpec,
)
Expand Down Expand Up @@ -120,6 +121,7 @@
_has_tv,
BatchSizeTransform,
FORWARD_NOT_IMPLEMENTED,
Transform,
)
from torchrl.envs.transforms.vc1 import _has_vc
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
Expand Down Expand Up @@ -7952,6 +7954,28 @@ def test_added_transforms_are_in_eval_mode():


class TestTransformedEnv:
def test_attr_error(self):
class BuggyTransform(Transform):
def transform_observation_spec(
self, observation_spec: TensorSpec
) -> TensorSpec:
raise AttributeError

def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
raise RuntimeError("reward!")

env = TransformedEnv(CountingEnv(), BuggyTransform())
with pytest.raises(
AttributeError, match="because an internal error was raised"
):
env.observation_spec
with pytest.raises(
AttributeError, match="'CountingEnv' object has no attribute 'tralala'"
):
env.tralala
with pytest.raises(RuntimeError, match="reward!"):
env.transform.transform_reward_spec(env.base_env.full_reward_spec)

def test_independent_obs_specs_from_shared_env(self):
obs_spec = CompositeSpec(
observation=BoundedTensorSpec(low=0, high=10, shape=torch.Size((1,)))
Expand Down
7 changes: 6 additions & 1 deletion torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,12 @@ def __getattr__(self, attr: str) -> Any:
return super().__getattr__(
attr
) # make sure that appropriate exceptions are raised
except Exception as err:
except AttributeError as err:
if attr.endswith("_spec"):
raise AttributeError(
f"Could not get {attr} because an internal error was raised. To find what this error "
f"is, call env.transform.transform_<placeholder>_spec(env.base_env.spec)."
)
if attr.startswith("__"):
raise AttributeError(
"passing built-in private methods is "
Expand Down

0 comments on commit 9d3530f

Please sign in to comment.