Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NOMERG] Add @overload to forward in losses #1893

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple
from typing import overload, Tuple

import torch
from tensordict import TensorDict, TensorDictBase
from tensordict import TensorDict, TensorDictBase, unravel_key
from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.utils import NestedKey
from torch import distributions as d
Expand Down Expand Up @@ -351,7 +351,12 @@ def in_keys(self):
]
if self.critic_coef:
keys.extend(self.critic.in_keys)
return list(set(keys))
out_keys = []
for key in keys:
key = unravel_key(key)
if key not in keys:
out_keys.append(key)
return out_keys

@property
def out_keys(self):
Expand Down Expand Up @@ -443,6 +448,20 @@ def _cached_detach_critic_network_params(self):
return None
return self.critic_network_params.detach()

@overload
def forward(
self,
*,
action,
next_reward,
next_terminated,
next_truncated,
next_observation,
observation,
):
# The key names can be extrapolated from test_a2c_notensordict in test/test_cost.py
...

@dispatch()
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = tensordict.clone(False)
Expand Down
Loading