Skip to content

Commit

Permalink
[Refactor] Use filter_empty=False in apply for params (#1882)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 6, 2024
1 parent e53eb73 commit 1fe745a
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 10 deletions.
10 changes: 7 additions & 3 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,14 @@ def _map_to_device_params(param, device):

# Create a stateless policy, then populate this copy with params on device
with param_and_buf.apply(
functools.partial(_map_to_device_params, device="meta")
functools.partial(_map_to_device_params, device="meta"),
filter_empty=False,
).to_module(policy):
policy = deepcopy(policy)

param_and_buf.apply(
functools.partial(_map_to_device_params, device=self.policy_device)
functools.partial(_map_to_device_params, device=self.policy_device),
filter_empty=False,
).to_module(policy)

return policy, get_weights_fn
Expand Down Expand Up @@ -1495,7 +1497,9 @@ def map_weight(
weight = nn.Parameter(weight, requires_grad=False)
return weight

local_policy_weights = TensorDictParams(policy_weights.apply(map_weight))
local_policy_weights = TensorDictParams(
policy_weights.apply(map_weight, filter_empty=False)
)

def _get_weight_fn(weights=policy_weights):
# This function will give the local_policy_weight the original weights.
Expand Down
6 changes: 4 additions & 2 deletions torchrl/envs/transforms/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def __init__(

# check that the model has parameters
params = TensorDict.from_module(actor)
with params.apply(_stateless_param, device="meta").to_module(actor):
with params.apply(
_stateless_param, device="meta", filter_empty=False
).to_module(actor):
# copy a stateless actor
self.__dict__["functional_actor"] = deepcopy(actor)
# we need to register these params as buffer to have `to` and similar
Expand All @@ -129,7 +131,7 @@ def _make_detached_param(x):
)
return x.clone()

self.frozen_params = params.apply(_make_detached_param)
self.frozen_params = params.apply(_make_detached_param, filter_empty=False)
if requires_grad:
# includes the frozen params/buffers in the module parameters/buffers
self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True)
Expand Down
10 changes: 7 additions & 3 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ def _compare_and_expand(param):

params = TensorDictParams(
params.apply(
_compare_and_expand, batch_size=[expand_dim, *params.shape]
_compare_and_expand,
batch_size=[expand_dim, *params.shape],
filter_empty=False,
),
no_convert=True,
)
Expand All @@ -283,7 +285,7 @@ def _compare_and_expand(param):
# set the functional module: we need to convert the params to non-differentiable params
# otherwise they will appear twice in parameters
with params.apply(
self._make_meta_params, device=torch.device("meta")
self._make_meta_params, device=torch.device("meta"), filter_empty=False
).to_module(module):
# avoid buffers and params being exposed
self.__dict__[module_name] = deepcopy(module)
Expand All @@ -293,7 +295,9 @@ def _compare_and_expand(param):
# if create_target_params:
# we create a TensorDictParams to keep the target params as Buffer instances
target_params = TensorDictParams(
params.apply(_make_target_param(clone=create_target_params)),
params.apply(
_make_target_param(clone=create_target_params), filter_empty=False
),
no_convert=True,
)
setattr(self, name_params_target + "_params", target_params)
Expand Down
4 changes: 3 additions & 1 deletion torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def __init__(

actor_critic = ActorCriticWrapper(actor_network, value_network)
params = TensorDict.from_module(actor_critic)
params_meta = params.apply(self._make_meta_params, device=torch.device("meta"))
params_meta = params.apply(
self._make_meta_params, device=torch.device("meta"), filter_empty=False
)
with params_meta.to_module(actor_critic):
self.__dict__["actor_critic"] = deepcopy(actor_critic)

Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/multiagent/qmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def __init__(
global_value_network = SafeSequential(local_value_network, mixer_network)
params = TensorDict.from_module(global_value_network)
with params.apply(
self._make_meta_params, device=torch.device("meta")
self._make_meta_params, device=torch.device("meta"), filter_empty=False
).to_module(global_value_network):
self.__dict__["global_value_network"] = deepcopy(global_value_network)

Expand Down

0 comments on commit 1fe745a

Please sign in to comment.