Skip to content

Commit

Permalink
Merge branch 'release/0.3.2' of https://github.com/pytorch/rl into re…
Browse files Browse the repository at this point in the history
…lease/0.3.2
  • Loading branch information
vmoens committed Apr 7, 2024
2 parents 10b72f3 + 78c92c0 commit 6d5980b
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 12 deletions.
2 changes: 1 addition & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -6191,7 +6191,7 @@ def zero_param(p):
if isinstance(p, nn.Parameter):
p.data.zero_()

params.apply(zero_param)
params.apply(zero_param, filter_empty=True)

# assert len(list(floss_fn.parameters())) == 0
with params.to_module(loss_fn):
Expand Down
4 changes: 2 additions & 2 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2417,7 +2417,7 @@ def test_transform_rb(self, include_forward, rbclass):
)
rb.extend(td)

storage = rb._storage._storage[:]
storage = rb._storage[:]

assert storage["action"].shape[-1] == 7
td = rb.sample(10)
Expand Down Expand Up @@ -8734,7 +8734,7 @@ def test_transform_rb(self, create_copy, inverse, rbclass):
rb.append_transform(t)
rb.extend(tensordict)

assert "a" in rb._storage._storage.keys()
assert "a" in rb._storage[:].keys()
sample = rb.sample(2)
if create_copy:
assert "a" in sample.keys()
Expand Down
1 change: 0 additions & 1 deletion torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,6 @@ def select_and_clone(name, tensor):
nested_keys=True,
filter_empty=True,
)
del out["next"]

if out.device != device:
if device is None:
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def metadata_from_env(env) -> EnvMetaData:
def fill_device_map(name, val, device_map=device_map):
device_map[name] = val.device

tensordict.named_apply(fill_device_map, nested_keys=True)
tensordict.named_apply(fill_device_map, nested_keys=True, filter_empty=True)
return EnvMetaData(
tensordict, specs, batch_size, env_str, device, batch_locked, device_map
)
Expand Down Expand Up @@ -2843,7 +2843,7 @@ def _sync_device(self):
sync_func = self.__dict__.get("_sync_device_val", None)
if sync_func is None:
device = self.device
if device.type != "cuda":
if device is not None and device.type != "cuda":
if torch.cuda.is_available():
self._sync_device_val = torch.cuda.synchronize
elif torch.backends.mps.is_available():
Expand Down
4 changes: 2 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3684,7 +3684,7 @@ def __init__(
if torch.cuda.is_available():
self._sync_device = torch.cuda.synchronize
elif torch.backends.mps.is_available():
self._sync_device = torch.cuda.synchronize
self._sync_device = torch.mps.synchronize
elif device.type == "cpu":
self._sync_device = _do_nothing
else:
Expand Down Expand Up @@ -3739,7 +3739,7 @@ def _sync_orig_device(self):
if torch.cuda.is_available():
self._sync_orig_device_val = torch.cuda.synchronize
elif torch.backends.mps.is_available():
self._sync_orig_device_val = torch.cuda.synchronize
self._sync_orig_device_val = torch.mps.synchronize
elif device.type == "cpu":
self._sync_orig_device_val = _do_nothing
else:
Expand Down
11 changes: 9 additions & 2 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def _compare_and_expand(param):
_compare_and_expand,
batch_size=[expand_dim, *param.shape],
call_on_nested=True,
filter_empty=True,
)
if not isinstance(param, nn.Parameter):
buffer = param.expand(expand_dim, *param.shape).clone()
Expand All @@ -276,6 +277,7 @@ def _compare_and_expand(param):
_compare_and_expand,
batch_size=[expand_dim, *params.shape],
call_on_nested=True,
filter_empty=True,
),
no_convert=True,
)
Expand All @@ -296,7 +298,9 @@ def _compare_and_expand(param):
# set the functional module: we need to convert the params to non-differentiable params
# otherwise they will appear twice in parameters
with params.apply(
self._make_meta_params, device=torch.device("meta")
self._make_meta_params,
device=torch.device("meta"),
filter_empty=True,
).to_module(module):
# avoid buffers and params being exposed
self.__dict__[module_name] = deepcopy(module)
Expand All @@ -306,7 +310,10 @@ def _compare_and_expand(param):
# if create_target_params:
# we create a TensorDictParams to keep the target params as Buffer instances
target_params = TensorDictParams(
params.apply(_make_target_param(clone=create_target_params)),
params.apply(
_make_target_param(clone=create_target_params),
filter_empty=True,
),
no_convert=True,
)
setattr(self, name_params_target + "_params", target_params)
Expand Down
4 changes: 3 additions & 1 deletion torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def __init__(

actor_critic = ActorCriticWrapper(actor_network, value_network)
params = TensorDict.from_module(actor_critic)
params_meta = params.apply(self._make_meta_params, device=torch.device("meta"))
params_meta = params.apply(
self._make_meta_params, device=torch.device("meta"), filter_empty=True
)
with params_meta.to_module(actor_critic):
self.__dict__["actor_critic"] = deepcopy(actor_critic)

Expand Down
4 changes: 3 additions & 1 deletion torchrl/objectives/multiagent/qmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,9 @@ def __init__(
global_value_network = SafeSequential(local_value_network, mixer_network)
params = TensorDict.from_module(global_value_network)
with params.apply(
self._make_meta_params, device=torch.device("meta")
self._make_meta_params,
device=torch.device("meta"),
filter_empty=True,
).to_module(global_value_network):
self.__dict__["global_value_network"] = deepcopy(global_value_network)

Expand Down

0 comments on commit 6d5980b

Please sign in to comment.