diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index ec8b29a9abd..7aa6fd7a42c 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -1112,6 +1112,7 @@ to be able to create this other composition: CenterCrop ClipTransform Compose + ConditionalPolicySwitch ConditionalSkip Crop DataLoadingPrimer diff --git a/test/test_transforms.py b/test/test_transforms.py index ef947a8a704..7f38f6cd58b 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -24,6 +24,7 @@ import tensordict.tensordict import torch + from tensordict import ( assert_close, LazyStackedTensorDict, @@ -33,7 +34,7 @@ TensorDictBase, unravel_key, ) -from tensordict.nn import TensorDictSequential +from tensordict.nn import TensorDictSequential, WrapModule from tensordict.utils import _unravel_key_to_tuple, assert_allclose_td from torch import multiprocessing as mp, nn, Tensor from torchrl._utils import _replace_last, prod, set_auto_unwrap_transformed_env @@ -62,6 +63,7 @@ CenterCrop, ClipTransform, Compose, + ConditionalPolicySwitch, ConditionalSkip, Crop, DeviceCastTransform, @@ -14526,6 +14528,206 @@ def test_can_init_with_fps(self): assert recorder is not None +class TestConditionalPolicySwitch(TransformBase): + def test_single_trans_env_check(self): + base_env = CountingEnv(max_steps=15) + condition = lambda td: ((td.get("step_count") % 2) == 0).all() + # Player 0 + policy_odd = lambda td: td.set("action", env.action_spec.zero()) + policy_even = lambda td: td.set("action", env.action_spec.one()) + transforms = Compose( + StepCounter(), + ConditionalPolicySwitch(condition=condition, policy=policy_even), + ) + env = base_env.append_transform(transforms) + env.check_env_specs() + + def _create_policy_odd(self, base_env): + return WrapModule( + lambda td, base_env=base_env: td.set( + "action", base_env.action_spec_unbatched.zero(td.shape) + ), + out_keys=["action"], + ) + + def _create_policy_even(self, base_env): + return WrapModule( + lambda td, base_env=base_env: td.set( + "action", base_env.action_spec_unbatched.one(td.shape) + ), + out_keys=["action"], + ) + + def _create_transforms(self, condition, policy_even): + return Compose( + StepCounter(), + ConditionalPolicySwitch(condition=condition, policy=policy_even), + ) + + def _make_env(self, max_count, env_cls): + torch.manual_seed(0) + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) + base_env = env_cls(max_steps=max_count) + policy_even = self._create_policy_even(base_env) + transforms = self._create_transforms(condition, policy_even) + return base_env.append_transform(transforms) + + def _test_env(self, env, policy_odd): + env.check_env_specs() + env.set_seed(0) + r = env.rollout(100, policy_odd, break_when_any_done=False) + # Check results are independent: one reset / step in one env should not impact results in another + r0, r1, r2 = r.unbind(0) + r0_split = r0.split(6) + assert all((r == r0_split[0][: r.numel()]).all() for r in r0_split[1:]) + r1_split = r1.split(7) + assert all((r == r1_split[0][: r.numel()]).all() for r in r1_split[1:]) + r2_split = r2.split(8) + assert all((r == r2_split[0][: r.numel()]).all() for r in r2_split[1:]) + + def test_trans_serial_env_check(self): + torch.manual_seed(0) + base_env = SerialEnv( + 3, + [partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)], + batch_locked=False, + ) + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) + policy_odd = self._create_policy_odd(base_env) + policy_even = self._create_policy_even(base_env) + transforms = self._create_transforms(condition, policy_even) + env = base_env.append_transform(transforms) + self._test_env(env, policy_odd) + + def test_trans_parallel_env_check(self): + torch.manual_seed(0) + base_env = ParallelEnv( + 3, + [partial(CountingEnv, 6), partial(CountingEnv, 7), partial(CountingEnv, 8)], + batch_locked=False, + mp_start_method=mp_ctx, + ) + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) + policy_odd = self._create_policy_odd(base_env) + policy_even = self._create_policy_even(base_env) + transforms = self._create_transforms(condition, policy_even) + env = base_env.append_transform(transforms) + self._test_env(env, policy_odd) + + def test_serial_trans_env_check(self): + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) + policy_odd = self._create_policy_odd(CountingEnv()) + + def make_env(max_count): + return partial(self._make_env, max_count, CountingEnv) + + env = SerialEnv(3, [make_env(6), make_env(7), make_env(8)]) + self._test_env(env, policy_odd) + + def test_parallel_trans_env_check(self): + condition = lambda td: ((td.get("step_count") % 2) == 0).squeeze(-1) + policy_odd = self._create_policy_odd(CountingEnv()) + + def make_env(max_count): + return partial(self._make_env, max_count, CountingEnv) + + env = ParallelEnv( + 3, [make_env(6), make_env(7), make_env(8)], mp_start_method=mp_ctx + ) + self._test_env(env, policy_odd) + + def test_transform_no_env(self): + policy_odd = lambda td: td + policy_even = lambda td: td + condition = lambda td: True + transforms = ConditionalPolicySwitch(condition=condition, policy=policy_even) + with pytest.raises( + RuntimeError, + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", + ): + transforms(TensorDict()) + + def test_transform_compose(self): + policy_odd = lambda td: td + policy_even = lambda td: td + condition = lambda td: True + transforms = Compose( + ConditionalPolicySwitch(condition=condition, policy=policy_even), + ) + with pytest.raises( + RuntimeError, + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", + ): + transforms(TensorDict()) + + def test_transform_env(self): + base_env = CountingEnv(max_steps=15) + condition = lambda td: ((td.get("step_count") % 2) == 0).all() + # Player 0 + policy_odd = lambda td: td.set("action", env.action_spec.zero()) + policy_even = lambda td: td.set("action", env.action_spec.one()) + transforms = Compose( + StepCounter(), + ConditionalPolicySwitch(condition=condition, policy=policy_even), + ) + env = base_env.append_transform(transforms) + env.check_env_specs() + r = env.rollout(1000, policy_odd, break_when_all_done=True) + assert r.shape[0] == 15 + assert (r["action"] == 0).all() + assert ( + r["step_count"] == torch.arange(1, r.numel() * 2, 2).unsqueeze(-1) + ).all() + assert r["next", "done"].any() + + # Player 1 + condition = lambda td: ((td.get("step_count") % 2) == 1).all() + transforms = Compose( + StepCounter(), + ConditionalPolicySwitch(condition=condition, policy=policy_odd), + ) + env = base_env.append_transform(transforms) + r = env.rollout(1000, policy_even, break_when_all_done=True) + assert r.shape[0] == 16 + assert (r["action"] == 1).all() + assert ( + r["step_count"] == torch.arange(0, r.numel() * 2, 2).unsqueeze(-1) + ).all() + assert r["next", "done"].any() + + def test_transform_model(self): + policy_odd = lambda td: td + policy_even = lambda td: td + condition = lambda td: True + transforms = nn.Sequential( + ConditionalPolicySwitch(condition=condition, policy=policy_even), + ) + with pytest.raises( + RuntimeError, + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", + ): + transforms(TensorDict()) + + @pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer]) + def test_transform_rb(self, rbclass): + policy_odd = lambda td: td + policy_even = lambda td: td + condition = lambda td: True + rb = rbclass(storage=LazyTensorStorage(10)) + rb.append_transform( + ConditionalPolicySwitch(condition=condition, policy=policy_even) + ) + rb.extend(TensorDict(batch_size=[2])) + with pytest.raises( + RuntimeError, + match="ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional.", + ): + rb.sample(2) + + def test_transform_inverse(self): + return + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 39804ac4352..46c0c78155f 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -350,6 +350,7 @@ def __init__( # if share_individual_td is None, we will assess later if the output can be stacked self.share_individual_td = share_individual_td + # self._batch_locked = batch_locked self._share_memory = shared_memory self._memmap = memmap self.allow_step_when_done = allow_step_when_done @@ -626,8 +627,8 @@ def map_device(key, value, device_map=device_map): self._env_tensordict.named_apply( map_device, nested_keys=True, filter_empty=True ) - - self._batch_locked = meta_data.batch_locked + # if self._batch_locked is None: + # self._batch_locked = meta_data.batch_locked else: self._batch_size = torch.Size([self.num_workers, *meta_data[0].batch_size]) devices = set() @@ -668,7 +669,8 @@ def map_device(key, value, device_map=device_map): self._env_tensordict = torch.stack( [meta_data.tensordict for meta_data in meta_data], 0 ) - self._batch_locked = meta_data[0].batch_locked + # if self._batch_locked is None: + # self._batch_locked = meta_data[0].batch_locked self.has_lazy_inputs = contains_lazy_spec(self.input_spec) def state_dict(self) -> OrderedDict: diff --git a/torchrl/envs/custom/chess.py b/torchrl/envs/custom/chess.py index 742519f2fec..23ccbc73b0a 100644 --- a/torchrl/envs/custom/chess.py +++ b/torchrl/envs/custom/chess.py @@ -195,9 +195,7 @@ class ChessEnv(EnvBase, metaclass=_ChessMeta): batch_size=torch.Size([96]), device=None, is_shared=False) - - - """ + """ # noqa: D301 _hash_table: dict[int, str] = {} _PGN_RESTART = """[Event "?"] diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 19e2ad7ec7d..1e437462881 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -85,6 +85,7 @@ ) from torchrl.envs.utils import ( _sort_keys, + _terminated_or_truncated, _update_during_reset, make_composite_from_td, step_mdp, @@ -11186,3 +11187,243 @@ def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec def forward(self, tensordict: TensorDictBase) -> TensorDictBase: raise NotImplementedError(FORWARD_NOT_IMPLEMENTED) + + +class ConditionalPolicySwitch(Transform): + """A transform that conditionally switches between policies based on a specified condition. + + This transform evaluates a condition on the data returned by the environment's `step` method. + If the condition is met, it applies a specified policy to the data. Otherwise, the data is + returned unaltered. This is useful for scenarios where different policies need to be applied + based on certain criteria, such as alternating turns in a game. + + Args: + policy (Callable[[TensorDictBase], TensorDictBase]): + The policy to be applied when the condition is met. This should be a callable that + takes a `TensorDictBase` and returns a `TensorDictBase`. + condition (Callable[[TensorDictBase], bool]): + A callable that takes a `TensorDictBase` and returns a boolean or a tensor indicating + whether the policy should be applied. + + .. warning:: This transform must have a parent environment. + + .. note:: Ideally, it should be the last transform in the stack. If the policy requires transformed + data (e.g., images), and this transform is applied before those transformations, the policy will + not receive the transformed data. + + Examples: + >>> import torch + >>> from tensordict.nn import TensorDictModule as Mod + >>> + >>> from torchrl.envs import GymEnv, ConditionalPolicySwitch, Compose, StepCounter + >>> # Create a CartPole environment. We'll be looking at the obs: if the first element of the obs is greater than + >>> # 0 (left position) we do a right action (action=0) using the switch policy. Otherwise, we use our main + >>> # policy which does a left action. + >>> base_env = GymEnv("CartPole-v1", categorical_action_encoding=True) + >>> + >>> policy = Mod(lambda: torch.ones((), dtype=torch.int64), in_keys=[], out_keys=["action"]) + >>> policy_switch = Mod(lambda: torch.zeros((), dtype=torch.int64), in_keys=[], out_keys=["action"]) + >>> + >>> cond = lambda td: td.get("observation")[..., 0] >= 0 + >>> + >>> env = base_env.append_transform( + ... Compose( + ... # We use two step counters to show that one counts the global steps, whereas the other + ... # only counts the steps where the main policy is executed + ... StepCounter(step_count_key="step_count_total"), + ... ConditionalPolicySwitch(condition=cond, policy=policy_switch), + ... StepCounter(step_count_key="step_count_main"), + ... ) + ... ) + >>> + >>> env.set_seed(0) + >>> torch.manual_seed(0) + >>> + >>> r = env.rollout(100, policy=policy) + >>> print("action", r["action"]) + action tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]) + >>> print("obs", r["observation"]) + obs tensor([[ 0.0322, -0.1540, 0.0111, 0.3190], + [ 0.0299, -0.1544, 0.0181, 0.3280], + [ 0.0276, -0.1550, 0.0255, 0.3414], + [ 0.0253, -0.1558, 0.0334, 0.3596], + [ 0.0230, -0.1569, 0.0422, 0.3828], + [ 0.0206, -0.1582, 0.0519, 0.4117], + [ 0.0181, -0.1598, 0.0629, 0.4469], + [ 0.0156, -0.1617, 0.0753, 0.4891], + [ 0.0130, -0.1639, 0.0895, 0.5394], + [ 0.0104, -0.1665, 0.1058, 0.5987], + [ 0.0076, -0.1696, 0.1246, 0.6685], + [ 0.0047, -0.1732, 0.1463, 0.7504], + [ 0.0016, -0.1774, 0.1715, 0.8459], + [-0.0020, 0.0150, 0.1884, 0.6117], + [-0.0017, 0.2071, 0.2006, 0.3838]]) + >>> print("obs'", r["next", "observation"]) + obs' tensor([[ 0.0299, -0.1544, 0.0181, 0.3280], + [ 0.0276, -0.1550, 0.0255, 0.3414], + [ 0.0253, -0.1558, 0.0334, 0.3596], + [ 0.0230, -0.1569, 0.0422, 0.3828], + [ 0.0206, -0.1582, 0.0519, 0.4117], + [ 0.0181, -0.1598, 0.0629, 0.4469], + [ 0.0156, -0.1617, 0.0753, 0.4891], + [ 0.0130, -0.1639, 0.0895, 0.5394], + [ 0.0104, -0.1665, 0.1058, 0.5987], + [ 0.0076, -0.1696, 0.1246, 0.6685], + [ 0.0047, -0.1732, 0.1463, 0.7504], + [ 0.0016, -0.1774, 0.1715, 0.8459], + [-0.0020, 0.0150, 0.1884, 0.6117], + [-0.0017, 0.2071, 0.2006, 0.3838], + [ 0.0105, 0.2015, 0.2115, 0.5110]]) + >>> print("total step count", r["step_count_total"].squeeze()) + total step count tensor([ 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 26, 27]) + >>> print("total step with main policy", r["step_count_main"].squeeze()) + total step with main policy tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]) + + """ + + def __init__( + self, + policy: Callable[[TensorDictBase], TensorDictBase], + condition: Callable[[TensorDictBase], bool], + ): + super().__init__([], []) + self.__dict__["policy"] = policy + self.condition = condition + + def _step( + self, tensordict: TensorDictBase, next_tensordict: TensorDictBase + ) -> TensorDictBase: + cond = self.condition(next_tensordict) + if not isinstance(cond, (bool, torch.Tensor)): + raise RuntimeError( + "Calling the condition function should return a boolean or a tensor." + ) + elif isinstance(cond, (torch.Tensor,)): + if tuple(cond.shape) not in ((1,), (), tuple(tensordict.shape)): + raise RuntimeError( + "Tensor outputs must have the shape of the tensordict, or contain a single element." + ) + else: + cond = torch.tensor(cond, device=tensordict.device) + + if cond.any(): + step = tensordict.get("_step", cond) + if step.shape != cond.shape: + step = step.view_as(cond) + cond = cond & step + + parent: TransformedEnv = self.parent + any_done, done = self._check_done(next_tensordict) + next_td_save = None + if any_done: + if next_tensordict.numel() == 1 or done.all(): + return next_tensordict + if parent.base_env.batch_locked: + raise RuntimeError( + "Cannot run partial steps in a batched locked environment. " + "Hint: Parallel and Serial envs can be unlocked through a keyword argument in " + "the constructor." + ) + done = done.view(next_tensordict.shape) + cond = cond & ~done + if not cond.all(): + if parent.base_env.batch_locked: + raise RuntimeError( + "Cannot run partial steps in a batched locked environment. " + "Hint: Parallel and Serial envs can be unlocked through a keyword argument in " + "the constructor." + ) + next_td_save = next_tensordict + next_tensordict = next_tensordict[cond] + tensordict = tensordict[cond] + + # policy may be expensive or raise an exception when executed with unadequate data so + # we index the td first + td = self.policy( + parent.step_mdp(tensordict.copy().set("next", next_tensordict)) + ) + # Mark the partial steps if needed + if next_td_save is not None: + td_new = td.new_zeros(cond.shape) + # TODO: swap with masked_scatter when avail + td_new[cond] = td + td = td_new + td.set("_step", cond) + next_tensordict = parent._step(td) + if next_td_save is not None: + return torch.where(cond, next_tensordict, next_td_save) + return next_tensordict + return next_tensordict + + def _check_done(self, tensordict): + env = self.parent + if env._simple_done: + done = tensordict._get_str("done", default=None) + if done is not None: + any_done = done.any() + else: + any_done = False + else: + any_done = _terminated_or_truncated( + tensordict, + full_done_spec=env.output_spec["full_done_spec"], + key="_reset", + ) + done = tensordict.pop("_reset") + return any_done, done + + def _reset( + self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase + ) -> TensorDictBase: + cond = self.condition(tensordict_reset) + # TODO: move to validate + if not isinstance(cond, (bool, torch.Tensor)): + raise RuntimeError( + "Calling the condition function should return a boolean or a tensor." + ) + elif isinstance(cond, (torch.Tensor,)): + if tuple(cond.shape) not in ((1,), (), tuple(tensordict.shape)): + raise RuntimeError( + "Tensor outputs must have the shape of the tensordict, or contain a single element." + ) + else: + cond = torch.tensor(cond, device=tensordict.device) + + if cond.any(): + reset = tensordict.get("_reset", cond) + if reset.shape != cond.shape: + reset = reset.view_as(cond) + cond = cond & reset + + parent: TransformedEnv = self.parent + reset_td_save = None + if not cond.all(): + if parent.base_env.batch_locked: + raise RuntimeError( + "Cannot run partial steps in a batched locked environment. " + "Hint: Parallel and Serial envs can be unlocked through a keyword argument in " + "the constructor." + ) + reset_td_save = tensordict_reset.copy() + tensordict_reset = tensordict_reset[cond] + tensordict = tensordict[cond] + + td = self.policy(tensordict_reset) + # Mark the partial steps if needed + if reset_td_save is not None: + td_new = td.new_zeros(cond.shape) + # TODO: swap with masked_scatter when avail + td_new[cond] = td + td = td_new + td.set("_step", cond) + tensordict_reset = parent._step(td).exclude(*parent.reward_keys) + if reset_td_save is not None: + return torch.where(cond, tensordict_reset, reset_td_save) + return tensordict_reset + + return tensordict_reset + + def forward(self, tensordict: TensorDictBase) -> Any: + raise RuntimeError( + "ConditionalPolicySwitch cannot be called independently, only its step and reset methods are functional." + )