@@ -4592,10 +4592,11 @@ class TensorDictPrimer(Transform):
4592
4592
The corresponding value has to be a TensorSpec instance indicating
4593
4593
what the value must be.
4594
4594
4595
- When used in a TransfomedEnv, the spec shapes must match the envs shape if
4596
- the parent env is batch-locked (:obj:`env.batch_locked=True`).
4597
- If the env is not batch-locked (e.g. model-based envs), it is assumed that the batch is
4598
- given by the input tensordict instead.
4595
+ When used in a `TransformedEnv`, the spec shapes must match the environment's shape if
4596
+ the parent environment is batch-locked (`env.batch_locked=True`). If the spec shapes and
4597
+ parent shapes do not match, the spec shapes are modified in-place to match the leading
4598
+ dimensions of the parent's batch size. This adjustment is made for cases where the parent
4599
+ batch size dimension is not known during instantiation.
4599
4600
4600
4601
Examples:
4601
4602
>>> from torchrl.envs.libs.gym import GymEnv
@@ -4635,6 +4636,40 @@ class TensorDictPrimer(Transform):
4635
4636
tensor([[1., 1., 1.],
4636
4637
[1., 1., 1.]])
4637
4638
4639
+ Examples:
4640
+ >>> from torchrl.envs.libs.gym import GymEnv
4641
+ >>> from torchrl.envs import SerialEnv, TransformedEnv
4642
+ >>> from torchrl.modules.utils import get_primers_from_module
4643
+ >>> from torchrl.modules import GRUModule
4644
+ >>> base_env = SerialEnv(2, lambda: GymEnv("Pendulum-v1"))
4645
+ >>> env = TransformedEnv(base_env)
4646
+ >>> model = GRUModule(input_size=2, hidden_size=2, in_key="observation", out_key="action")
4647
+ >>> primers = get_primers_from_module(model)
4648
+ >>> print(primers) # Primers shape is independent of the env batch size
4649
+ TensorDictPrimer(primers=Composite(
4650
+ recurrent_state: UnboundedContinuous(
4651
+ shape=torch.Size([1, 2]),
4652
+ space=ContinuousBox(
4653
+ low=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True),
4654
+ high=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True)),
4655
+ device=cpu,
4656
+ dtype=torch.float32,
4657
+ domain=continuous),
4658
+ device=None,
4659
+ shape=torch.Size([])), default_value={'recurrent_state': 0.0}, random=None)
4660
+ >>> env.append_transform(primers)
4661
+ >>> print(env.reset()) # The primers are automatically expanded to match the env batch size
4662
+ TensorDict(
4663
+ fields={
4664
+ done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
4665
+ observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
4666
+ recurrent_state: Tensor(shape=torch.Size([2, 1, 2]), device=cpu, dtype=torch.float32, is_shared=False),
4667
+ terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
4668
+ truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
4669
+ batch_size=torch.Size([2]),
4670
+ device=None,
4671
+ is_shared=False)
4672
+
4638
4673
.. note:: Some TorchRL modules rely on specific keys being present in the environment TensorDicts,
4639
4674
like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`.
4640
4675
To facilitate this process, the method :func:`~torchrl.modules.utils.get_primers_from_module`
@@ -4760,7 +4795,7 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
4760
4795
# We try to set the primer shape to the observation spec shape
4761
4796
self .primers .shape = observation_spec .shape
4762
4797
except ValueError :
4763
- # If we fail, we expnad them to that shape
4798
+ # If we fail, we expand them to that shape
4764
4799
self .primers = self ._expand_shape (self .primers )
4765
4800
device = observation_spec .device
4766
4801
observation_spec .update (self .primers .clone ().to (device ))
@@ -4827,12 +4862,17 @@ def _reset(
4827
4862
) -> TensorDictBase :
4828
4863
"""Sets the default values in the input tensordict.
4829
4864
4830
- If the parent is batch-locked, we assume that the specs have the appropriate leading
4865
+ If the parent is batch-locked, we make sure the specs have the appropriate leading
4831
4866
shape. We allow for execution when the parent is missing, in which case the
4832
4867
spec shape is assumed to match the tensordict's.
4833
-
4834
4868
"""
4835
4869
_reset = _get_reset (self .reset_key , tensordict )
4870
+ if (
4871
+ self .parent
4872
+ and self .parent .batch_locked
4873
+ and self .primers .shape [: len (self .parent .shape )] != self .parent .batch_size
4874
+ ):
4875
+ self .primers = self ._expand_shape (self .primers )
4836
4876
if _reset .any ():
4837
4877
for key , spec in self .primers .items (True , True ):
4838
4878
if self .random :
0 commit comments