diff --git a/test/test_transforms.py b/test/test_transforms.py index 1c1903a3b1b..942814818e6 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -159,7 +159,6 @@ from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform from torchrl.envs.utils import check_env_specs, step_mdp from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal -from torchrl.modules.utils import get_primers_from_module IS_WIN = platform == "win32" if IS_WIN: @@ -7164,33 +7163,6 @@ def test_dict_default_value(self): rollout_td.get(("next", "mykey2")) == torch.tensor(1, dtype=torch.int64) ).all - def test_spec_shape_inplace_correction(self): - hidden_size = input_size = num_layers = 2 - model = GRUModule( - input_size, hidden_size, num_layers, in_key="observation", out_key="action" - ) - env = TransformedEnv( - SerialEnv(2, lambda: GymEnv("Pendulum-v1")), - ) - # These primers do not have the leading batch dimension - # since model is agnostic to batch dimension that will be used. - primers = get_primers_from_module(model) - for primer in primers.primers: - assert primers.primers.get(primer).shape == torch.Size( - [num_layers, hidden_size] - ) - env.append_transform(primers) - - # Reset should add the batch dimension to the primers - # since the parent exists and is batch_locked. - td = env.reset() - - for primer in primers.primers: - assert primers.primers.get(primer).shape == torch.Size( - [2, num_layers, hidden_size] - ) - assert td.get(primer).shape == torch.Size([2, num_layers, hidden_size]) - class TestTimeMaxPool(TransformBase): @pytest.mark.parametrize("T", [2, 4]) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index 83233aaaac0..600e03775f7 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -4596,11 +4596,10 @@ class TensorDictPrimer(Transform): The corresponding value has to be a TensorSpec instance indicating what the value must be. - When used in a `TransformedEnv`, the spec shapes must match the environment's shape if - the parent environment is batch-locked (`env.batch_locked=True`). If the spec shapes and - parent shapes do not match, the spec shapes are modified in-place to match the leading - dimensions of the parent's batch size. This adjustment is made for cases where the parent - batch size dimension is not known during instantiation. + When used in a TransfomedEnv, the spec shapes must match the envs shape if + the parent env is batch-locked (:obj:`env.batch_locked=True`). + If the env is not batch-locked (e.g. model-based envs), it is assumed that the batch is + given by the input tensordict instead. Examples: >>> from torchrl.envs.libs.gym import GymEnv @@ -4640,40 +4639,6 @@ class TensorDictPrimer(Transform): tensor([[1., 1., 1.], [1., 1., 1.]]) - Examples: - >>> from torchrl.envs.libs.gym import GymEnv - >>> from torchrl.envs import SerialEnv, TransformedEnv - >>> from torchrl.modules.utils import get_primers_from_module - >>> from torchrl.modules import GRUModule - >>> base_env = SerialEnv(2, lambda: GymEnv("Pendulum-v1")) - >>> env = TransformedEnv(base_env) - >>> model = GRUModule(input_size=2, hidden_size=2, in_key="observation", out_key="action") - >>> primers = get_primers_from_module(model) - >>> print(primers) # Primers shape is independent of the env batch size - TensorDictPrimer(primers=Composite( - recurrent_state: UnboundedContinuous( - shape=torch.Size([1, 2]), - space=ContinuousBox( - low=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True), - high=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True)), - device=cpu, - dtype=torch.float32, - domain=continuous), - device=None, - shape=torch.Size([])), default_value={'recurrent_state': 0.0}, random=None) - >>> env.append_transform(primers) - >>> print(env.reset()) # The primers are automatically expanded to match the env batch size - TensorDict( - fields={ - done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), - recurrent_state: Tensor(shape=torch.Size([2, 1, 2]), device=cpu, dtype=torch.float32, is_shared=False), - terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False), - truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([2]), - device=None, - is_shared=False) - .. note:: Some TorchRL modules rely on specific keys being present in the environment TensorDicts, like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`. To facilitate this process, the method :func:`~torchrl.modules.utils.get_primers_from_module` @@ -4799,7 +4764,7 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite: # We try to set the primer shape to the observation spec shape self.primers.shape = observation_spec.shape except ValueError: - # If we fail, we expand them to that shape + # If we fail, we expnad them to that shape self.primers = self._expand_shape(self.primers) device = observation_spec.device observation_spec.update(self.primers.clone().to(device)) @@ -4866,17 +4831,12 @@ def _reset( ) -> TensorDictBase: """Sets the default values in the input tensordict. - If the parent is batch-locked, we make sure the specs have the appropriate leading + If the parent is batch-locked, we assume that the specs have the appropriate leading shape. We allow for execution when the parent is missing, in which case the spec shape is assumed to match the tensordict's. + """ _reset = _get_reset(self.reset_key, tensordict) - if ( - self.parent - and self.parent.batch_locked - and self.primers.shape[: len(self.parent.shape)] != self.parent.batch_size - ): - self.primers = self._expand_shape(self.primers) if _reset.any(): for key, spec in self.primers.items(True, True): if self.random: