Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Use filter_empty=False in apply for params #1882

Merged
merged 3 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,14 @@ def _map_to_device_params(param, device):

# Create a stateless policy, then populate this copy with params on device
with param_and_buf.apply(
functools.partial(_map_to_device_params, device="meta")
functools.partial(_map_to_device_params, device="meta"),
filter_empty=False,
).to_module(policy):
policy = deepcopy(policy)

param_and_buf.apply(
functools.partial(_map_to_device_params, device=self.policy_device)
functools.partial(_map_to_device_params, device=self.policy_device),
filter_empty=False,
).to_module(policy)

return policy, get_weights_fn
Expand Down Expand Up @@ -1495,7 +1497,9 @@ def map_weight(
weight = nn.Parameter(weight, requires_grad=False)
return weight

local_policy_weights = TensorDictParams(policy_weights.apply(map_weight))
local_policy_weights = TensorDictParams(
policy_weights.apply(map_weight, filter_empty=False)
)

def _get_weight_fn(weights=policy_weights):
# This function will give the local_policy_weight the original weights.
Expand Down
6 changes: 4 additions & 2 deletions torchrl/envs/transforms/rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def __init__(

# check that the model has parameters
params = TensorDict.from_module(actor)
with params.apply(_stateless_param, device="meta").to_module(actor):
with params.apply(
_stateless_param, device="meta", filter_empty=False
).to_module(actor):
# copy a stateless actor
self.__dict__["functional_actor"] = deepcopy(actor)
# we need to register these params as buffer to have `to` and similar
Expand All @@ -129,7 +131,7 @@ def _make_detached_param(x):
)
return x.clone()

self.frozen_params = params.apply(_make_detached_param)
self.frozen_params = params.apply(_make_detached_param, filter_empty=False)
if requires_grad:
# includes the frozen params/buffers in the module parameters/buffers
self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True)
Expand Down
10 changes: 7 additions & 3 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ def _compare_and_expand(param):

params = TensorDictParams(
params.apply(
_compare_and_expand, batch_size=[expand_dim, *params.shape]
_compare_and_expand,
batch_size=[expand_dim, *params.shape],
filter_empty=False,
),
no_convert=True,
)
Expand All @@ -283,7 +285,7 @@ def _compare_and_expand(param):
# set the functional module: we need to convert the params to non-differentiable params
# otherwise they will appear twice in parameters
with params.apply(
self._make_meta_params, device=torch.device("meta")
self._make_meta_params, device=torch.device("meta"), filter_empty=False
).to_module(module):
# avoid buffers and params being exposed
self.__dict__[module_name] = deepcopy(module)
Expand All @@ -293,7 +295,9 @@ def _compare_and_expand(param):
# if create_target_params:
# we create a TensorDictParams to keep the target params as Buffer instances
target_params = TensorDictParams(
params.apply(_make_target_param(clone=create_target_params)),
params.apply(
_make_target_param(clone=create_target_params), filter_empty=False
),
no_convert=True,
)
setattr(self, name_params_target + "_params", target_params)
Expand Down
4 changes: 3 additions & 1 deletion torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ def __init__(

actor_critic = ActorCriticWrapper(actor_network, value_network)
params = TensorDict.from_module(actor_critic)
params_meta = params.apply(self._make_meta_params, device=torch.device("meta"))
params_meta = params.apply(
self._make_meta_params, device=torch.device("meta"), filter_empty=False
)
with params_meta.to_module(actor_critic):
self.__dict__["actor_critic"] = deepcopy(actor_critic)

Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/multiagent/qmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def __init__(
global_value_network = SafeSequential(local_value_network, mixer_network)
params = TensorDict.from_module(global_value_network)
with params.apply(
self._make_meta_params, device=torch.device("meta")
self._make_meta_params, device=torch.device("meta"), filter_empty=False
).to_module(global_value_network):
self.__dict__["global_value_network"] = deepcopy(global_value_network)

Expand Down
Loading