diff --git a/test/conftest.py b/test/conftest.py index 5ce980a4080..41bd69ec2f6 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -53,7 +53,7 @@ def fin(): request.addfinalizer(fin) -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(autouse=True) def set_warnings() -> None: warnings.filterwarnings( "ignore", diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 6d291bf5986..8fcbd5a6699 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -255,7 +255,8 @@ def __init__( if functional: self.convert_to_functional( - actor_network, "actor_network", + actor_network, + "actor_network", ) else: self.actor_network = actor_network @@ -350,7 +351,7 @@ def in_keys(self): *[("next", key) for key in self.actor_network.in_keys], ] if self.critic_coef: - keys.extend(self.critic.in_keys) + keys.extend(self.critic_network.in_keys) return list(set(keys)) @property @@ -414,11 +415,11 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: # TODO: if the advantage is gathered by forward, this introduces an # overhead that we could easily reduce. target_return = tensordict.get(self.tensor_keys.value_target) - tensordict_select = tensordict.select(*self.critic.in_keys) + tensordict_select = tensordict.select(*self.critic_network.in_keys) with self.critic_network_params.to_module( - self.critic + self.critic_network ) if self.functional else contextlib.nullcontext(): - state_value = self.critic( + state_value = self.critic_network( tensordict_select, ).get(self.tensor_keys.value) loss_value = distance_loss( @@ -477,13 +478,19 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams if hasattr(self, "gamma"): hp["gamma"] = self.gamma if value_type == ValueEstimators.TD1: - self._value_estimator = TD1Estimator(value_network=self.critic, **hp) + self._value_estimator = TD1Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.TD0: - self._value_estimator = TD0Estimator(value_network=self.critic, **hp) + self._value_estimator = TD0Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.GAE: - self._value_estimator = GAE(value_network=self.critic, **hp) + self._value_estimator = GAE(value_network=self.critic_network, **hp) elif value_type == ValueEstimators.TDLambda: - self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) + self._value_estimator = TDLambdaEstimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.VTrace: # VTrace currently does not support functional call on the actor if self.functional: diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 4613810d0d3..9738b922c5d 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -351,7 +351,7 @@ def _set_in_keys(self): ("next", self.tensor_keys.terminated), *self.actor_network.in_keys, *[("next", key) for key in self.actor_network.in_keys], - *self.critic.in_keys, + *self.critic_network.in_keys, ] self._in_keys = list(set(keys)) @@ -398,11 +398,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: try: target_return = tensordict.get(self.tensor_keys.value_target) - tensordict_select = tensordict.select(*self.critic.in_keys) + tensordict_select = tensordict.select(*self.critic_network.in_keys) with self.critic_network_params.to_module( - self.critic + self.critic_network ) if self.functional else contextlib.nullcontext(): - state_value = self.critic(tensordict_select).get(self.tensor_keys.value) + state_value = self.critic_network(tensordict_select).get( + self.tensor_keys.value + ) loss_value = distance_loss( target_return, state_value, @@ -427,13 +429,19 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams hp["gamma"] = self.gamma hp.update(hyperparams) if value_type == ValueEstimators.TD1: - self._value_estimator = TD1Estimator(value_network=self.critic, **hp) + self._value_estimator = TD1Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.TD0: - self._value_estimator = TD0Estimator(value_network=self.critic, **hp) + self._value_estimator = TD0Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.GAE: - self._value_estimator = GAE(value_network=self.critic, **hp) + self._value_estimator = GAE(value_network=self.critic_network, **hp) elif value_type == ValueEstimators.TDLambda: - self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) + self._value_estimator = TDLambdaEstimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.VTrace: # VTrace currently does not support functional call on the actor if self.functional: