Skip to content

Commit 19a920e

Browse files
author
Vincent Moens
authored
[BugFix] Fix update in serial / parallel env (#1866)
1 parent 80fc87f commit 19a920e

File tree

13 files changed

+278
-178
lines changed

13 files changed

+278
-178
lines changed

test/mocking_classes.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,7 +1072,7 @@ def _step(
10721072
tensordict: TensorDictBase,
10731073
) -> TensorDictBase:
10741074
action = tensordict.get(self.action_key)
1075-
self.count += action.to(torch.int).to(self.device)
1075+
self.count += action.to(dtype=torch.int, device=self.device)
10761076
tensordict = TensorDict(
10771077
source={
10781078
"observation": self.count.clone(),
@@ -1426,10 +1426,12 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
14261426
3,
14271427
)
14281428
),
1429+
device=self.device,
14291430
)
14301431

14311432
self.unbatched_action_spec = CompositeSpec(
14321433
lazy=action_specs,
1434+
device=self.device,
14331435
)
14341436
self.unbatched_reward_spec = CompositeSpec(
14351437
{
@@ -1441,7 +1443,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
14411443
},
14421444
shape=(self.n_nested_dim,),
14431445
)
1444-
}
1446+
},
1447+
device=self.device,
14451448
)
14461449
self.unbatched_done_spec = CompositeSpec(
14471450
{
@@ -1455,7 +1458,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
14551458
},
14561459
shape=(self.n_nested_dim,),
14571460
)
1458-
}
1461+
},
1462+
device=self.device,
14591463
)
14601464

14611465
self.action_spec = self.unbatched_action_spec.expand(
@@ -1488,7 +1492,8 @@ def get_agent_obs_spec(self, i):
14881492
"lidar": lidar,
14891493
"vector": vector_3d,
14901494
"tensor_0": tensor_0,
1491-
}
1495+
},
1496+
device=self.device,
14921497
)
14931498
elif i == 1:
14941499
return CompositeSpec(
@@ -1497,15 +1502,17 @@ def get_agent_obs_spec(self, i):
14971502
"lidar": lidar,
14981503
"vector": vector_2d,
14991504
"tensor_1": tensor_1,
1500-
}
1505+
},
1506+
device=self.device,
15011507
)
15021508
elif i == 2:
15031509
return CompositeSpec(
15041510
{
15051511
"camera": camera,
15061512
"vector": vector_2d,
15071513
"tensor_2": tensor_2,
1508-
}
1514+
},
1515+
device=self.device,
15091516
)
15101517
else:
15111518
raise ValueError(f"Index {i} undefined for index 3")

