Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Make sure ParallelEnv does not overflow mem when policy requires grad #1909

Merged
merged 7 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 90 additions & 1 deletion test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import argparse
import functools
import gc
import os.path
import re
Expand Down Expand Up @@ -65,7 +66,14 @@
DiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import CatTensors, DoubleToFloat, EnvCreator, ParallelEnv, SerialEnv
from torchrl.envs import (
CatTensors,
DoubleToFloat,
EnvBase,
EnvCreator,
ParallelEnv,
SerialEnv,
)
from torchrl.envs.gym_like import default_info_dict_reader
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
from torchrl.envs.libs.gym import _has_gym, GymEnv, GymWrapper
Expand Down Expand Up @@ -2473,6 +2481,87 @@ def test_auto_cast_to_device(break_when_any_done):
assert_allclose_td(rollout0, rollout1)


@pytest.mark.parametrize("device", get_default_devices())
def test_backprop(device):
# Tests that backprop through a series of single envs and through a serial env are identical
# Also tests that no backprop can be achieved with parallel env.
class DifferentiableEnv(EnvBase):
def __init__(self, device):
super().__init__(device=device)
self.observation_spec = CompositeSpec(
observation=UnboundedContinuousTensorSpec(3, device=device),
device=device,
)
self.action_spec = CompositeSpec(
action=UnboundedContinuousTensorSpec(3, device=device), device=device
)
self.reward_spec = CompositeSpec(
reward=UnboundedContinuousTensorSpec(1, device=device), device=device
)
self.seed = 0

def _set_seed(self, seed):
self.seed = seed
return seed

def _reset(self, tensordict):
td = self.observation_spec.zero().update(self.done_spec.zero())
td["observation"] = (
td["observation"].clone() + self.seed % 10
).requires_grad_()
return td

def _step(self, tensordict):
action = tensordict.get("action")
obs = (tensordict.get("observation") + action) / action.norm()
return TensorDict(
{
"reward": action.sum().unsqueeze(0),
**self.full_done_spec.zero(),
"observation": obs,
}
)

torch.manual_seed(0)
policy = Actor(torch.nn.Linear(3, 3, device=device))
env0 = DifferentiableEnv(device=device)
seed = env0.set_seed(0)
env1 = DifferentiableEnv(device=device)
env1.set_seed(seed)
r0 = env0.rollout(10, policy)
r1 = env1.rollout(10, policy)
r = torch.stack([r0, r1])
g = torch.autograd.grad(r["next", "reward"].sum(), policy.parameters())

def make_env(seed, device=device):
env = DifferentiableEnv(device=device)
env.set_seed(seed)
return env

serial_env = SerialEnv(
2,
[functools.partial(make_env, seed=0), functools.partial(make_env, seed=seed)],
device=device,
)
r_serial = serial_env.rollout(10, policy)

g_serial = torch.autograd.grad(
r_serial["next", "reward"].sum(), policy.parameters()
)
torch.testing.assert_close(g, g_serial)

p_env = ParallelEnv(
2,
[functools.partial(make_env, seed=0), functools.partial(make_env, seed=seed)],
device=device,
)
try:
r_parallel = p_env.rollout(10, policy)
assert not r_parallel.exclude("action").requires_grad
finally:
p_env.close()


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
9 changes: 9 additions & 0 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,12 @@ class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta):
>>> # If no cuda device is available
>>> env = ParallelEnv(N, MyEnv(..., device="cpu"))

.. warning::
ParallelEnv disable gradients in all operations (:meth:`~.step`,
:meth:`~.reset` and :meth:`~.step_and_maybe_reset`) because gradients
cannot be passed through :class:`multiprocessing.Pipe` objects.
Only :class:`~torchrl.envs.SerialEnv` will support backpropagation.

"""

def _start_workers(self) -> None:
Expand Down Expand Up @@ -1143,6 +1149,7 @@ def load_state_dict(self, state_dict: OrderedDict) -> None:
event.wait()
event.clear()

@torch.no_grad()
@_check_start
def step_and_maybe_reset(
self, tensordict: TensorDictBase
Expand Down Expand Up @@ -1205,6 +1212,7 @@ def step_and_maybe_reset(
tensordict.set("next", next_td)
return tensordict, tensordict_

@torch.no_grad()
@_check_start
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# We must use the in_keys and nothing else for the following reasons:
Expand Down Expand Up @@ -1261,6 +1269,7 @@ def select_and_clone(name, tensor):
out = out.to(device, non_blocking=self.non_blocking)
return out

@torch.no_grad()
@_check_start
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
if tensordict is not None:
Expand Down
Loading