Skip to content

Commit

Permalink
[BugFix] Compatibility of tensordict primers with batched envs (speci…
Browse files Browse the repository at this point in the history
…fically for LSTM and GRU)

ghstack-source-id: e1da58ecfd36ca01b8a11fe90e5f3c5fe77f064c
Pull Request resolved: #2668
  • Loading branch information
vmoens committed Dec 20, 2024
1 parent 133d709 commit f4709c1
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 48 deletions.
2 changes: 1 addition & 1 deletion sota-implementations/decision_transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def make_transformed_env(base_env, env_cfg, obs_loc, obs_std, train=False):
)

# copy action from the input tensordict to the output
transformed_env.append_transform(TensorDictPrimer(action=base_env.action_spec))
transformed_env.append_transform(TensorDictPrimer(base_env.full_action_spec))

transformed_env.append_transform(DoubleToFloat())
obsnorm = ObservationNorm(
Expand Down
105 changes: 80 additions & 25 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.

import argparse
import functools
import os

import pytest
Expand All @@ -12,6 +13,7 @@
import torchrl.modules
from tensordict import LazyStackedTensorDict, pad, TensorDict, unravel_key_list
from tensordict.nn import InteractionType, TensorDictModule, TensorDictSequential
from tensordict.utils import assert_close
from torch import nn
from torchrl.data.tensor_specs import Bounded, Composite, Unbounded
from torchrl.envs import (
Expand Down Expand Up @@ -938,10 +940,12 @@ def test_multi_consecutive(self, shape, python_based):
@pytest.mark.parametrize("python_based", [True, False])
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_lstm_parallel_env(self, python_based, parallel, heterogeneous):
@pytest.mark.parametrize("within", [False, True])
def test_lstm_parallel_env(self, python_based, parallel, heterogeneous, within):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

torch.manual_seed(0)
num_envs = 3
device = "cuda" if torch.cuda.device_count() else "cpu"
# tests that hidden states are carried over with parallel envs
lstm_module = LSTMModule(
Expand All @@ -958,25 +962,36 @@ def test_lstm_parallel_env(self, python_based, parallel, heterogeneous):
else:
cls = SerialEnv

def create_transformed_env():
primer = lstm_module.make_tensordict_primer()
env = DiscreteActionVecMockEnv(
categorical_action_encoding=True, device=device
if within:

def create_transformed_env():
primer = lstm_module.make_tensordict_primer()
env = DiscreteActionVecMockEnv(
categorical_action_encoding=True, device=device
)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(primer)
return env

else:
create_transformed_env = functools.partial(
DiscreteActionVecMockEnv,
categorical_action_encoding=True,
device=device,
)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(primer)
return env

if heterogeneous:
create_transformed_env = [
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env) for _ in range(num_envs)
]
env = cls(
create_env_fn=create_transformed_env,
num_workers=2,
num_workers=num_envs,
)
if not within:
env = env.append_transform(InitTracker())
env.append_transform(lstm_module.make_tensordict_primer())

mlp = TensorDictModule(
MLP(
Expand All @@ -1002,6 +1017,19 @@ def create_transformed_env():
data = env.rollout(10, actor, break_when_any_done=break_when_any_done)
assert (data.get(("next", "recurrent_state_c")) != 0.0).all()
assert (data.get("recurrent_state_c") != 0.0).any()
return data

@pytest.mark.parametrize("python_based", [True, False])
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_lstm_parallel_within(self, python_based, parallel, heterogeneous):
out_within = self.test_lstm_parallel_env(
python_based, parallel, heterogeneous, within=True
)
out_not_within = self.test_lstm_parallel_env(
python_based, parallel, heterogeneous, within=False
)
assert_close(out_within, out_not_within)

@pytest.mark.skipif(
not _has_functorch, reason="vmap can only be used with functorch"
Expand Down Expand Up @@ -1330,10 +1358,12 @@ def test_multi_consecutive(self, shape, python_based):
@pytest.mark.parametrize("python_based", [True, False])
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_gru_parallel_env(self, python_based, parallel, heterogeneous):
@pytest.mark.parametrize("within", [False, True])
def test_gru_parallel_env(self, python_based, parallel, heterogeneous, within):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

torch.manual_seed(0)
num_workers = 3

device = "cuda" if torch.cuda.device_count() else "cpu"
# tests that hidden states are carried over with parallel envs
Expand All @@ -1347,30 +1377,42 @@ def test_gru_parallel_env(self, python_based, parallel, heterogeneous):
python_based=python_based,
)

def create_transformed_env():
primer = gru_module.make_tensordict_primer()
env = DiscreteActionVecMockEnv(
categorical_action_encoding=True, device=device
if within:

def create_transformed_env():
primer = gru_module.make_tensordict_primer()
env = DiscreteActionVecMockEnv(
categorical_action_encoding=True, device=device
)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(primer)
return env

else:
create_transformed_env = functools.partial(
DiscreteActionVecMockEnv,
categorical_action_encoding=True,
device=device,
)
env = TransformedEnv(env)
env.append_transform(InitTracker())
env.append_transform(primer)
return env

if parallel:
cls = ParallelEnv
else:
cls = SerialEnv
if heterogeneous:
create_transformed_env = [
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env) for _ in range(num_workers)
]

env = cls(
env: ParallelEnv | SerialEnv = cls(
create_env_fn=create_transformed_env,
num_workers=2,
num_workers=num_workers,
)
if not within:
primer = gru_module.make_tensordict_primer()
env = env.append_transform(InitTracker())
env.append_transform(primer)

mlp = TensorDictModule(
MLP(
Expand All @@ -1396,6 +1438,19 @@ def create_transformed_env():
data = env.rollout(10, actor, break_when_any_done=break_when_any_done)
assert (data.get("recurrent_state") != 0.0).any()
assert (data.get(("next", "recurrent_state")) != 0.0).all()
return data

@pytest.mark.parametrize("python_based", [True, False])
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_gru_parallel_within(self, python_based, parallel, heterogeneous):
out_within = self.test_gru_parallel_env(
python_based, parallel, heterogeneous, within=True
)
out_not_within = self.test_gru_parallel_env(
python_based, parallel, heterogeneous, within=False
)
assert_close(out_within, out_not_within)

@pytest.mark.skipif(
not _has_functorch, reason="vmap can only be used with functorch"
Expand Down
45 changes: 36 additions & 9 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7408,7 +7408,7 @@ def make_env():
def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
env = TransformedEnv(
maybe_fork_ParallelEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(mykey=Unbounded([2, 4])),
TensorDictPrimer(mykey=Unbounded([4])),
)
try:
check_env_specs(env)
Expand All @@ -7423,11 +7423,39 @@ def test_trans_parallel_env_check(self, maybe_fork_ParallelEnv):
pass

@pytest.mark.parametrize("spec_shape", [[4], [2, 4]])
def test_trans_serial_env_check(self, spec_shape):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(mykey=Unbounded(spec_shape)),
)
@pytest.mark.parametrize("expand_specs", [True, False, None])
def test_trans_serial_env_check(self, spec_shape, expand_specs):
if expand_specs is None:
with pytest.warns(FutureWarning, match=""):
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(
mykey=Unbounded(spec_shape), expand_specs=expand_specs
),
)
env.observation_spec
elif expand_specs is True:
shape = spec_shape[:-1]
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(
Composite(mykey=Unbounded(spec_shape), shape=shape),
expand_specs=expand_specs,
),
)
else:
# If we don't expand, we can't use [4]
env = TransformedEnv(
SerialEnv(2, ContinuousActionVecMockEnv),
TensorDictPrimer(
mykey=Unbounded(spec_shape), expand_specs=expand_specs
),
)
if spec_shape == [4]:
with pytest.raises(ValueError):
env.observation_spec
return

check_env_specs(env)
assert "mykey" in env.reset().keys()
r = env.rollout(3)
Expand Down Expand Up @@ -10310,9 +10338,8 @@ def _make_transform_env(self, out_key, base_env):
transform = KLRewardTransform(actor, out_keys=out_key)
return Compose(
TensorDictPrimer(
primers={
"sample_log_prob": Unbounded(shape=base_env.action_spec.shape[:-1])
}
sample_log_prob=Unbounded(shape=base_env.action_spec.shape[:-1]),
shape=base_env.shape,
),
transform,
)
Expand Down
36 changes: 32 additions & 4 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1744,14 +1744,39 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# We keep track of which keys are present to let the worker know what
# should be passed to the env (we don't want to pass done states for instance)
next_td_keys = list(next_td_passthrough.keys(True, True))
next_shared_tensordict_parent = shared_tensordict_parent.get("next")

# We separate keys that are and are not present in the buffer here and not in step_and_maybe_reset.
# The reason we do that is that the policy may write stuff in 'next' that is not part of the specs of
# the batched env but part of the specs of a transformed batched env.
# If that is the case, `update_` will fail to find the entries to update.
# What we do instead is keeping the tensors on the side and putting them back after completing _step.
keys_to_update, keys_to_copy = zip(
*[
(key, None)
if key in next_shared_tensordict_parent.keys(True, True)
else (None, key)
for key in next_td_keys
]
)
keys_to_update = [key for key in keys_to_update if key is not None]
keys_to_copy = [key for key in keys_to_copy if key is not None]
data = [
{"next_td_passthrough_keys": next_td_keys}
{"next_td_passthrough_keys": keys_to_update}
for _ in range(self.num_workers)
]
shared_tensordict_parent.get("next").update_(
next_td_passthrough, non_blocking=self.non_blocking
)
if keys_to_update:
next_shared_tensordict_parent.update_(
next_td_passthrough,
non_blocking=self.non_blocking,
keys_to_update=keys_to_update,
)
if keys_to_copy:
next_td_passthrough = next_td_passthrough.select(*keys_to_copy)
else:
next_td_passthrough = None
else:
next_td_passthrough = None
data = [{} for _ in range(self.num_workers)]

if self._non_tensor_keys:
Expand Down Expand Up @@ -1807,6 +1832,9 @@ def select_and_clone(name, tensor):
LazyStackedTensorDict(*non_tensor_tds),
keys_to_update=self._non_tensor_keys,
)
if next_td_passthrough is not None:
out.update(next_td_passthrough)

self._sync_w2m()
if partial_steps is not None:
result = out.new_zeros(tensordict_save.shape)
Expand Down
37 changes: 30 additions & 7 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4984,6 +4984,7 @@ def __init__(
| Dict[NestedKey, float]
| Dict[NestedKey, Callable] = None,
reset_key: NestedKey | None = None,
expand_specs: bool = None,
**kwargs,
):
self.device = kwargs.pop("device", None)
Expand All @@ -4995,8 +4996,16 @@ def __init__(
)
kwargs = primers
if not isinstance(kwargs, Composite):
kwargs = Composite(kwargs)
self.primers = kwargs
shape = kwargs.pop("shape", None)
device = kwargs.pop("device", None)
if "batch_size" in kwargs.keys():
extra_kwargs = {"batch_size": kwargs.pop("batch_size")}
else:
extra_kwargs = {}
primers = Composite(kwargs, device=device, shape=shape, **extra_kwargs)
self.primers = primers
self.expand_specs = expand_specs

if random and default_value:
raise ValueError(
"Setting random to True and providing a default_value are incompatible."
Expand Down Expand Up @@ -5089,12 +5098,26 @@ def transform_observation_spec(self, observation_spec: Composite) -> Composite:
)

if self.primers.shape != observation_spec.shape:
try:
# We try to set the primer shape to the observation spec shape
self.primers.shape = observation_spec.shape
except ValueError:
# If we fail, we expand them to that shape
if self.expand_specs:
self.primers = self._expand_shape(self.primers)
elif self.expand_specs is None:
warnings.warn(
f"expand_specs wasn't specified in the {type(self).__name__} constructor. "
f"The current behaviour is that the transform will attempt to set the shape of the composite "
f"spec, and if this can't be done it will be expanded. "
f"From v0.8, a mismatched shape between the spec of the transform and the env's batch_size "
f"will raise an exception.",
category=FutureWarning,
)
try:
# We try to set the primer shape to the observation spec shape
self.primers.shape = observation_spec.shape
except ValueError:
# If we fail, we expand them to that shape
self.primers = self._expand_shape(self.primers)
else:
self.primers.shape = observation_spec.shape

device = observation_spec.device
observation_spec.update(self.primers.clone().to(device))
return observation_spec
Expand Down
Loading

0 comments on commit f4709c1

Please sign in to comment.