From 4c2d6c2ee0b070847476f0ed9f511193d1dd9d4e Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 9 Feb 2024 21:01:16 +0000 Subject: [PATCH 1/2] init --- torchrl/objectives/a2c.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 6edcda5c800..edd6c655ac4 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -6,10 +6,10 @@ import warnings from copy import deepcopy from dataclasses import dataclass -from typing import Tuple +from typing import Tuple, overload import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import TensorDict, TensorDictBase, unravel_key from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule from tensordict.utils import NestedKey from torch import distributions as d @@ -351,7 +351,12 @@ def in_keys(self): ] if self.critic_coef: keys.extend(self.critic.in_keys) - return list(set(keys)) + out_keys = [] + for key in keys: + key = unravel_key(key) + if key not in keys: + out_keys.append(key) + return out_keys @property def out_keys(self): @@ -443,6 +448,12 @@ def _cached_detach_critic_network_params(self): return None return self.critic_network_params.detach() + + @overload + def forward(self, *, action, next_reward, next_terminated, next_truncated, next_observation, observation): + # The key names can be extrapolated from test_a2c_notensordict in test/test_cost.py + ... + @dispatch() def forward(self, tensordict: TensorDictBase) -> TensorDictBase: tensordict = tensordict.clone(False) From 738d97109190cf04c8b5c450d13b622f718a7af7 Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 9 Feb 2024 21:05:24 +0000 Subject: [PATCH 2/2] lint --- torchrl/objectives/a2c.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index edd6c655ac4..95084d20376 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -6,7 +6,7 @@ import warnings from copy import deepcopy from dataclasses import dataclass -from typing import Tuple, overload +from typing import overload, Tuple import torch from tensordict import TensorDict, TensorDictBase, unravel_key @@ -448,9 +448,17 @@ def _cached_detach_critic_network_params(self): return None return self.critic_network_params.detach() - @overload - def forward(self, *, action, next_reward, next_terminated, next_truncated, next_observation, observation): + def forward( + self, + *, + action, + next_reward, + next_terminated, + next_truncated, + next_observation, + observation, + ): # The key names can be extrapolated from test_a2c_notensordict in test/test_cost.py ...