Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix update in serial / parallel env #1866

Merged
merged 35 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ def _step(
tensordict: TensorDictBase,
) -> TensorDictBase:
action = tensordict.get(self.action_key)
self.count += action.to(torch.int).to(self.device)
self.count += action.to(dtype=torch.int, device=self.device)
tensordict = TensorDict(
source={
"observation": self.count.clone(),
Expand Down Expand Up @@ -1426,10 +1426,12 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
3,
)
),
device=self.device,
)

self.unbatched_action_spec = CompositeSpec(
lazy=action_specs,
device=self.device,
)
self.unbatched_reward_spec = CompositeSpec(
{
Expand All @@ -1441,7 +1443,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
},
shape=(self.n_nested_dim,),
)
}
},
device=self.device,
)
self.unbatched_done_spec = CompositeSpec(
{
Expand All @@ -1455,7 +1458,8 @@ def __init__(self, max_steps: int = 5, start_val: int = 0, **kwargs):
},
shape=(self.n_nested_dim,),
)
}
},
device=self.device,
)

self.action_spec = self.unbatched_action_spec.expand(
Expand Down Expand Up @@ -1488,7 +1492,8 @@ def get_agent_obs_spec(self, i):
"lidar": lidar,
"vector": vector_3d,
"tensor_0": tensor_0,
}
},
device=self.device,
)
elif i == 1:
return CompositeSpec(
Expand All @@ -1497,15 +1502,17 @@ def get_agent_obs_spec(self, i):
"lidar": lidar,
"vector": vector_2d,
"tensor_1": tensor_1,
}
},
device=self.device,
)
elif i == 2:
return CompositeSpec(
{
"camera": camera,
"vector": vector_2d,
"tensor_2": tensor_2,
}
},
device=self.device,
)
else:
raise ValueError(f"Index {i} undefined for index 3")
Expand Down
18 changes: 13 additions & 5 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,8 +1675,12 @@ def test_maxframes_error():
@pytest.mark.parametrize("policy_device", [None, *get_available_devices()])
@pytest.mark.parametrize("env_device", [None, *get_available_devices()])
@pytest.mark.parametrize("storing_device", [None, *get_available_devices()])
@pytest.mark.parametrize("parallel", [False, True])
def test_reset_heterogeneous_envs(
policy_device: torch.device, env_device: torch.device, storing_device: torch.device
policy_device: torch.device,
env_device: torch.device,
storing_device: torch.device,
parallel,
):
if (
policy_device is not None
Expand All @@ -1686,9 +1690,13 @@ def test_reset_heterogeneous_envs(
env_device = torch.device("cpu") # explicit mapping
elif env_device is not None and env_device.type == "cuda" and policy_device is None:
policy_device = torch.device("cpu")
env1 = lambda: TransformedEnv(CountingEnv(), StepCounter(2))
env2 = lambda: TransformedEnv(CountingEnv(), StepCounter(3))
env = SerialEnv(2, [env1, env2], device=env_device)
env1 = lambda: TransformedEnv(CountingEnv(device="cpu"), StepCounter(2))
env2 = lambda: TransformedEnv(CountingEnv(device="cpu"), StepCounter(3))
if parallel:
cls = ParallelEnv
else:
cls = SerialEnv
env = cls(2, [env1, env2], device=env_device)
collector = SyncDataCollector(
env,
RandomPolicy(env.action_spec),
Expand All @@ -1705,7 +1713,7 @@ def test_reset_heterogeneous_envs(
assert (
data[0]["next", "truncated"].squeeze()
== torch.tensor([False, True], device=data_device).repeat(25)[:50]
).all(), data[0]["next", "truncated"][:10]
).all(), data[0]["next", "truncated"]
assert (
data[1]["next", "truncated"].squeeze()
== torch.tensor([False, False, True], device=data_device).repeat(17)[:50]
Expand Down
7 changes: 5 additions & 2 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2095,7 +2095,10 @@ def test_rollout_policy(self, batch_size, rollout_steps, count):

@pytest.mark.parametrize("batch_size", [(1, 2)])
@pytest.mark.parametrize("env_type", ["serial", "parallel"])
def test_vec_env(self, batch_size, env_type, rollout_steps=4, n_workers=2):
@pytest.mark.parametrize("break_when_any_done", [False, True])
def test_vec_env(
self, batch_size, env_type, break_when_any_done, rollout_steps=4, n_workers=2
):
env_fun = lambda: HeterogeneousCountingEnv(batch_size=batch_size)
if env_type == "serial":
vec_env = SerialEnv(n_workers, env_fun)
Expand All @@ -2109,7 +2112,7 @@ def test_vec_env(self, batch_size, env_type, rollout_steps=4, n_workers=2):
rollout_steps,
policy=policy,
return_contiguous=False,
break_when_any_done=False,
break_when_any_done=break_when_any_done,
)
td = dense_stack_tds(td)
for i in range(env_fun().n_nested_dim):
Expand Down
35 changes: 31 additions & 4 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CompositeSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import EnvCreator, SerialEnv
from torchrl.envs.utils import set_exploration_type, step_mdp
from torchrl.modules import (
AdditiveGaussianWrapper,
Expand Down Expand Up @@ -1782,9 +1783,12 @@ def test_multi_consecutive(self, shape, python_based):
)

@pytest.mark.parametrize("python_based", [True, False])
def test_lstm_parallel_env(self, python_based):
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_lstm_parallel_env(self, python_based, parallel, heterogeneous):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

torch.manual_seed(0)
device = "cuda" if torch.cuda.device_count() else "cpu"
# tests that hidden states are carried over with parallel envs
lstm_module = LSTMModule(
Expand All @@ -1796,6 +1800,10 @@ def test_lstm_parallel_env(self, python_based):
device=device,
python_based=python_based,
)
if parallel:
cls = ParallelEnv
else:
cls = SerialEnv

def create_transformed_env():
primer = lstm_module.make_tensordict_primer()
Expand All @@ -1807,7 +1815,12 @@ def create_transformed_env():
env.append_transform(primer)
return env

env = ParallelEnv(
if heterogeneous:
create_transformed_env = [
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env),
]
env = cls(
create_env_fn=create_transformed_env,
num_workers=2,
)
Expand Down Expand Up @@ -2109,9 +2122,13 @@ def test_multi_consecutive(self, shape, python_based):
)

@pytest.mark.parametrize("python_based", [True, False])
def test_gru_parallel_env(self, python_based):
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("heterogeneous", [True, False])
def test_gru_parallel_env(self, python_based, parallel, heterogeneous):
from torchrl.envs import InitTracker, ParallelEnv, TransformedEnv

torch.manual_seed(0)

device = "cuda" if torch.cuda.device_count() else "cpu"
# tests that hidden states are carried over with parallel envs
gru_module = GRUModule(
Expand All @@ -2134,7 +2151,17 @@ def create_transformed_env():
env.append_transform(primer)
return env

env = ParallelEnv(
if parallel:
cls = ParallelEnv
else:
cls = SerialEnv
if heterogeneous:
create_transformed_env = [
EnvCreator(create_transformed_env),
EnvCreator(create_transformed_env),
]

env = cls(
create_env_fn=create_transformed_env,
num_workers=2,
)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,7 +1077,7 @@ def rollout(self) -> TensorDictBase:

if self.storing_device is not None:
tensordicts.append(
self._shuttle.to(self.storing_device, non_blocking=False)
self._shuttle.to(self.storing_device, non_blocking=True)
)
else:
tensordicts.append(self._shuttle)
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,7 @@ def get(self, index: Union[int, Sequence[int], slice]) -> Any:
# to be deprecated in v0.4
def map_device(tensor):
if tensor.device != self.device:
return tensor.to(self.device, non_blocking=False)
return tensor.to(self.device, non_blocking=True)
return tensor

if is_tensor_collection(result):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/rlhf/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def get_dataloader(
)
out = TensorDictReplayBuffer(
storage=TensorStorage(data),
collate_fn=lambda x: x.as_tensor().to(device, non_blocking=False),
collate_fn=lambda x: x.as_tensor().to(device, non_blocking=True),
sampler=SamplerWithoutReplacement(drop_last=True),
batch_size=batch_size,
prefetch=prefetch,
Expand Down
Loading
Loading