Skip to content

[Feature] ConditionalPolicySwitch transform #2711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: gh/vmoens/76/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,7 @@ to be able to create this other composition:
CenterCrop
ClipTransform
Compose
ConditionalPolicySwitch
ConditionalSkip
Crop
DataLoadingPrimer
Expand Down
204 changes: 203 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import tensordict.tensordict
import torch

from tensordict import (
assert_close,
LazyStackedTensorDict,
Expand All @@ -33,7 +34,7 @@
TensorDictBase,
unravel_key,
)
from tensordict.nn import TensorDictSequential
from tensordict.nn import TensorDictSequential, WrapModule
from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td
from torch import multiprocessing as mp, nn, Tensor
from torchrl._utils import _replace_last, prod, set_auto_unwrap_transformed_env
Expand Down Expand Up @@ -62,6 +63,7 @@
CenterCrop,
ClipTransform,
Compose,
ConditionalPolicySwitch,
ConditionalSkip,
Crop,
DeviceCastTransform,
Expand Down Expand Up @@ -14526,6 +14528,206 @@ def test_can_init_with_fps(self):
assert recorder is not None


class TestConditionalPolicySwitch(TransformBase):
def test_single_trans_env_check(self):
base_env = CountingEnv(max_steps=15)
condition = lambda td: ((td.get("step_count") % 2) == 0).all()
# Player 0
policy_odd = lambda td: td.set("action", env.action_spec.zero())
policy_even = lambda td: td.set("action", env.action_spec.one())
transforms = Compose(
StepCounter(),
ConditionalPolicySwitch(condition=condition, policy=policy_even),
)
env = base_env.append_transform(transforms)
env.check_env_specs()

def _create_policy_odd(self, base_env):
return WrapModule(
lambda td, base_env=base_env: td.set(
"action", base_env.action_spec_unbatched.zero(td.shape)
),
out_keys=["action"],
)

def _create_policy_even(self, base_env):
return WrapModule(
lambda td, base_env=base_env: td.set(
"action", base_env.action_spec_unbatched.one(td.shape)
),
out_keys=["action"],
)

def _create_transforms(self, condition, policy_even):
return Compose(
StepCounter(),
ConditionalPolicySwitch(condition=condition, policy=policy_even),
)

def _make_env(self, max_count, env_cls):
torch.manual_seed(0)
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
base_env = env_cls(max_steps=max_count)
policy_even = self._create_policy_even(base_env)
transforms = self._create_transforms(condition, policy_even)
return base_env.append_transform(transforms)

def _test_env(self, env, policy_odd):
env.check_env_specs()
env.set_seed(0)
r = env.rollout(100, policy_odd, break_when_any_done=False)
# Check results are independent: one reset / step in one env should not impact results in another
r0, r1, r2 = r.unbind(0)
r0_split = r0.split(6)
assert all((r == r0_split[0][: r.numel()]).all() for r in r0_split[1:])
r1_split = r1.split(7)
assert all((r == r1_split[0][: r.numel()]).all() for r in r1_split[1:])
r2_split = r2.split(8)
assert all((r == r2_split[0][: r.numel()]).all() for r in r2_split[1:])

def test_trans_serial_env_check(self):
torch.manual_seed(0)
base_env = SerialEnv(
3,
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
batch_locked=False,
)
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
policy_odd = self._create_policy_odd(base_env)
policy_even = self._create_policy_even(base_env)
transforms = self._create_transforms(condition, policy_even)
env = base_env.append_transform(transforms)
self._test_env(env, policy_odd)

def test_trans_parallel_env_check(self):
torch.manual_seed(0)
base_env = ParallelEnv(
3,
[partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)],
batch_locked=False,
mp_start_method=mp_ctx,
)
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
policy_odd = self._create_policy_odd(base_env)
policy_even = self._create_policy_even(base_env)
transforms = self._create_transforms(condition, policy_even)
env = base_env.append_transform(transforms)
self._test_env(env, policy_odd)

def test_serial_trans_env_check(self):
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
policy_odd = self._create_policy_odd(CountingEnv())

def make_env(max_count):
return partial(self._make_env, max_count, CountingEnv)

env = SerialEnv(3, [make_env(6), make_env(7), make_env(8)])
self._test_env(env, policy_odd)

