Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Apr 23, 2024
1 parent f068784 commit 199085a
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 5 deletions.
31 changes: 29 additions & 2 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -798,8 +798,35 @@ In some cases, creating a testing environment where images can be collected is t
(some libraries only allow one environment instance per workspace).
In these cases, assuming that a `render` method is available in the environment, the :class:`~torchrl.record.PixelRenderTransform`
can be used to call `render` on the parent environment and save the images in the rollout data stream.
This class should only be used within the same process as the environment that is being rendered (remote calls to `render`
are not allowed).
This class works over single and batched environments alike:

>>> from torchrl.envs import GymEnv, check_env_specs, ParallelEnv, EnvCreator
>>> from torchrl.record.loggers import CSVLogger
>>> from torchrl.record.recorder import PixelRenderTransform, VideoRecorder
>>>
>>> def make_env():
>>> env = GymEnv("CartPole-v1", render_mode="rgb_array")
>>> # Uncomment this line to execute per-env
>>> # env = env.append_transform(PixelRenderTransform())
>>> return env
>>>
>>> if __name__ == "__main__":
... logger = CSVLogger("dummy", video_format="mp4")
...
... env = ParallelEnv(16, EnvCreator(make_env))
... env.start()
... # Comment this line to execute per-env
... env = env.append_transform(PixelRenderTransform())
...
... env = env.append_transform(VideoRecorder(logger=logger, tag="pixels_record"))
... env.rollout(3)
...
... check_env_specs(env)
...
... r = env.rollout(30)
... env.transform.dump()
... env.close()


.. currentmodule:: torchrl.record

Expand Down
1 change: 1 addition & 0 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,6 +1624,7 @@ def __getattr__(self, attr: str) -> Any:
try:
# _ = getattr(self._dummy_env, attr)
if self.is_closed:
self.start()
raise RuntimeError(
"Trying to access attributes of closed/non started "
"environments. Check that the batched environment "
Expand Down
7 changes: 5 additions & 2 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2298,7 +2298,7 @@ def rollout(
self,
max_steps: int,
policy: Optional[Callable[[TensorDictBase], TensorDictBase]] = None,
callback: Optional[Callable[[TensorDictBase, ...], TensorDictBase]] = None,
callback: Optional[Callable[[TensorDictBase, ...], Any]] = None,
auto_reset: bool = True,
auto_cast_to_device: bool = False,
break_when_any_done: bool = True,
Expand All @@ -2320,7 +2320,10 @@ def rollout(
The policy can be any callable that reads either a tensordict or
the entire sequence of observation entries __sorted as__ the ``env.observation_spec.keys()``.
Defaults to `None`.
callback (callable, optional): function to be called at each iteration with the given TensorDict.
callback (Callable[[TensorDict], Any], optional): function to be called at each iteration with the given
TensorDict. Defaults to ``None``. The output of ``callback`` will not be collected, it is the user
responsibility to save any result within the callback call if data needs to be carried over beyond
the call to ``rollout``.
auto_reset (bool, optional): if ``True``, resets automatically the environment
if it is in a done state when the rollout is initiated.
Default is ``True``.
Expand Down
43 changes: 42 additions & 1 deletion torchrl/record/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import importlib.util
from copy import copy
from typing import Callable, List, Optional, Sequence
from typing import Callable, List, Optional, Sequence, Union

import numpy as np
import torch
Expand All @@ -19,6 +19,7 @@
from torchrl.data import TensorSpec
from torchrl.data.tensor_specs import NonTensorSpec, UnboundedContinuousTensorSpec
from torchrl.data.utils import CloudpickleWrapper
from torchrl.envs import EnvBase
from torchrl.envs.transforms import ObservationTransform, Transform
from torchrl.record.loggers import Logger

Expand Down Expand Up @@ -335,6 +336,7 @@ class PixelRenderTransform(Transform):
This transform offers an alternative to the ``from_pixels`` syntatic sugar when instantiating an environment
that offers rendering is expensive, or when ``from_pixels`` is not implemented.
It can be used within a single environment or over batched environments alike.
Args:
out_keys (List[NestedKey] or Nested): List of keys where to register the pixel observations.
Expand Down Expand Up @@ -400,6 +402,15 @@ class PixelRenderTransform(Transform):
>>> r = env.rollout(30)
>>> env.transform[-1].dump()
The transform can be disabled using the :meth:`~.switch` method, which will turn the rendering on if it's off
or off if it's on (an argument can also be passed to control this behaviour). Since transforms are
:class:`~torch.nn.Module` instances, :meth:`~torch.nn.Module.apply` can be used to control this behaviour:
>>> def switch(module):
... if isinstance(module, PixelRenderTransform):
... module.switch()
>>> env.apply(switch)
"""

def __init__(
Expand All @@ -426,6 +437,7 @@ def __init__(
self.as_non_tensor = as_non_tensor
self.kwargs = kwargs
self.render_method = render_method
self._enabled = True
super().__init__(in_keys=[], out_keys=out_keys)

def _reset(
Expand All @@ -434,10 +446,18 @@ def _reset(
return self._call(tensordict_reset)

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
if not self._enabled:
return tensordict

array = getattr(self.parent, self.render_method)(**self.kwargs)
if self.preproc:
array = self.preproc(array)
if self.as_non_tensor is None:
if isinstance(array, list):
if isinstance(array[0], np.ndarray):
array = np.asarray(array)
else:
array = torch.as_tensor(array)
if (
array.ndim == 3
and array.shape[-1] == 3
Expand Down Expand Up @@ -475,3 +495,24 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec
)
observation_spec[self.out_keys[0]] = spec
return observation_spec

def switch(self, mode: str | bool = None):
"""Sets the transform on or off."""
if mode is None:
mode = not self._enabled
if not isinstance(mode, bool):
if mode not in ("on", "off"):
raise ValueError("mode must be either 'on' or 'off', or a boolean.")
mode = mode == "on"
self._enabled = mode

def set_container(self, container: Union[Transform, EnvBase]) -> None:
out = super().set_container(container)
if isinstance(self.parent, EnvBase):
# Start the env if needed
method = getattr(self.parent, self.render_method, None)
if method is None or not callable(method):
raise ValueError(
f"The render method must exist and be a callable. Got render={method}."
)
return out

0 comments on commit 199085a

Please sign in to comment.