Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Robust sync for non_blocking=True #2034

Merged
merged 24 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 42 additions & 12 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from torchrl.collectors.utils import split_trajectories
from torchrl.data.tensor_specs import TensorSpec
from torchrl.data.utils import CloudpickleWrapper, DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.envs.common import _do_nothing, EnvBase
from torchrl.envs.transforms import StepCounter, TransformedEnv
from torchrl.envs.utils import (
_aggregate_end_of_traj,
Expand Down Expand Up @@ -472,8 +472,45 @@ def __init__(
)

self.storing_device = storing_device
if self.storing_device is not None and self.storing_device.type != "cuda":
# Cuda handles sync
if torch.cuda.is_available():
self._sync_storage = torch.cuda.synchronize
elif torch.backends.mps.is_available():
self._sync_storage = torch.mps.synchronize
elif self.storing_device.type == "cpu":
self._sync_storage = _do_nothing
else:
raise RuntimeError("Non supported device")
else:
self._sync_storage = _do_nothing

self.env_device = env_device
if self.env_device is not None and self.env_device.type != "cuda":
# Cuda handles sync
if torch.cuda.is_available():
self._sync_env = torch.cuda.synchronize
elif torch.backends.mps.is_available():
self._sync_env = torch.mps.synchronize
elif self.env_device.type == "cpu":
self._sync_env = _do_nothing
else:
raise RuntimeError("Non supported device")
else:
self._sync_env = _do_nothing
self.policy_device = policy_device
if self.policy_device is not None and self.policy_device.type != "cuda":
# Cuda handles sync
if torch.cuda.is_available():
self._sync_policy = torch.cuda.synchronize
elif torch.backends.mps.is_available():
self._sync_policy = torch.mps.synchronize
elif self.policy_device.type == "cpu":
self._sync_policy = _do_nothing
else:
raise RuntimeError("Non supported device")
else:
self._sync_policy = _do_nothing
self.device = device
# Check if we need to cast things from device to device
# If the policy has a None device and the env too, no need to cast (we don't know
Expand Down Expand Up @@ -503,7 +540,7 @@ def __init__(
if self.env_device:
self.env: EnvBase = self.env.to(self.env_device)
elif self.env.device is not None:
# we we did not receive an env device, we use the device of the env
# we did not receive an env device, we use the device of the env
self.env_device = self.env.device

# If the storing device is not the same as the policy device, we have
Expand Down Expand Up @@ -915,6 +952,7 @@ def rollout(self) -> TensorDictBase:
policy_input = self._shuttle.to(
self.policy_device, non_blocking=True
)
self._sync_policy()
elif self.policy_device is None:
# we know the tensordict has a device otherwise we would not be here
# we can pass this, clear_device_ must have been called earlier
Expand All @@ -933,6 +971,7 @@ def rollout(self) -> TensorDictBase:
if self._cast_to_env_device:
if self.env_device is not None:
env_input = self._shuttle.to(self.env_device, non_blocking=True)
self._sync_env()
elif self.env_device is None:
# we know the tensordict has a device otherwise we would not be here
# we can pass this, clear_device_ must have been called earlier
Expand All @@ -954,6 +993,7 @@ def rollout(self) -> TensorDictBase:
tensordicts.append(
self._shuttle.to(self.storing_device, non_blocking=True)
)
self._sync_storage()
else:
tensordicts.append(self._shuttle)

Expand Down Expand Up @@ -1000,16 +1040,6 @@ def rollout(self) -> TensorDictBase:
)
return self._final_rollout

@staticmethod
def _update_device_wise(tensor0, tensor1):
# given 2 tensors, returns tensor0 if their identity matches,
# or a copy of tensor1 on the device of tensor0 otherwise
if tensor1 is None or tensor1 is tensor0:
return tensor0
if tensor1.device == tensor0.device:
return tensor1
return tensor1.to(tensor0.device, non_blocking=True)

@torch.no_grad()
def reset(self, index=None, **kwargs) -> None:
"""Resets the environments to a new initial state."""
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ def get(self, index: Union[int, Sequence[int], slice]) -> Any:
# to be deprecated in v0.4
def map_device(tensor):
if tensor.device != self.device:
return tensor.to(self.device, non_blocking=True)
return tensor.to(self.device, non_blocking=False)
return tensor

if is_tensor_collection(result):
Expand Down
Loading
Loading