diff --git a/test/test_transforms.py b/test/test_transforms.py index 84c4b3871fa..b4627fdfebb 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -142,6 +142,7 @@ from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform from torchrl.envs.utils import check_env_specs, step_mdp from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal +from torchrl.modules.utils import get_primers_from_module IS_WIN = platform == "win32" if IS_WIN: @@ -6939,6 +6940,33 @@ def test_dict_default_value(self): rollout_td.get(("next", "mykey2")) == torch.tensor(1, dtype=torch.int64) ).all + def test_spec_shape_inplace_correction(self): + hidden_size = input_size = num_layers = 2 + model = GRUModule( + input_size, hidden_size, num_layers, in_key="observation", out_key="action" + ) + env = TransformedEnv( + SerialEnv(2, lambda: GymEnv("Pendulum-v1")), + ) + # These primers do not have the leading batch dimension + # since model is agnostic to batch dimension that will be used. + primers = get_primers_from_module(model) + for primer in primers.primers: + assert primers.primers.get(primer).shape == torch.Size( + [num_layers, hidden_size] + ) + env.append_transform(primers) + + # Reset should add the batch dimension to the primers + # since the parent exists and is batch_locked. + td = env.reset() + + for primer in primers.primers: + assert primers.primers.get(primer).shape == torch.Size( + [2, num_layers, hidden_size] + ) + assert td.get(primer).shape == torch.Size([2, num_layers, hidden_size]) + class TestTimeMaxPool(TransformBase): @pytest.mark.parametrize("T", [2, 4]) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index b70e05ca431..97ffb3650f6 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4592,10 +4592,11 @@ class TensorDictPrimer(Transform): The corresponding value has to be a TensorSpec instance indicating what the value must be. - When used in a TransfomedEnv, the spec shapes must match the envs shape if - the parent env is batch-locked (:obj:`env.batch_locked=True`). - If the env is not batch-locked (e.g. model-based envs), it is assumed that the batch is - given by the input tensordict instead. + When used in a `TransformedEnv`, the spec shapes must match the environment's shape if + the parent environment is batch-locked (`env.batch_locked=True`). If the spec shapes and + parent shapes do not match, the spec shapes are modified in-place to match the leading + dimensions of the parent's batch size. This adjustment is made for cases where the parent + batch size dimension is not known during instantiation. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -4635,6 +4636,40 @@ class TensorDictPrimer(Transform): tensor([[1., 1., 1.], [1., 1., 1.]]) + Examples: + >>> from torchrl.envs.libs.gym import GymEnv + >>> from torchrl.envs import SerialEnv, TransformedEnv + >>> from torchrl.modules.utils import get_primers_from_module + >>> from torchrl.modules import GRUModule + >>> base_env = SerialEnv(2, lambda: GymEnv("Pendulum-v1")) + >>> env = TransformedEnv(base_env) + >>> model = GRUModule(input_size=2, hidden_size=2, in_key="observation", out_key="action") + >>> primers = get_primers_from_module(model) + >>> print(primers) # Primers shape is independent of the env batch size + TensorDictPrimer(primers=Composite( + recurrent_state: UnboundedContinuous( + shape=torch.Size([1, 2]), + space=ContinuousBox( + low=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True), + high=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True)), + device=cpu, + dtype=torch.float32, + domain=continuous), + device=None, + shape=torch.Size([])), default_value={'recurrent_state': 0.0}, random=None) + >>> env.append_transform(primers) + >>> print(env.reset()) # The primers are automatically expanded to match the env batch size + TensorDict( + fields={ + done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), + observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), + recurrent_state: Tensor(shape=torch.Size([2, 1, 2]), device=cpu, dtype=torch.float32, is_shared=False), + terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), + truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, + batch_size=torch.Size([2]), + device=None, + is_shared=False) + .. note:: Some TorchRL modules rely on specific keys being present in the environment TensorDicts, like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`. To facilitate this process, the method :func:`~torchrl.modules.utils.get_primers_from_module` @@ -4760,7 +4795,7 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite: # We try to set the primer shape to the observation spec shape self.primers.shape = observation_spec.shape except ValueError: - # If we fail, we expnad them to that shape + # If we fail, we expand them to that shape self.primers = self._expand_shape(self.primers) device = observation_spec.device observation_spec.update(self.primers.clone().to(device)) @@ -4827,12 +4862,17 @@ def _reset( ) -> TensorDictBase: """Sets the default values in the input tensordict. - If the parent is batch-locked, we assume that the specs have the appropriate leading + If the parent is batch-locked, we make sure the specs have the appropriate leading shape. We allow for execution when the parent is missing, in which case the spec shape is assumed to match the tensordict's. - """ _reset = _get_reset(self.reset_key, tensordict) + if ( + self.parent + and self.parent.batch_locked + and self.primers.shape[: len(self.parent.shape)] != self.parent.batch_size + ): + self.primers = self._expand_shape(self.primers) if _reset.any(): for key, spec in self.primers.items(True, True): if self.random: