Skip to content

[Feature] Fine grained DeviceCastTransform #2041

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 233 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9331,7 +9331,239 @@ def test_transform_inverse(self):
return


class TestDeviceCastTransform(TransformBase):
class TestDeviceCastTransformPart(TransformBase):
@pytest.mark.parametrize("in_keys", ["observation"])
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_single_trans_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
env = ContinuousActionVecMockEnv(device="cpu:0")
env = TransformedEnv(
env,
DeviceCastTransform(
"cpu:1",
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
),
)
assert env.device is None
check_env_specs(env)

@pytest.mark.parametrize("in_keys", ["observation"])
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_serial_trans_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(device="cpu:0"),
DeviceCastTransform(
"cpu:1",
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
),
)

env = SerialEnv(2, make_env)
assert env.device is None
check_env_specs(env)

@pytest.mark.parametrize("in_keys", ["observation"])
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_parallel_trans_env_check(
self, in_keys, out_keys, in_keys_inv, out_keys_inv
):
def make_env():
return TransformedEnv(
ContinuousActionVecMockEnv(device="cpu:0"),
DeviceCastTransform(
"cpu:1",
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
),
)

env = ParallelEnv(
2,
make_env,
mp_start_method="fork" if not torch.cuda.is_available() else "spawn",
)
assert env.device is None
try:
check_env_specs(env)
finally:
env.close()

@pytest.mark.parametrize("in_keys", ["observation"])
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_trans_serial_env_check(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
def make_env():
return ContinuousActionVecMockEnv(device="cpu:0")

env = TransformedEnv(
SerialEnv(2, make_env),
DeviceCastTransform(
"cpu:1",
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
),
)
assert env.device is None
check_env_specs(env)

@pytest.mark.parametrize("in_keys", ["observation"])
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_trans_parallel_env_check(
self, in_keys, out_keys, in_keys_inv, out_keys_inv
):
def make_env():
return ContinuousActionVecMockEnv(device="cpu:0")

env = TransformedEnv(
ParallelEnv(
2,
make_env,
mp_start_method="fork" if not torch.cuda.is_available() else "spawn",
),
DeviceCastTransform(
"cpu:1",
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
),
)
assert env.device is None
try:
check_env_specs(env)
finally:
env.close()

def test_transform_no_env(self):
t = DeviceCastTransform("cpu:1", "cpu:0", in_keys=["a"], out_keys=["b"])
td = TensorDict({"a": torch.randn((), device="cpu:0")}, [], device="cpu:0")
tdt = t._call(td)
assert tdt.device is None

@pytest.mark.parametrize("in_keys", ["observation"])
@pytest.mark.parametrize("out_keys", [None, ["obs_device"]])
@pytest.mark.parametrize("in_keys_inv", ["action"])
@pytest.mark.parametrize("out_keys_inv", [None, ["action_device"]])
def test_transform_env(self, in_keys, out_keys, in_keys_inv, out_keys_inv):
env = ContinuousActionVecMockEnv(device="cpu:0")
env = TransformedEnv(
env,
DeviceCastTransform(
"cpu:1",
in_keys=in_keys,
out_keys=out_keys,
in_keys_inv=in_keys_inv,
out_keys_inv=out_keys_inv,
),
)
assert env.device is None
assert env.transform.device == torch.device("cpu:1")
assert env.transform.orig_device == torch.device("cpu:0")

def test_transform_compose(self):
t = Compose(
DeviceCastTransform(
"cpu:1",
"cpu:0",
in_keys=["a"],
out_keys=["b"],
in_keys_inv=["c"],
out_keys_inv=["d"],
)
)

td = TensorDict(
{
"a": torch.randn((), device="cpu:0"),
"c": torch.randn((), device="cpu:1"),
},
[],
device="cpu:0",
)
tdt = t._call(td)
tdit = t._inv_call(td)

assert tdt.device is None
assert tdit.device is None

def test_transform_model(self):
t = nn.Sequential(
Compose(
DeviceCastTransform(
"cpu:1",
"cpu:0",
in_keys=["a"],
out_keys=["b"],
in_keys_inv=["c"],
out_keys_inv=["d"],
)
)
)
td = TensorDict(
{
"a": torch.randn((), device="cpu:0"),
"c": torch.randn((), device="cpu:1"),
},
[],
device="cpu:0",
)
tdt = t(td)

assert tdt.device is None

@pytest.mark.parametrize("rbclass", [ReplayBuffer, TensorDictReplayBuffer])
@pytest.mark.parametrize("storage", [LazyTensorStorage])
def test_transform_rb(self, rbclass, storage):
# we don't test casting to cuda on Memmap tensor storage since it's discouraged
t = Compose(
DeviceCastTransform(
"cpu:1",
"cpu:0",
in_keys=["a"],
out_keys=["b"],
in_keys_inv=["c"],
out_keys_inv=["d"],
)
)
rb = rbclass(storage=storage(max_size=20, device="auto"))
rb.append_transform(t)
td = TensorDict(
{
"a": torch.randn((), device="cpu:0"),
"c": torch.randn((), device="cpu:1"),
},
[],
device="cpu:0",
)
rb.add(td)
assert rb._storage._storage.device is None
assert rb.sample(4).device is None

def test_transform_inverse(self):
# Tested before
return


class TestDeviceCastTransformWhole(TransformBase):
def test_single_trans_env_check(self):
env = ContinuousActionVecMockEnv(device="cpu:0")
env = TransformedEnv(env, DeviceCastTransform("cpu:1"))
Expand Down
7 changes: 5 additions & 2 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def decorator(func):

def clear_device_(self):
"""A no-op for all leaf specs (which must have a device)."""
pass
return self

def encode(
self, val: Union[np.ndarray, torch.Tensor], *, ignore_device=False
Expand Down Expand Up @@ -866,6 +866,7 @@ def clear_device_(self):
"""Clears the device of the CompositeSpec."""
for spec in self._specs:
spec.clear_device_()
return self

def __getitem__(self, item):
is_key = isinstance(item, str) or (
Expand Down Expand Up @@ -3594,8 +3595,10 @@ def device(self, device: DEVICE_TYPING):

def clear_device_(self):
"""Clears the device of the CompositeSpec."""
for spec in self._specs:
self._device = None
for spec in self._specs.values():
spec.clear_device_()
return self

def __getitem__(self, idx):
"""Indexes the current CompositeSpec based on the provided index."""
Expand Down
Loading