diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index edd6c655ac4..95084d20376 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -6,7 +6,7 @@ import warnings from copy import deepcopy from dataclasses import dataclass -from typing import Tuple, overload +from typing import overload, Tuple import torch from tensordict import TensorDict, TensorDictBase, unravel_key @@ -448,9 +448,17 @@ def _cached_detach_critic_network_params(self): return None return self.critic_network_params.detach() - @overload - def forward(self, *, action, next_reward, next_terminated, next_truncated, next_observation, observation): + def forward( + self, + *, + action, + next_reward, + next_terminated, + next_truncated, + next_observation, + observation, + ): # The key names can be extrapolated from test_a2c_notensordict in test/test_cost.py ...