From 0314e0558e915636220c1173628dbe42b86a40d2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 15 Feb 2024 15:29:15 +0000 Subject: [PATCH] [BugFix] Make sure ParallelEnv does not overflow mem when policy requires grad (#1909) --- test/test_env.py | 91 +++++++++++++++++++++++++++++++++++- torchrl/envs/batched_envs.py | 9 ++++ 2 files changed, 99 insertions(+), 1 deletion(-) diff --git a/test/test_env.py b/test/test_env.py index 15bcf5e3fcb..e2d2f1d9854 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import argparse +import functools import gc import os.path import re @@ -65,7 +66,14 @@ DiscreteTensorSpec, UnboundedContinuousTensorSpec, ) -from torchrl.envs import CatTensors, DoubleToFloat, EnvCreator, ParallelEnv, SerialEnv +from torchrl.envs import ( + CatTensors, + DoubleToFloat, + EnvBase, + EnvCreator, + ParallelEnv, + SerialEnv, +) from torchrl.envs.gym_like import default_info_dict_reader from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv from torchrl.envs.libs.gym import _has_gym, GymEnv, GymWrapper @@ -2473,6 +2481,87 @@ def test_auto_cast_to_device(break_when_any_done): assert_allclose_td(rollout0, rollout1) +@pytest.mark.parametrize("device", get_default_devices()) +def test_backprop(device): + # Tests that backprop through a series of single envs and through a serial env are identical + # Also tests that no backprop can be achieved with parallel env. + class DifferentiableEnv(EnvBase): + def __init__(self, device): + super().__init__(device=device) + self.observation_spec = CompositeSpec( + observation=UnboundedContinuousTensorSpec(3, device=device), + device=device, + ) + self.action_spec = CompositeSpec( + action=UnboundedContinuousTensorSpec(3, device=device), device=device + ) + self.reward_spec = CompositeSpec( + reward=UnboundedContinuousTensorSpec(1, device=device), device=device + ) + self.seed = 0 + + def _set_seed(self, seed): + self.seed = seed + return seed + + def _reset(self, tensordict): + td = self.observation_spec.zero().update(self.done_spec.zero()) + td["observation"] = ( + td["observation"].clone() + self.seed % 10 + ).requires_grad_() + return td + + def _step(self, tensordict): + action = tensordict.get("action") + obs = (tensordict.get("observation") + action) / action.norm() + return TensorDict( + { + "reward": action.sum().unsqueeze(0), + **self.full_done_spec.zero(), + "observation": obs, + } + ) + + torch.manual_seed(0) + policy = Actor(torch.nn.Linear(3, 3, device=device)) + env0 = DifferentiableEnv(device=device) + seed = env0.set_seed(0) + env1 = DifferentiableEnv(device=device) + env1.set_seed(seed) + r0 = env0.rollout(10, policy) + r1 = env1.rollout(10, policy) + r = torch.stack([r0, r1]) + g = torch.autograd.grad(r["next", "reward"].sum(), policy.parameters()) + + def make_env(seed, device=device): + env = DifferentiableEnv(device=device) + env.set_seed(seed) + return env + + serial_env = SerialEnv( + 2, + [functools.partial(make_env, seed=0), functools.partial(make_env, seed=seed)], + device=device, + ) + r_serial = serial_env.rollout(10, policy) + + g_serial = torch.autograd.grad( + r_serial["next", "reward"].sum(), policy.parameters() + ) + torch.testing.assert_close(g, g_serial) + + p_env = ParallelEnv( + 2, + [functools.partial(make_env, seed=0), functools.partial(make_env, seed=seed)], + device=device, + ) + try: + r_parallel = p_env.rollout(10, policy) + assert not r_parallel.exclude("action").requires_grad + finally: + p_env.close() + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 67802f01620..ae313ce5f19 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -1041,6 +1041,12 @@ class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta): >>> # If no cuda device is available >>> env = ParallelEnv(N, MyEnv(..., device="cpu")) + .. warning:: + ParallelEnv disable gradients in all operations (:meth:`~.step`, + :meth:`~.reset` and :meth:`~.step_and_maybe_reset`) because gradients + cannot be passed through :class:`multiprocessing.Pipe` objects. + Only :class:`~torchrl.envs.SerialEnv` will support backpropagation. + """ def _start_workers(self) -> None: @@ -1143,6 +1149,7 @@ def load_state_dict(self, state_dict: OrderedDict) -> None: event.wait() event.clear() + @torch.no_grad() @_check_start def step_and_maybe_reset( self, tensordict: TensorDictBase @@ -1205,6 +1212,7 @@ def step_and_maybe_reset( tensordict.set("next", next_td) return tensordict, tensordict_ + @torch.no_grad() @_check_start def _step(self, tensordict: TensorDictBase) -> TensorDictBase: # We must use the in_keys and nothing else for the following reasons: @@ -1261,6 +1269,7 @@ def select_and_clone(name, tensor): out = out.to(device, non_blocking=self.non_blocking) return out + @torch.no_grad() @_check_start def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: if tensordict is not None: