diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 202dcc9ead8..98040d9640e 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -316,12 +316,14 @@ def _map_to_device_params(param, device): # Create a stateless policy, then populate this copy with params on device with param_and_buf.apply( - functools.partial(_map_to_device_params, device="meta") + functools.partial(_map_to_device_params, device="meta"), + filter_empty=False, ).to_module(policy): policy = deepcopy(policy) param_and_buf.apply( - functools.partial(_map_to_device_params, device=self.policy_device) + functools.partial(_map_to_device_params, device=self.policy_device), + filter_empty=False, ).to_module(policy) return policy, get_weights_fn @@ -1495,7 +1497,9 @@ def map_weight( weight = nn.Parameter(weight, requires_grad=False) return weight - local_policy_weights = TensorDictParams(policy_weights.apply(map_weight)) + local_policy_weights = TensorDictParams( + policy_weights.apply(map_weight, filter_empty=False) + ) def _get_weight_fn(weights=policy_weights): # This function will give the local_policy_weight the original weights. diff --git a/torchrl/envs/transforms/rlhf.py b/torchrl/envs/transforms/rlhf.py index 623bc2864fe..33874393038 100644 --- a/torchrl/envs/transforms/rlhf.py +++ b/torchrl/envs/transforms/rlhf.py @@ -112,7 +112,9 @@ def __init__( # check that the model has parameters params = TensorDict.from_module(actor) - with params.apply(_stateless_param, device="meta").to_module(actor): + with params.apply( + _stateless_param, device="meta", filter_empty=False + ).to_module(actor): # copy a stateless actor self.__dict__["functional_actor"] = deepcopy(actor) # we need to register these params as buffer to have `to` and similar @@ -129,7 +131,7 @@ def _make_detached_param(x): ) return x.clone() - self.frozen_params = params.apply(_make_detached_param) + self.frozen_params = params.apply(_make_detached_param, filter_empty=False) if requires_grad: # includes the frozen params/buffers in the module parameters/buffers self.frozen_params = TensorDictParams(self.frozen_params, no_convert=True) diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 1f5edcf26ed..5d620b56227 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -262,7 +262,9 @@ def _compare_and_expand(param): params = TensorDictParams( params.apply( - _compare_and_expand, batch_size=[expand_dim, *params.shape] + _compare_and_expand, + batch_size=[expand_dim, *params.shape], + filter_empty=False, ), no_convert=True, ) @@ -283,7 +285,7 @@ def _compare_and_expand(param): # set the functional module: we need to convert the params to non-differentiable params # otherwise they will appear twice in parameters with params.apply( - self._make_meta_params, device=torch.device("meta") + self._make_meta_params, device=torch.device("meta"), filter_empty=False ).to_module(module): # avoid buffers and params being exposed self.__dict__[module_name] = deepcopy(module) @@ -293,7 +295,9 @@ def _compare_and_expand(param): # if create_target_params: # we create a TensorDictParams to keep the target params as Buffer instances target_params = TensorDictParams( - params.apply(_make_target_param(clone=create_target_params)), + params.apply( + _make_target_param(clone=create_target_params), filter_empty=False + ), no_convert=True, ) setattr(self, name_params_target + "_params", target_params) diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 6572084c8ec..03e82689ad5 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -197,7 +197,9 @@ def __init__( actor_critic = ActorCriticWrapper(actor_network, value_network) params = TensorDict.from_module(actor_critic) - params_meta = params.apply(self._make_meta_params, device=torch.device("meta")) + params_meta = params.apply( + self._make_meta_params, device=torch.device("meta"), filter_empty=False + ) with params_meta.to_module(actor_critic): self.__dict__["actor_critic"] = deepcopy(actor_critic) diff --git a/torchrl/objectives/multiagent/qmixer.py b/torchrl/objectives/multiagent/qmixer.py index f7b9307a962..fcfcba49ca1 100644 --- a/torchrl/objectives/multiagent/qmixer.py +++ b/torchrl/objectives/multiagent/qmixer.py @@ -224,7 +224,7 @@ def __init__( global_value_network = SafeSequential(local_value_network, mixer_network) params = TensorDict.from_module(global_value_network) with params.apply( - self._make_meta_params, device=torch.device("meta") + self._make_meta_params, device=torch.device("meta"), filter_empty=False ).to_module(global_value_network): self.__dict__["global_value_network"] = deepcopy(global_value_network)