Skip to content

Commit

Permalink
[BugFix] Remove reset on last step of a rollout (#1936)
Browse files Browse the repository at this point in the history
Co-authored-by: vmoens <[email protected]>
  • Loading branch information
matteobettini and vmoens authored Feb 21, 2024
1 parent 4080cf3 commit 03f4aa3
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 5 deletions.
29 changes: 29 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,35 @@ def test_rollout(env_name, frame_skip, seed=0):
env.close()


@pytest.mark.parametrize("max_steps", [1, 5])
def test_rollouts_chaining(max_steps, batch_size=(4,), epochs=4):
# CountingEnv is done at max_steps + 1, so to emulate it being done at max_steps, we feed max_steps=max_steps - 1
env = CountingEnv(max_steps=max_steps - 1, batch_size=batch_size)
policy = CountingEnvCountPolicy(
action_spec=env.action_spec, action_key=env.action_key
)

input_td = env.reset()
for _ in range(epochs):
rollout_td = env.rollout(
max_steps=max_steps,
policy=policy,
auto_reset=False,
break_when_any_done=False,
tensordict=input_td,
)
assert (env.count == max_steps).all()
input_td = step_mdp(
rollout_td[..., -1],
keep_other=True,
exclude_action=False,
exclude_reward=True,
reward_keys=env.reward_keys,
action_keys=env.action_keys,
done_keys=env.done_keys,
)


@pytest.mark.parametrize("device", get_default_devices())
def test_rollout_predictability(device):
env = MockSerialEnv(device=device)
Expand Down
52 changes: 47 additions & 5 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2273,7 +2273,9 @@ def rollout(
called on the sub-envs that are done. Default is True.
return_contiguous (bool): if False, a LazyStackedTensorDict will be returned. Default is True.
tensordict (TensorDict, optional): if auto_reset is False, an initial
tensordict must be provided.
tensordict must be provided. Rollout will check if this tensordict has done flags and reset the
environment in those dimensions (if needed). This normally should not occur if ``tensordict`` is the
output of a reset, but can occur if ``tensordict`` is the last step of a previous rollout.
Returns:
TensorDict object containing the resulting trajectory.
Expand Down Expand Up @@ -2369,6 +2371,26 @@ def rollout(
>>> print(rollout.names)
[None, 'time']
Rollouts can be used in a loop to emulate data collection.
To do so, you need to pass as input the last tensordict coming from the previous rollout after calling
:func:`~torchrl.envs.utils.step_mdp` on it.
Examples:
>>> from torchrl.envs import GymEnv, step_mdp
>>> env = GymEnv("CartPole-v1")
>>> epochs = 10
>>> input_td = env.reset()
>>> for i in range(epochs):
... rollout_td = env.rollout(
... max_steps=100,
... break_when_any_done=False,
... auto_reset=False,
... tensordict=input_td,
... )
... input_td = step_mdp(
... rollout_td[..., -1],
... )
"""
if auto_cast_to_device:
try:
Expand All @@ -2388,6 +2410,9 @@ def rollout(
tensordict = self.reset()
elif tensordict is None:
raise RuntimeError("tensordict must be provided when auto_reset is False")
else:
tensordict = self.maybe_reset(tensordict)

if policy is None:

policy = self.rand_action
Expand Down Expand Up @@ -2493,7 +2518,10 @@ def _rollout_nonstop(
tensordict_ = tensordict_.to(env_device, non_blocking=True)
else:
tensordict_.clear_device_()
tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
if i == max_steps - 1:
tensordict = self.step(tensordict_)
else:
tensordict, tensordict_ = self.step_and_maybe_reset(tensordict_)
tensordicts.append(tensordict)
if i == max_steps - 1:
# we don't truncated as one could potentially continue the run
Expand Down Expand Up @@ -2557,14 +2585,28 @@ def step_and_maybe_reset(
action_keys=self.action_keys,
done_keys=self.done_keys,
)
tensordict_ = self.maybe_reset(tensordict_)
return tensordict, tensordict_

def maybe_reset(self, tensordict: TensorDictBase) -> TensorDictBase:
"""Checks the done keys of the input tensordict and, if needed, resets the environment where it is done.
Args:
tensordict (TensorDictBase): a tensordict coming from the output of :func:`~torchrl.envs.utils.step_mdp`.
Returns:
A tensordict that is identical to the input where the environment was
not reset and contains the new reset data where the environment was reset.
"""
any_done = _terminated_or_truncated(
tensordict_,
tensordict,
full_done_spec=self.output_spec["full_done_spec"],
key="_reset",
)
if any_done:
tensordict_ = self.reset(tensordict_)
return tensordict, tensordict_
tensordict = self.reset(tensordict)
return tensordict

def empty_cache(self):
"""Erases all the cached values.
Expand Down

0 comments on commit 03f4aa3

Please sign in to comment.