diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 39ea546fd9d..1a5ba7a38c2 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -16,7 +16,7 @@ from tensordict.utils import NestedKey from torch import distributions as d -from torchrl.objectives.common import LossModule +from torchrl.objectives.common import LossModule, LossContainerBase from torchrl.objectives.utils import ( _cache_values, @@ -34,20 +34,6 @@ VTrace, ) - -class LossContainerBase: - """ContainerBase class loss tensorclass's.""" - - __getitem__ = TensorDictBase.__getitem__ - - def aggregate_loss(self): - result = 0.0 - for key in self.__dataclass_attr__: - if key.startswith("loss_"): - result += getattr(self, key) - return result - - @tensorclass class A2CLosses(LossContainerBase): """The tensorclass for The A2CLoss Loss class.""" @@ -58,11 +44,6 @@ class A2CLosses(LossContainerBase): loss_entropy: torch.Tensor | None = None entropy: torch.Tensor | None = None - @property - def aggregate_loss(self): - return self.loss_critic + self.loss_objective + self.loss_entropy - - class A2CLoss(LossModule): """TorchRL implementation of the A2C loss. @@ -164,8 +145,8 @@ class A2CLoss(LossModule): A2CLosses( entropy=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_critic=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - loss_entropy=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - loss_objective=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_entropy=Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False), + loss_objective=Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False), batch_size=torch.Size([]), device=None, is_shared=False) @@ -497,7 +478,7 @@ def _cached_detach_critic_network_params(self): return self.critic_network_params.detach() @dispatch() - def forward(self, tensordict: TensorDictBase) -> A2CLosses: + def forward(self, tensordict: TensorDictBase) -> A2CLosses | TensorDictBase: tensordict = tensordict.clone(False) advantage = tensordict.get(self.tensor_keys.advantage, None) if advantage is None: diff --git a/torchrl/objectives/common.py b/torchrl/objectives/common.py index 6b6fd391560..70b944d3388 100644 --- a/torchrl/objectives/common.py +++ b/torchrl/objectives/common.py @@ -12,7 +12,7 @@ from typing import Iterator, List, Optional, Tuple import torch -from tensordict import is_tensor_collection, TensorDict, TensorDictBase +from tensordict import tensorclass, is_tensor_collection, TensorDict, TensorDictBase from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams from torch import nn @@ -38,6 +38,19 @@ def __init__(cls, name, bases, attr_dict): cls.forward = set_exploration_type(ExplorationType.MODE)(cls.forward) +class LossContainerBase: + """ContainerBase class loss tensorclass's.""" + + __getitem__ = TensorDictBase.__getitem__ + + @property + def aggregate_loss(self): + result = torch.zeros((), device=self.device) + for key in self.__dataclass_attr__: + if key.startswith("loss_"): + result += getattr(self, key) + return result + class LossModule(TensorDictModuleBase, metaclass=_LossMeta): """A parent class for RL losses. @@ -252,7 +265,6 @@ def _compare_and_expand(param): return param._apply_nest( _compare_and_expand, batch_size=[expand_dim, *param.shape], - filter_empty=False, call_on_nested=True, ) if not isinstance(param, nn.Parameter): @@ -276,7 +288,6 @@ def _compare_and_expand(param): params.apply( _compare_and_expand, batch_size=[expand_dim, *params.shape], - filter_empty=False, call_on_nested=True, ), no_convert=True, @@ -298,7 +309,7 @@ def _compare_and_expand(param): # set the functional module: we need to convert the params to non-differentiable params # otherwise they will appear twice in parameters with params.apply( - self._make_meta_params, device=torch.device("meta"), filter_empty=False + self._make_meta_params, device=torch.device("meta") ).to_module(module): # avoid buffers and params being exposed self.__dict__[module_name] = deepcopy(module) @@ -309,7 +320,7 @@ def _compare_and_expand(param): # we create a TensorDictParams to keep the target params as Buffer instances target_params = TensorDictParams( params.apply( - _make_target_param(clone=create_target_params), filter_empty=False + _make_target_param(clone=create_target_params) ), no_convert=True, ) diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 4f9ff46d448..6878a4a5a4f 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -25,7 +25,7 @@ from torchrl.modules import ProbabilisticActor, QValueActor from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible -from torchrl.objectives.common import LossModule +from torchrl.objectives.common import LossModule, LossContainerBase from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_ERROR, @@ -37,20 +37,6 @@ from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator - -class LossContainerBase: - """ContainerBase class loss tensorclass's.""" - - __getitem__ = TensorDictBase.__getitem__ - - def aggregate_loss(self): - result = 0.0 - for key in self.__dataclass_attr__: - if key.startswith("loss_"): - result += getattr(self, key) - return result - - @tensorclass class CQLLosses(LossContainerBase): """The tensorclass for The CQLLoss Loss class.""" @@ -217,7 +203,7 @@ class CQLLoss(LossModule): >>> loss = CQLLoss(actor, qvalue) >>> batch = [2, ] >>> action = spec.rand(batch) - >>> loss_actor, loss_qvalue, _, _, _, _ = loss( + >>> loss_actor, loss_qvalue, loss_actor_bc, loss_qvalue, loss_cql, loss_alpha = loss( ... observation=torch.randn(*batch, n_obs), ... action=action, ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), @@ -532,7 +518,7 @@ def out_keys(self, values): self._out_keys = values @dispatch - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + def forward(self, tensordict: TensorDictBase) -> CQLLosses | TensorDictBase: shape = None if tensordict.ndimension() > 1: shape = tensordict.shape diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index d73b2d83ef8..65a5c8c02b0 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -15,7 +15,7 @@ from tensordict.utils import NestedKey, unravel_key from torchrl.modules.tensordict_module.actors import ActorCriticWrapper -from torchrl.objectives.common import LossModule +from torchrl.objectives.common import LossModule, LossContainerBase from torchrl.objectives.utils import ( _cache_values, _GAMMA_LMBDA_DEPREC_ERROR, @@ -26,20 +26,6 @@ ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator - -class LossContainerBase: - """ContainerBase class loss tensorclass's.""" - - __getitem__ = TensorDictBase.__getitem__ - - def aggregate_loss(self): - result = 0.0 - for key in self.__dataclass_attr__: - if key.startswith("loss_"): - result += getattr(self, key) - return result - - @tensorclass class DDPGLosses(LossContainerBase): """The tensorclass for The DDPGLoss class.""" @@ -171,7 +157,7 @@ class DDPGLoss(LossModule): method. Examples: - >>> out_keys = loss.select_out_keys('loss_actor', 'loss_value') + >>> _ = loss.select_out_keys('loss_actor', 'loss_value') >>> loss_actor, loss_value = loss( ... observation=torch.randn(n_obs), ... action=spec.rand(), @@ -315,7 +301,7 @@ def in_keys(self, values): self._in_keys = values @dispatch - def forward(self, tensordict: TensorDictBase) -> DDPGLosses: + def forward(self, tensordict: TensorDictBase) -> DDPGLosses | TensorDictBase: """Computes the DDPG losses given a tensordict sampled from the replay buffer. This function will also write a "td_error" key that can be used by prioritized replay buffers to assign diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index d93e9fd1087..ca2293f1eb7 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -16,23 +16,9 @@ from torch import distributions as d from torchrl.modules import ProbabilisticActor -from torchrl.objectives.common import LossModule +from torchrl.objectives.common import LossModule, LossContainerBase from torchrl.objectives.utils import distance_loss - -class LossContainerBase: - """ContainerBase class loss tensorclass's.""" - - __getitem__ = TensorDictBase.__getitem__ - - def aggregate_loss(self): - result = 0.0 - for key in self.__dataclass_attr__: - if key.startswith("loss_"): - result += getattr(self, key) - return result - - @tensorclass class OnlineDTLosses(LossContainerBase): """The tensorclass for The OnlineDTLoss Loss class.""" @@ -226,7 +212,7 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor: return -log_p.mean(axis=0) @dispatch - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + def forward(self, tensordict: TensorDictBase) -> OnlineDTLosses | TensorDictBase: """Compute the loss for the Online Decision Transformer.""" # extract action targets tensordict = tensordict.clone(False) diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 2fb06c7de17..78e8662f6f2 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -24,7 +24,7 @@ ) from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible -from torchrl.objectives.common import LossModule +from torchrl.objectives.common import LossModule, LossContainerBase from torchrl.objectives.utils import ( _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, @@ -34,20 +34,6 @@ from torchrl.objectives.value import TDLambdaEstimator from torchrl.objectives.value.advantages import TD0Estimator, TD1Estimator - -class LossContainerBase: - """ContainerBase class loss tensorclass's.""" - - __getitem__ = TensorDictBase.__getitem__ - - def aggregate_loss(self): - result = 0.0 - for key in self.__dataclass_attr__: - if key.startswith("loss_"): - result += getattr(self, key) - return result - - @tensorclass class DQNLosses(LossContainerBase): """The tensorclass for The DQN Loss class.""" @@ -55,10 +41,6 @@ class DQNLosses(LossContainerBase): loss_objective: torch.Tensor loss: torch.Tensor - @property - def aggregate_loss(self): - return self.loss_critic + self.loss_objective + self.loss_entropy - class DQNLoss(LossModule): """The DQN Loss class. @@ -334,7 +316,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self._value_estimator.set_keys(**tensor_keys) @dispatch - def forward(self, tensordict: TensorDictBase) -> DQNLosses: + def forward(self, tensordict: TensorDictBase) -> DQNLosses | TensorDictBase: """Computes the DQN loss given a tensordict sampled from the replay buffer. This function will also write a "td_error" key that can be used by prioritized replay buffers to assign @@ -404,7 +386,10 @@ def forward(self, tensordict: TensorDictBase) -> DQNLosses: inplace=True, ) loss = distance_loss(pred_val_index, target_value, self.loss_function) - return TensorDict({"loss": loss.mean()}, []) + loss_td = TensorDict({"loss": loss.mean()}, []) + if self.return_tensorclass: + return DQNLosses._from_tensordict(loss_td) + return loss_td class DistributionalDQNLoss(LossModule): @@ -531,7 +516,7 @@ def _log_ps_a_categorical(action, action_log_softmax): action = action.expand(new_shape) return torch.gather(action_log_softmax, -1, index=action).squeeze(-1) - def forward(self, input_tensordict: TensorDictBase) -> DQNLosses: + def forward(self, input_tensordict: TensorDictBase) -> DQNLosses | TensorDictBase: # from https://github.com/Kaixhin/Rainbow/blob/9ff5567ad1234ae0ed30d8471e8f13ae07119395/agent.py tensordict = TensorDict( source=input_tensordict, batch_size=input_tensordict.batch_size @@ -644,8 +629,11 @@ def forward(self, input_tensordict: TensorDictBase) -> DQNLosses: loss.detach().unsqueeze(1).to(input_tensordict.device), inplace=True, ) - loss_td = TensorDict({"loss": loss.mean()}, []) - return loss_td + loss = _reduce(loss, reduction=self.reduction) + td_out = TensorDict({"loss": loss}, []) + if self.return_tensorclass: + return DQNLosses._from_tensordict(loss_td) + return td_out def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): if value_type is None: diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 3ecffceed71..6379ba7f7d3 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -14,7 +14,7 @@ from torchrl.envs.model_based.dreamer import DreamerEnv from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp -from torchrl.objectives.common import LossModule +from torchrl.objectives.common import LossModule, LossContainerBase from torchrl.objectives.utils import ( _GAMMA_LMBDA_DEPREC_ERROR, default_value_kwargs, @@ -24,20 +24,6 @@ ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator - -class LossContainerBase: - """ContainerBase class loss tensorclass's.""" - - __getitem__ = TensorDictBase.__getitem__ - - def aggregate_loss(self): - result = 0.0 - for key in self.__dataclass_attr__: - if key.startswith("loss_"): - result += getattr(self, key) - return result - - @tensorclass class DreamerModelLosses(LossContainerBase): """The tensorclass for The Dreamer Model Loss class.""" @@ -288,7 +274,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: def forward( self, tensordict: TensorDict - ) -> Tuple[DreamerModelLosses, DreamerModelLosses]: + ) -> Tuple[DreamerModelLosses, TensorDict] | Tuple[TensorDict, TensorDict]: with torch.no_grad(): tensordict = tensordict.select("state", self.tensor_keys.belief) tensordict = tensordict.reshape(-1) @@ -328,7 +314,7 @@ def forward( if self.return_tensorclass: return DreamerModelLosses._from_tensordict( loss_tensordict - ), DreamerModelLosses._from_tensordict(fake_data.detach()) + ), fake_data.detach() return loss_tensordict, fake_data.detach() def lambda_target(self, reward: torch.Tensor, value: torch.Tensor) -> torch.Tensor: diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 76b47824eeb..8bfb4ceee73 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -17,7 +17,7 @@ from torchrl.data.utils import _find_action_space from torchrl.modules import ProbabilisticActor -from torchrl.objectives.common import LossModule +from torchrl.objectives.common import LossModule, LossContainerBase from torchrl.objectives.utils import ( _GAMMA_LMBDA_DEPREC_ERROR, _vmap_func, @@ -27,20 +27,6 @@ ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator - -class LossContainerBase: - """ContainerBase class loss tensorclass's.""" - - __getitem__ = TensorDictBase.__getitem__ - - def aggregate_loss(self): - result = 0.0 - for key in self.__dataclass_attr__: - if key.startswith("loss_"): - result += getattr(self, key) - return result - - @tensorclass class IQLLosses(LossContainerBase): """The tensorclass for The PPOLoss Loss class.""" @@ -377,7 +363,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: self._set_in_keys() @dispatch - def forward(self, tensordict: TensorDictBase) -> IQLLosses: + def forward(self, tensordict: TensorDictBase) -> IQLLosses | TensorDictBase: shape = None if tensordict.ndimension() > 1: shape = tensordict.shape diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 1d363072eb9..9a6a4292a00 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -23,7 +23,7 @@ from tensordict.utils import NestedKey from torch import distributions as d -from torchrl.objectives.common import LossModule +from torchrl.objectives.common import LossModule, LossContainerBase from torchrl.objectives.utils import ( _cache_values, @@ -41,25 +41,10 @@ VTrace, ) - -class LossContainerBase: - """ContainerBase class loss tensorclass's.""" - - __getitem__ = TensorDictBase.__getitem__ - - def aggregate_loss(self): - result = 0.0 - for key in self.__dataclass_attr__: - if key.startswith("loss_"): - result += getattr(self, key) - return result - - @tensorclass class PPOLosses(LossContainerBase): """The tensorclass for The PPOLoss Loss class.""" - loss_actor: torch.Tensor loss_objective: torch.Tensor loss_critic: torch.Tensor | None = None loss_entropy: torch.Tensor | None = None @@ -214,10 +199,10 @@ class PPOLoss(LossModule): >>> loss = PPOLoss(actor, value, return_tensorclass=True) >>> loss(data) PPOLosses( - entropy=Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False), + entropy=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_critic=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - loss_entropy=Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False), - loss_objective=Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False), + loss_entropy=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_objective=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), batch_size=torch.Size([]), device=None, is_shared=False) @@ -570,7 +555,7 @@ def _cached_critic_network_params_detached(self): return self.critic_network_params.detach() @dispatch - def forward(self, tensordict: TensorDictBase) -> PPOLosses: + def forward(self, tensordict: TensorDictBase) -> PPOLosses | TensorDictBase: tensordict = tensordict.clone(False) advantage = tensordict.get(self.tensor_keys.advantage, None) if advantage is None: @@ -595,8 +580,6 @@ def forward(self, tensordict: TensorDictBase) -> PPOLosses: if self.critic_coef: loss_critic = self.loss_critic(tensordict).mean() td_out.set("loss_critic", loss_critic.mean()) - if self.return_tensorclass: - return PPOLosses._from_tensordict(td_out) loss_critic = self.loss_critic(tensordict) td_out.set("loss_critic", loss_critic) td_out = td_out.named_apply( @@ -605,6 +588,8 @@ def forward(self, tensordict: TensorDictBase) -> PPOLosses: else value, batch_size=[], ) + if self.return_tensorclass: + return PPOLosses._from_tensordict(td_out) return td_out def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 014352060ab..ab9f881e702 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -18,7 +18,7 @@ from torchrl.data.tensor_specs import CompositeSpec from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp -from torchrl.objectives.common import LossModule +from torchrl.objectives.common import LossModule, LossContainerBase from torchrl.objectives.utils import ( _cache_values, @@ -31,20 +31,6 @@ ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator - -class LossContainerBase: - """ContainerBase class loss tensorclass's.""" - - __getitem__ = TensorDictBase.__getitem__ - - def aggregate_loss(self): - result = 0.0 - for key in self.__dataclass_attr__: - if key.startswith("loss_"): - result += getattr(self, key) - return result - - @tensorclass class REDQLosses(LossContainerBase): """The tensorclass for The REDQLoss Loss class.""" @@ -476,7 +462,7 @@ def _qvalue_params_cat(self, selected_q_params): return qvalue_params @dispatch - def forward(self, tensordict: TensorDictBase) -> REDQLosses: + def forward(self, tensordict: TensorDictBase) -> REDQLosses | TensorDictBase: obs_keys = self.actor_network.in_keys tensordict_select = tensordict.clone(False).select( "next", *obs_keys, self.tensor_keys.action @@ -618,6 +604,14 @@ def forward(self, tensordict: TensorDictBase) -> REDQLosses: }, [], ) + td_out = td_out.named_apply( + lambda name, value: _reduce(value, reduction=self.reduction) + if name.startswith("loss_") + else value, + batch_size=[], + ) + if self.return_tensorclass: + return SACLosses._from_tensordict(td_out) return td_out diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index ef793101bf5..aa53998aafa 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -14,7 +14,7 @@ from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule from tensordict.utils import NestedKey -from torchrl.objectives.common import LossModule +from torchrl.objectives.common import LossModule, LossContainerBase from torchrl.objectives.utils import ( _GAMMA_LMBDA_DEPREC_ERROR, @@ -31,20 +31,6 @@ VTrace, ) - -class LossContainerBase: - """ContainerBase class loss tensorclass's.""" - - __getitem__ = TensorDictBase.__getitem__ - - def aggregate_loss(self): - result = 0.0 - for key in self.__dataclass_attr__: - if key.startswith("loss_"): - result += getattr(self, key) - return result - - @tensorclass class ReinforceLosses(LossContainerBase): """The tensorclass for The Reinforce Loss class.""" @@ -407,7 +393,7 @@ def in_keys(self, values): self._in_keys = values @dispatch - def forward(self, tensordict: TensorDictBase) -> ReinforceLosses: + def forward(self, tensordict: TensorDictBase) -> ReinforceLosses | TensorDictBase: advantage = tensordict.get(self.tensor_keys.advantage, None) if advantage is None: self.value_estimator( diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 1cb50d43db5..4e5712362b3 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -23,7 +23,7 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import ProbabilisticActor from torchrl.modules.tensordict_module.actors import ActorCriticWrapper -from torchrl.objectives.common import LossModule +from torchrl.objectives.common import LossModule, LossContainerBase from torchrl.objectives.utils import ( _cache_values, @@ -45,20 +45,6 @@ def new_func(self, *args, **kwargs): return new_func - -class LossContainerBase: - """ContainerBase class loss tensorclass's.""" - - __getitem__ = TensorDictBase.__getitem__ - - def aggregate_loss(self): - result = 0.0 - for key in self.__dataclass_attr__: - if key.startswith("loss_"): - result += getattr(self, key) - return result - - @tensorclass class SACLosses(LossContainerBase): """The tensorclass for The SACLoss Loss class.""" @@ -581,7 +567,7 @@ def out_keys(self, values): self._out_keys = values @dispatch - def forward(self, tensordict: TensorDictBase) -> SACLosses: + def forward(self, tensordict: TensorDictBase) -> SACLosses | TensorDictBase: shape = None if tensordict.ndimension() > 1: shape = tensordict.shape @@ -615,8 +601,17 @@ def forward(self, tensordict: TensorDictBase) -> SACLosses: "entropy": entropy.detach().mean(), } if self._version == 1: - out["loss_value"] = loss_value.mean() - return TensorDict(out, []) + out["loss_value"] = loss_value + td_out = TensorDict(out, []) + td_out = td_out.named_apply( + lambda name, value: _reduce(value, reduction=self.reduction) + if name.startswith("loss_") + else value, + batch_size=[], + ) + if self.return_tensorclass: + return SACLosses._from_tensordict(td_out) + return td_out @property @_cache_values diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index a012f4e4ff8..cae1bd8a51a 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -15,7 +15,7 @@ from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec, TensorSpec from torchrl.envs.utils import step_mdp -from torchrl.objectives.common import LossModule +from torchrl.objectives.common import LossModule, LossContainerBase from torchrl.objectives.utils import ( _cache_values, @@ -28,20 +28,6 @@ ) from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator - -class LossContainerBase: - """ContainerBase class loss tensorclass's.""" - - __getitem__ = TensorDictBase.__getitem__ - - def aggregate_loss(self): - result = 0.0 - for key in self.__dataclass_attr__: - if key.startswith("loss_"): - result += getattr(self, key) - return result - - @tensorclass class TD3Losses(LossContainerBase): """The tensorclass for The TD3 Loss class.""" @@ -492,7 +478,7 @@ def value_loss(self, tensordict): return loss_qval, metadata @dispatch - def forward(self, tensordict: TensorDictBase) -> TD3Losses: + def forward(self, tensordict: TensorDictBase) -> TD3Losses | TensorDictBase: tensordict_save = tensordict loss_actor, metadata_actor = self.actor_loss(tensordict) loss_qval, metadata_value = self.value_loss(tensordict_save)