Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix collector tests where device ordinal is needed #2240

Merged
merged 1 commit into from
Jun 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@
from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictSequential

from torch import nn
from torchrl._utils import _replace_last, logger as torchrl_logger, prod, seed_generator
from torchrl._utils import (
_make_ordinal_device,
_replace_last,
logger as torchrl_logger,
prod,
seed_generator,
)
from torchrl.collectors import aSyncDataCollector, SyncDataCollector
from torchrl.collectors.collectors import (
_Interruptor,
Expand Down Expand Up @@ -285,15 +291,19 @@ def __init__(self, default_device):
self.action_spec = UnboundedContinuousTensorSpec(
(), device=self.default_device
)
assert self.device == torch.device(self.default_device)
assert self.device == _make_ordinal_device(
torch.device(self.default_device)
)
assert self.full_observation_spec is not None
assert self.full_done_spec is not None
assert self.full_state_spec is not None
assert self.full_action_spec is not None
assert self.full_reward_spec is not None

def _step(self, tensordict):
assert tensordict.device == torch.device(self.default_device)
assert tensordict.device == _make_ordinal_device(
torch.device(self.default_device)
)
with torch.device(self.default_device):
return TensorDict(
{
Expand Down Expand Up @@ -339,7 +349,9 @@ class PolicyWithDevice(TensorDictModuleBase):
default_device = "cuda:0" if torch.cuda.device_count() else "cpu"

def forward(self, tensordict):
assert tensordict.device == torch.device(self.default_device)
assert tensordict.device == _make_ordinal_device(
torch.device(self.default_device)
)
return tensordict.set("action", torch.zeros((), device=self.default_device))

@pytest.mark.parametrize("main_device", get_default_devices())
Expand Down Expand Up @@ -1436,7 +1448,7 @@ def env_fn(seed):
)
assert collector._use_buffers
batch = next(collector.iterator())
assert batch.device == torch.device(storing_device)
assert batch.device == _make_ordinal_device(torch.device(storing_device))
collector.shutdown()

collector = MultiSyncDataCollector(
Expand All @@ -1459,7 +1471,7 @@ def env_fn(seed):
cat_results="stack",
)
batch = next(collector.iterator())
assert batch.device == torch.device(storing_device)
assert batch.device == _make_ordinal_device(torch.device(storing_device))
collector.shutdown()

collector = MultiaSyncDataCollector(
Expand All @@ -1481,7 +1493,7 @@ def env_fn(seed):
],
)
batch = next(collector.iterator())
assert batch.device == torch.device(storing_device)
assert batch.device == _make_ordinal_device(torch.device(storing_device))
collector.shutdown()
del collector

Expand Down
Loading