Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Mar 6, 2024
1 parent 6c6c19e commit 3fafe1d
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2678,16 +2678,21 @@ def test_rb_multidim(self, datatype, datadim, rbtype, storage_cls):
],
],
)
@pytest.mark.parametrize("env_device", get_default_devices())
def test_rb_multidim_collector(
self, rbtype, storage_cls, writer_cls, sampler_cls, transform
self, rbtype, storage_cls, writer_cls, sampler_cls, transform, env_device
):
from _utils_internal import CARTPOLE_VERSIONED

torch.manual_seed(0)
env = SerialEnv(2, lambda: GymEnv(CARTPOLE_VERSIONED()))
env = SerialEnv(2, lambda: GymEnv(CARTPOLE_VERSIONED()), device=env_device)
env.set_seed(0)
collector = SyncDataCollector(
env, RandomPolicy(env.action_spec), frames_per_batch=4, total_frames=16
env,
RandomPolicy(env.action_spec),
frames_per_batch=4,
total_frames=16,
device=env_device,
)
if writer_cls is TensorDictMaxValueWriter:
with pytest.raises(
Expand All @@ -2712,6 +2717,7 @@ def test_rb_multidim_collector(
rb.append_transform(t())
try:
for i, data in enumerate(collector): # noqa: B007
assert data.device == torch.device(env_device)
rb.extend(data)
if isinstance(rb, TensorDictReplayBuffer) and transform is not None:
# this should fail bc we can't set the indices after executing the transform.
Expand All @@ -2721,6 +2727,7 @@ def test_rb_multidim_collector(
rb.sample()
return
s = rb.sample()
assert s.device == torch.device("cpu")
rbtot = rb[:]
assert rbtot.shape[0] == 2
assert len(rb) == rbtot.numel()
Expand Down

0 comments on commit 3fafe1d

Please sign in to comment.