test/test_collector.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1675,8 +1675,12 @@ def test_maxframes_error():
16751675
@pytest.mark.parametrize("policy_device", [None, *get_available_devices()])
16761676
@pytest.mark.parametrize("env_device", [None, *get_available_devices()])
16771677
@pytest.mark.parametrize("storing_device", [None, *get_available_devices()])
1678+
@pytest.mark.parametrize("parallel", [False, True])
16781679
def test_reset_heterogeneous_envs(
1679-
policy_device: torch.device, env_device: torch.device, storing_device: torch.device
1680+
policy_device: torch.device,
1681+
env_device: torch.device,
1682+
storing_device: torch.device,
1683+
parallel,
16801684
):
16811685
if (
16821686
policy_device is not None
@@ -1686,9 +1690,13 @@ def test_reset_heterogeneous_envs(
16861690
env_device = torch.device("cpu") # explicit mapping
16871691
elif env_device is not None and env_device.type == "cuda" and policy_device is None:
16881692
policy_device = torch.device("cpu")
1689-
env1 = lambda: TransformedEnv(CountingEnv(), StepCounter(2))
1690-
env2 = lambda: TransformedEnv(CountingEnv(), StepCounter(3))
1691-
env = SerialEnv(2, [env1, env2], device=env_device)
1693+
env1 = lambda: TransformedEnv(CountingEnv(device="cpu"), StepCounter(2))
1694+
env2 = lambda: TransformedEnv(CountingEnv(device="cpu"), StepCounter(3))
1695+
if parallel:
1696+
cls = ParallelEnv
1697+
else:
1698+
cls = SerialEnv
1699+
env = cls(2, [env1, env2], device=env_device)
16921700
collector = SyncDataCollector(
16931701
env,
16941702
RandomPolicy(env.action_spec),
@@ -1705,7 +1713,7 @@ def test_reset_heterogeneous_envs(
17051713
assert (
17061714
data[0]["next", "truncated"].squeeze()
17071715
== torch.tensor([False, True], device=data_device).repeat(25)[:50]
1708-
).all(), data[0]["next", "truncated"][:10]
1716+
).all(), data[0]["next", "truncated"]
17091717
assert (
17101718
data[1]["next", "truncated"].squeeze()
17111719
== torch.tensor([False, False, True], device=data_device).repeat(17)[:50]

test/test_env.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2095,7 +2095,10 @@ def test_rollout_policy(self, batch_size, rollout_steps, count):
20952095

20962096
@pytest.mark.parametrize("batch_size", [(1, 2)])
20972097
@pytest.mark.parametrize("env_type", ["serial", "parallel"])
2098-
def test_vec_env(self, batch_size, env_type, rollout_steps=4, n_workers=2):
2098+
@pytest.mark.parametrize("break_when_any_done", [False, True])
2099+
def test_vec_env(
2100+
self, batch_size, env_type, break_when_any_done, rollout_steps=4, n_workers=2
2101+
):
20992102
env_fun = lambda: HeterogeneousCountingEnv(batch_size=batch_size)
21002103
if env_type == "serial":
21012104
vec_env = SerialEnv(n_workers, env_fun)
@@ -2109,7 +2112,7 @@ def test_vec_env(self, batch_size, env_type, rollout_steps=4, n_workers=2):
21092112
rollout_steps,
21102113
policy=policy,
21112114
return_contiguous=False,
2112-
break_when_any_done=False,
2115+
break_when_any_done=break_when_any_done,
21132116
)
21142117
td = dense_stack_tds(td)
21152118
for i in range(env_fun().n_nested_dim):

test/test_tensordictmodules.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
CompositeSpec,
2222
UnboundedContinuousTensorSpec,
2323
)
24+
from torchrl.envs import EnvCreator, SerialEnv
2425
from torchrl.envs.utils import set_exploration_type, step_mdp
2526
from torchrl.modules import (
2627
AdditiveGaussianWrapper,
@@ -1782,9 +1783,12 @@ def test_multi_consecutive(self, shape, python_based):
17821783
)
17831784

17841785
@pytest.mark.parametrize("python_based", [True, False])
1785-
def test_lstm_parallel_env(self, python_based):
1786+
@pytest.mark.parametrize("parallel", [True, False])
1787+
@pytest.mark.parametrize("heterogeneous", [True, False])
1788+
def test_lstm_parallel_env(self, python_based, parallel, heterogeneous):
17861789
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv
17871790

1791+
torch.manual_seed(0)
17881792
device = "cuda" if torch.cuda.device_count() else "cpu"
17891793
# tests that hidden states are carried over with parallel envs
17901794
lstm_module = LSTMModule(
@@ -1796,6 +1800,10 @@ def test_lstm_parallel_env(self, python_based):
17961800
device=device,
17971801
python_based=python_based,
17981802
)
1803+
if parallel:
1804+
cls = ParallelEnv
1805+
else:
1806+
cls = SerialEnv
17991807

18001808
def create_transformed_env():
18011809
primer = lstm_module.make_tensordict_primer()
@@ -1807,7 +1815,12 @@ def create_transformed_env():
18071815
env.append_transform(primer)
18081816
return env
18091817

1810-
env = ParallelEnv(
1818+
if heterogeneous:
1819+
create_transformed_env = [
1820+
EnvCreator(create_transformed_env),
1821+
EnvCreator(create_transformed_env),
1822+
]
1823+
env = cls(
18111824
create_env_fn=create_transformed_env,
18121825
num_workers=2,
18131826
)
@@ -2109,9 +2122,13 @@ def test_multi_consecutive(self, shape, python_based):
21092122
)
21102123

21112124
@pytest.mark.parametrize("python_based", [True, False])
2112-
def test_gru_parallel_env(self, python_based):
2125+
@pytest.mark.parametrize("parallel", [True, False])
2126+
@pytest.mark.parametrize("heterogeneous", [True, False])
2127+
def test_gru_parallel_env(self, python_based, parallel, heterogeneous):
21132128
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv
21142129

2130+
torch.manual_seed(0)
2131+
21152132
device = "cuda" if torch.cuda.device_count() else "cpu"
21162133
# tests that hidden states are carried over with parallel envs
21172134
gru_module = GRUModule(
@@ -2134,7 +2151,17 @@ def create_transformed_env():
21342151
env.append_transform(primer)
21352152
return env
21362153

2137-
env = ParallelEnv(
2154+
if parallel:
2155+
cls = ParallelEnv
2156+
else:
2157+
cls = SerialEnv
2158+
if heterogeneous:
2159+
create_transformed_env = [
2160+
EnvCreator(create_transformed_env),
2161+
EnvCreator(create_transformed_env),
2162+
]
2163+
2164+
env = cls(
21382165
create_env_fn=create_transformed_env,
21392166
num_workers=2,
21402167
)

torchrl/collectors/collectors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,7 @@ def rollout(self) -> TensorDictBase:
10771077

10781078
if self.storing_device is not None:
10791079
tensordicts.append(
1080-
self._shuttle.to(self.storing_device, non_blocking=False)
1080+
self._shuttle.to(self.storing_device, non_blocking=True)
10811081
)
10821082
else:
10831083
tensordicts.append(self._shuttle)

torchrl/data/replay_buffers/storages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -894,7 +894,7 @@ def get(self, index: Union[int, Sequence[int], slice]) -> Any:
894894
# to be deprecated in v0.4
895895
def map_device(tensor):
896896
if tensor.device != self.device:
897-
return tensor.to(self.device, non_blocking=False)
897+
return tensor.to(self.device, non_blocking=True)
898898
return tensor
899899

900900
if is_tensor_collection(result):

torchrl/data/rlhf/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ def get_dataloader(
394394
)
395395
out = TensorDictReplayBuffer(
396396
storage=TensorStorage(data),
397-
collate_fn=lambda x: x.as_tensor().to(device, non_blocking=False),
397+
collate_fn=lambda x: x.as_tensor().to(device, non_blocking=True),
398398
sampler=SamplerWithoutReplacement(drop_last=True),
399399
batch_size=batch_size,
400400
prefetch=prefetch,

0 commit comments

Comments
 (0)