def test_parallel_trans_env_check(self):
condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1)
policy_odd = self._create_policy_odd(CountingEnv())

def make_env(max_count):
return partial(self._make_env, max_count, CountingEnv)

env = ParallelEnv(
3, [make_env(6), make_env(7), make_env(8)], mp_start_method=mp_ctx
)
self._test_env(env, policy_odd)

def test_transform_no_env(self):
policy_odd = lambda td: td
policy_even = lambda td: td
condition = lambda td: True
transforms = ConditionalPolicySwitch(condition=condition, policy=policy_even)
with pytest.raises(
RuntimeError,
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
):
transforms(TensorDict())

def test_transform_compose(self):
policy_odd = lambda td: td
policy_even = lambda td: td
condition = lambda td: True
transforms = Compose(
ConditionalPolicySwitch(condition=condition, policy=policy_even),
)
with pytest.raises(
RuntimeError,
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
):
transforms(TensorDict())

def test_transform_env(self):
base_env = CountingEnv(max_steps=15)
condition = lambda td: ((td.get("step_count") % 2) == 0).all()
# Player 0
policy_odd = lambda td: td.set("action", env.action_spec.zero())
policy_even = lambda td: td.set("action", env.action_spec.one())
transforms = Compose(
StepCounter(),
ConditionalPolicySwitch(condition=condition, policy=policy_even),
)
env = base_env.append_transform(transforms)
env.check_env_specs()
r = env.rollout(1000, policy_odd, break_when_all_done=True)
assert r.shape[0] == 15
assert (r["action"] == 0).all()
assert (
r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1)
).all()
assert r["next", "done"].any()

# Player 1
condition = lambda td: ((td.get("step_count") % 2) == 1).all()
transforms = Compose(
StepCounter(),
ConditionalPolicySwitch(condition=condition, policy=policy_odd),
)
env = base_env.append_transform(transforms)
r = env.rollout(1000, policy_even, break_when_all_done=True)
assert r.shape[0] == 16
assert (r["action"] == 1).all()
assert (
r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1)
).all()
assert r["next", "done"].any()

def test_transform_model(self):
policy_odd = lambda td: td
policy_even = lambda td: td
condition = lambda td: True
transforms = nn.Sequential(
ConditionalPolicySwitch(condition=condition, policy=policy_even),
)
with pytest.raises(
RuntimeError,
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
):
transforms(TensorDict())

@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
def test_transform_rb(self, rbclass):
policy_odd = lambda td: td
policy_even = lambda td: td
condition = lambda td: True
rb = rbclass(storage=LazyTensorStorage(10))
rb.append_transform(
ConditionalPolicySwitch(condition=condition, policy=policy_even)
)
rb.extend(TensorDict(batch_size=[2]))
with pytest.raises(
RuntimeError,
match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.",
):
rb.sample(2)

def test_transform_inverse(self):
return


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
8 changes: 5 additions & 3 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ def __init__(

# if share_individual_td is None, we will assess later if the output can be stacked
self.share_individual_td = share_individual_td
# self._batch_locked = batch_locked
self._share_memory = shared_memory
self._memmap = memmap
self.allow_step_when_done = allow_step_when_done
Expand Down Expand Up @@ -626,8 +627,8 @@ def map_device(key, value, device_map=device_map):
self._env_tensordict.named_apply(
map_device, nested_keys=True, filter_empty=True
)

self._batch_locked = meta_data.batch_locked
# if self._batch_locked is None:
# self._batch_locked = meta_data.batch_locked
else:
self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size])
devices = set()
Expand Down Expand Up @@ -668,7 +669,8 @@ def map_device(key, value, device_map=device_map):
self._env_tensordict = torch.stack(
[meta_data.tensordict for meta_data in meta_data], 0
)
self._batch_locked = meta_data[0].batch_locked
# if self._batch_locked is None:
# self._batch_locked = meta_data[0].batch_locked
self.has_lazy_inputs = contains_lazy_spec(self.input_spec)

def state_dict(self) -> OrderedDict:
Expand Down
4 changes: 1 addition & 3 deletions torchrl/envs/custom/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta):
batch_size=torch.Size([96]),
device=None,
is_shared=False)


"""
""" # noqa: D301

_hash_table: dict[int, str] = {}
_PGN_RESTART = """[Event "?"]
Expand Down
Loading
Loading