Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 9, 2024
1 parent 4c2d6c2 commit 738d971
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple, overload
from typing import overload, Tuple

import torch
from tensordict import TensorDict, TensorDictBase, unravel_key
Expand Down Expand Up @@ -448,9 +448,17 @@ def _cached_detach_critic_network_params(self):
return None
return self.critic_network_params.detach()


@overload
def forward(self, *, action, next_reward, next_terminated, next_truncated, next_observation, observation):
def forward(
self,
*,
action,
next_reward,
next_terminated,
next_truncated,
next_observation,
observation,
):
# The key names can be extrapolated from test_a2c_notensordict in test/test_cost.py
...

Expand Down

0 comments on commit 738d971

Please sign in to comment.