Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 15, 2024
1 parent 899af07 commit 2cf514f
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
58 changes: 57 additions & 1 deletion test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import argparse
import functools
import gc
import os.path
import re
Expand Down Expand Up @@ -65,7 +66,8 @@
DiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import CatTensors, DoubleToFloat, EnvCreator, ParallelEnv, SerialEnv
from torchrl.envs import CatTensors, DoubleToFloat, EnvCreator, ParallelEnv, \
SerialEnv, EnvBase
from torchrl.envs.gym_like import default_info_dict_reader
from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv
from torchrl.envs.libs.gym import _has_gym, GymEnv, GymWrapper
Expand Down Expand Up @@ -2473,6 +2475,60 @@ def test_auto_cast_to_device(break_when_any_done):
assert_allclose_td(rollout0, rollout1)


@pytest.mark.parametrize("device", get_default_devices())
def test_backprop(device):
# Tests that backprop through a series of single envs and through a serial env are identical
# Also tests that no backprop can be achieved with parallel env.
class DifferentiableEnv(EnvBase):
def __init__(self, device):
super().__init__(device=device)
self.observation_spec = CompositeSpec(observation=UnboundedContinuousTensorSpec(3, device=device), device=device)
self.action_spec = CompositeSpec(action=UnboundedContinuousTensorSpec(3, device=device), device=device)
self.reward_spec = CompositeSpec(reward=UnboundedContinuousTensorSpec(1, device=device), device=device)
self.seed = 0
def _set_seed(self, seed):
self.seed = seed
return seed
def _reset(self, tensordict):
td = self.observation_spec.zero().update(self.done_spec.zero())
td["observation"] = (td["observation"].clone()+self.seed % 10).requires_grad_()
return td
def _step(self, tensordict):
action = tensordict.get("action")
obs = (tensordict.get("observation")+action) / action.norm()
return TensorDict({
"reward": action.sum().unsqueeze(0),
**self.full_done_spec.zero(),
"observation": obs,
})

torch.manual_seed(0)
policy = Actor(torch.nn.Linear(3, 3))
env0 = DifferentiableEnv(device=device)
seed = env0.set_seed(0)
env1 = DifferentiableEnv(device=device)
env1.set_seed(seed)
r0 = env0.rollout(10, policy)
r1 = env1.rollout(10, policy)
r = torch.stack([r0, r1])
g = torch.autograd.grad(r["next", "reward"].sum(), policy.parameters())

def make_env(seed, device=device):
env = DifferentiableEnv(device=device)
env.set_seed(seed)
return env
serial_env = SerialEnv(2, [functools.partial(make_env, seed=0), functools.partial(make_env, seed=seed)])
r_serial = serial_env.rollout(10, policy)

g_serial = torch.autograd.grad(r_serial["next", "reward"].sum(), policy.parameters())
torch.testing.assert_close(g, g_serial)

p_env = ParallelEnv(2, [functools.partial(make_env, seed=0), functools.partial(make_env, seed=seed)])
r_parallel = p_env.rollout(10, policy)
assert not r_parallel.exclude("action").requires_grad



if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
11 changes: 11 additions & 0 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,12 @@ class ParallelEnv(BatchedEnvBase, metaclass=_PEnvMeta):
>>> # If no cuda device is available
>>> env = ParallelEnv(N, MyEnv(..., device="cpu"))
.. warning::
ParallelEnv disable gradients in all operations (:meth:`~.step`,
:meth:`~.reset` and :meth:`~.step_and_maybe_reset`) because gradients
cannot be passed through :class:`multiprocessing.Pipe` objects.
Only :class:`~torchrl.envs.SerialEnv` will support backpropagation.
"""

def _start_workers(self) -> None:
Expand Down Expand Up @@ -1143,6 +1149,7 @@ def load_state_dict(self, state_dict: OrderedDict) -> None:
event.wait()
event.clear()

@torch.no_grad()
@_check_start
def step_and_maybe_reset(
self, tensordict: TensorDictBase
Expand All @@ -1155,6 +1162,7 @@ def step_and_maybe_reset(
# and this transform overrides an observation key (eg, CatFrames)
# the shape, dtype or device may not necessarily match and writing
# the value in-place will fail.
assert not self.shared_tensordict_parent.requires_grad
self.shared_tensordict_parent.update_(
tensordict, keys_to_update=self._env_input_keys
)
Expand Down Expand Up @@ -1205,6 +1213,7 @@ def step_and_maybe_reset(
tensordict.set("next", next_td)
return tensordict, tensordict_

@torch.no_grad()
@_check_start
def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# We must use the in_keys and nothing else for the following reasons:
Expand All @@ -1215,6 +1224,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# and this transform overrides an observation key (eg, CatFrames)
# the shape, dtype or device may not necessarily match and writing
# the value in-place will fail.
assert not self.shared_tensordict_parent.requires_grad
self.shared_tensordict_parent.update_(
tensordict, keys_to_update=list(self._env_input_keys)
)
Expand Down Expand Up @@ -1261,6 +1271,7 @@ def select_and_clone(name, tensor):
out = out.to(device, non_blocking=self.non_blocking)
return out

@torch.no_grad()
@_check_start
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
if tensordict is not None:
Expand Down

0 comments on commit 2cf514f

Please sign in to comment.