diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index de4446a51c5..5c39c5a1349 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -798,8 +798,35 @@ In some cases, creating a testing environment where images can be collected is t (some libraries only allow one environment instance per workspace). In these cases, assuming that a `render` method is available in the environment, the :class:`~torchrl.record.PixelRenderTransform` can be used to call `render` on the parent environment and save the images in the rollout data stream. -This class should only be used within the same process as the environment that is being rendered (remote calls to `render` -are not allowed). +This class works over single and batched environments alike: + + >>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator + >>> from torchrl.record.loggers import CSVLogger + >>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder + >>> + >>> def make_env(): + >>> env = GymEnv("CartPole-v1", render_mode="rgb_array") + >>> # Uncomment this line to execute per-env + >>> # env = env.append_transform(PixelRenderTransform()) + >>> return env + >>> + >>> if __name__ == "__main__": + ... logger = CSVLogger("dummy", video_format="mp4") + ... + ... env = ParallelEnv(16, EnvCreator(make_env)) + ... env.start() + ... # Comment this line to execute per-env + ... env = env.append_transform(PixelRenderTransform()) + ... + ... env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record")) + ... env.rollout(3) + ... + ... check_env_specs(env) + ... + ... r = env.rollout(30) + ... env.transform.dump() + ... env.close() + .. currentmodule:: torchrl.record diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 660aecb3fd8..b3026da35ca 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1624,6 +1624,7 @@ def __getattr__(self, attr: str) -> Any: try: # _ = getattr(self._dummy_env, attr) if self.is_closed: + self.start() raise RuntimeError( "Trying to access attributes of closed/non started " "environments. Check that the batched environment " diff --git a/torchrl/envs/common.py b/torchrl/envs/common.py index f5d4625fd07..8712c74340a 100644 --- a/torchrl/envs/common.py +++ b/torchrl/envs/common.py @@ -2298,7 +2298,7 @@ def rollout( self, max_steps: int, policy: Optional[Callable[[TensorDictBase], TensorDictBase]] = None, - callback: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None, + callback: Optional[Callable[[TensorDictBase, ...], Any]] = None, auto_reset: bool = True, auto_cast_to_device: bool = False, break_when_any_done: bool = True, @@ -2320,7 +2320,10 @@ def rollout( The policy can be any callable that reads either a tensordict or the entire sequence of observation entries __sorted as__ the ``env.observation_spec.keys()``. Defaults to `None`. - callback (callable, optional): function to be called at each iteration with the given TensorDict. + callback (Callable[[TensorDict], Any], optional): function to be called at each iteration with the given + TensorDict. Defaults to ``None``. The output of ``callback`` will not be collected, it is the user + responsibility to save any result within the callback call if data needs to be carried over beyond + the call to ``rollout``. auto_reset (bool, optional): if ``True``, resets automatically the environment if it is in a done state when the rollout is initiated. Default is ``True``. diff --git a/torchrl/record/recorder.py b/torchrl/record/recorder.py index 9e8d681de7c..5033a2c93c2 100644 --- a/torchrl/record/recorder.py +++ b/torchrl/record/recorder.py @@ -6,7 +6,7 @@ import importlib.util from copy import copy -from typing import Callable, List, Optional, Sequence +from typing import Callable, List, Optional, Sequence, Union import numpy as np import torch @@ -19,6 +19,7 @@ from torchrl.data import TensorSpec from torchrl.data.tensor_specs import NonTensorSpec, UnboundedContinuousTensorSpec from torchrl.data.utils import CloudpickleWrapper +from torchrl.envs import EnvBase from torchrl.envs.transforms import ObservationTransform, Transform from torchrl.record.loggers import Logger @@ -335,6 +336,7 @@ class PixelRenderTransform(Transform): This transform offers an alternative to the ``from_pixels`` syntatic sugar when instantiating an environment that offers rendering is expensive, or when ``from_pixels`` is not implemented. + It can be used within a single environment or over batched environments alike. Args: out_keys (List[NestedKey] or Nested): List of keys where to register the pixel observations. @@ -400,6 +402,15 @@ class PixelRenderTransform(Transform): >>> r = env.rollout(30) >>> env.transform[-1].dump() + The transform can be disabled using the :meth:`~.switch` method, which will turn the rendering on if it's off + or off if it's on (an argument can also be passed to control this behaviour). Since transforms are + :class:`~torch.nn.Module` instances, :meth:`~torch.nn.Module.apply` can be used to control this behaviour: + + >>> def switch(module): + ... if isinstance(module, PixelRenderTransform): + ... module.switch() + >>> env.apply(switch) + """ def __init__( @@ -426,6 +437,7 @@ def __init__( self.as_non_tensor = as_non_tensor self.kwargs = kwargs self.render_method = render_method + self._enabled = True super().__init__(in_keys=[], out_keys=out_keys) def _reset( @@ -434,10 +446,18 @@ def _reset( return self._call(tensordict_reset) def _call(self, tensordict: TensorDictBase) -> TensorDictBase: + if not self._enabled: + return tensordict + array = getattr(self.parent, self.render_method)(**self.kwargs) if self.preproc: array = self.preproc(array) if self.as_non_tensor is None: + if isinstance(array, list): + if isinstance(array[0], np.ndarray): + array = np.asarray(array) + else: + array = torch.as_tensor(array) if ( array.ndim == 3 and array.shape[-1] == 3 @@ -475,3 +495,24 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec ) observation_spec[self.out_keys[0]] = spec return observation_spec + + def switch(self, mode: str | bool = None): + """Sets the transform on or off.""" + if mode is None: + mode = not self._enabled + if not isinstance(mode, bool): + if mode not in ("on", "off"): + raise ValueError("mode must be either 'on' or 'off', or a boolean.") + mode = mode == "on" + self._enabled = mode + + def set_container(self, container: Union[Transform, EnvBase]) -> None: + out = super().set_container(container) + if isinstance(self.parent, EnvBase): + # Start the env if needed + method = getattr(self.parent, self.render_method, None) + if method is None or not callable(method): + raise ValueError( + f"The render method must exist and be a callable. Got render={method}." + ) + return out