Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 9, 2024
1 parent 4d52d5f commit 4c2d6c2
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple
from typing import Tuple, overload

import torch
from tensordict import TensorDict, TensorDictBase
from tensordict import TensorDict, TensorDictBase, unravel_key
from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.utils import NestedKey
from torch import distributions as d
Expand Down Expand Up @@ -351,7 +351,12 @@ def in_keys(self):
]
if self.critic_coef:
keys.extend(self.critic.in_keys)
return list(set(keys))
out_keys = []
for key in keys:
key = unravel_key(key)
if key not in keys:
out_keys.append(key)
return out_keys

@property
def out_keys(self):
Expand Down Expand Up @@ -443,6 +448,12 @@ def _cached_detach_critic_network_params(self):
return None
return self.critic_network_params.detach()


@overload
def forward(self, *, action, next_reward, next_terminated, next_truncated, next_observation, observation):
# The key names can be extrapolated from test_a2c_notensordict in test/test_cost.py
...

@dispatch()
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
tensordict = tensordict.clone(False)
Expand Down

0 comments on commit 4c2d6c2

Please sign in to comment.