Skip to content

Commit f934657

Browse files
albertbou92Vincent Moens
authored andcommitted
[BugFix] Allow expanding TensorDictPrimer transforms shape with parent batch size (#2521)
(cherry picked from commit 98b45a6)
1 parent b3712ea commit f934657

File tree

2 files changed

+75
-7
lines changed

2 files changed

+75
-7
lines changed

test/test_transforms.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@
156156
from torchrl.envs.transforms.vip import _VIPNet, VIPRewardTransform
157157
from torchrl.envs.utils import check_env_specs, step_mdp
158158
from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal
159+
from torchrl.modules.utils import get_primers_from_module
159160

160161
IS_WIN = platform == "win32"
161162
if IS_WIN:
@@ -6953,6 +6954,33 @@ def test_dict_default_value(self):
69536954
rollout_td.get(("next", "mykey2")) == torch.tensor(1, dtype=torch.int64)
69546955
).all
69556956

6957+
def test_spec_shape_inplace_correction(self):
6958+
hidden_size = input_size = num_layers = 2
6959+
model = GRUModule(
6960+
input_size, hidden_size, num_layers, in_key="observation", out_key="action"
6961+
)
6962+
env = TransformedEnv(
6963+
SerialEnv(2, lambda: GymEnv("Pendulum-v1")),
6964+
)
6965+
# These primers do not have the leading batch dimension
6966+
# since model is agnostic to batch dimension that will be used.
6967+
primers = get_primers_from_module(model)
6968+
for primer in primers.primers:
6969+
assert primers.primers.get(primer).shape == torch.Size(
6970+
[num_layers, hidden_size]
6971+
)
6972+
env.append_transform(primers)
6973+
6974+
# Reset should add the batch dimension to the primers
6975+
# since the parent exists and is batch_locked.
6976+
td = env.reset()
6977+
6978+
for primer in primers.primers:
6979+
assert primers.primers.get(primer).shape == torch.Size(
6980+
[2, num_layers, hidden_size]
6981+
)
6982+
assert td.get(primer).shape == torch.Size([2, num_layers, hidden_size])
6983+
69566984

69576985
class TestTimeMaxPool(TransformBase):
69586986
@pytest.mark.parametrize("T", [2, 4])

torchrl/envs/transforms/transforms.py

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4592,10 +4592,11 @@ class TensorDictPrimer(Transform):
45924592
The corresponding value has to be a TensorSpec instance indicating
45934593
what the value must be.
45944594
4595-
When used in a TransfomedEnv, the spec shapes must match the envs shape if
4596-
the parent env is batch-locked (:obj:`env.batch_locked=True`).
4597-
If the env is not batch-locked (e.g. model-based envs), it is assumed that the batch is
4598-
given by the input tensordict instead.
4595+
When used in a `TransformedEnv`, the spec shapes must match the environment's shape if
4596+
the parent environment is batch-locked (`env.batch_locked=True`). If the spec shapes and
4597+
parent shapes do not match, the spec shapes are modified in-place to match the leading
4598+
dimensions of the parent's batch size. This adjustment is made for cases where the parent
4599+
batch size dimension is not known during instantiation.
45994600
46004601
Examples:
46014602
>>> from torchrl.envs.libs.gym import GymEnv
@@ -4635,6 +4636,40 @@ class TensorDictPrimer(Transform):
46354636
tensor([[1., 1., 1.],
46364637
[1., 1., 1.]])
46374638
4639+
Examples:
4640+
>>> from torchrl.envs.libs.gym import GymEnv
4641+
>>> from torchrl.envs import SerialEnv, TransformedEnv
4642+
>>> from torchrl.modules.utils import get_primers_from_module
4643+
>>> from torchrl.modules import GRUModule
4644+
>>> base_env = SerialEnv(2, lambda: GymEnv("Pendulum-v1"))
4645+
>>> env = TransformedEnv(base_env)
4646+
>>> model = GRUModule(input_size=2, hidden_size=2, in_key="observation", out_key="action")
4647+
>>> primers = get_primers_from_module(model)
4648+
>>> print(primers) # Primers shape is independent of the env batch size
4649+
TensorDictPrimer(primers=Composite(
4650+
recurrent_state: UnboundedContinuous(
4651+
shape=torch.Size([1, 2]),
4652+
space=ContinuousBox(
4653+
low=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True),
4654+
high=Tensor(shape=torch.Size([1, 2]), device=cpu, dtype=torch.float32, contiguous=True)),
4655+
device=cpu,
4656+
dtype=torch.float32,
4657+
domain=continuous),
4658+
device=None,
4659+
shape=torch.Size([])), default_value={'recurrent_state': 0.0}, random=None)
4660+
>>> env.append_transform(primers)
4661+
>>> print(env.reset()) # The primers are automatically expanded to match the env batch size
4662+
TensorDict(
4663+
fields={
4664+
done: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
4665+
observation: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False),
4666+
recurrent_state: Tensor(shape=torch.Size([2, 1, 2]), device=cpu, dtype=torch.float32, is_shared=False),
4667+
terminated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False),
4668+
truncated: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
4669+
batch_size=torch.Size([2]),
4670+
device=None,
4671+
is_shared=False)
4672+
46384673
.. note:: Some TorchRL modules rely on specific keys being present in the environment TensorDicts,
46394674
like :class:`~torchrl.modules.models.LSTM` or :class:`~torchrl.modules.models.GRU`.
46404675
To facilitate this process, the method :func:`~torchrl.modules.utils.get_primers_from_module`
@@ -4760,7 +4795,7 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
47604795
# We try to set the primer shape to the observation spec shape
47614796
self.primers.shape = observation_spec.shape
47624797
except ValueError:
4763-
# If we fail, we expnad them to that shape
4798+
# If we fail, we expand them to that shape
47644799
self.primers = self._expand_shape(self.primers)
47654800
device = observation_spec.device
47664801
observation_spec.update(self.primers.clone().to(device))
@@ -4827,12 +4862,17 @@ def _reset(
48274862
) -> TensorDictBase:
48284863
"""Sets the default values in the input tensordict.
48294864
4830-
If the parent is batch-locked, we assume that the specs have the appropriate leading
4865+
If the parent is batch-locked, we make sure the specs have the appropriate leading
48314866
shape. We allow for execution when the parent is missing, in which case the
48324867
spec shape is assumed to match the tensordict's.
4833-
48344868
"""
48354869
_reset = _get_reset(self.reset_key, tensordict)
4870+
if (
4871+
self.parent
4872+
and self.parent.batch_locked
4873+
and self.primers.shape[: len(self.parent.shape)] != self.parent.batch_size
4874+
):
4875+
self.primers = self._expand_shape(self.primers)
48364876
if _reset.any():
48374877
for key, spec in self.primers.items(True, True):
48384878
if self.random:

0 commit comments

Comments
 (0)