diff --git a/test/test_cost.py b/test/test_cost.py index 3c07f5f79f4..5db7c688b7b 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -491,7 +491,11 @@ def test_dqn(self, delay_value, double_dqn, device, action_spec_type, td_est): action_spec_type=action_spec_type, device=device ) loss_fn = DQNLoss( - actor, loss_function="l2", delay_value=delay_value, double_dqn=double_dqn + actor, + loss_function="l2", + delay_value=delay_value, + double_dqn=double_dqn, + return_tensorclass=False, ) if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): @@ -1490,6 +1494,7 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est): loss_function="l2", delay_actor=delay_actor, delay_value=delay_value, + return_tensorclass=False, ) if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): @@ -2118,6 +2123,7 @@ def test_td3( noise_clip=noise_clip, delay_actor=delay_actor, delay_qvalue=delay_qvalue, + return_tensorclass=False, ) if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): @@ -4216,6 +4222,7 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est): num_qvalue_nets=num_qvalue, loss_function="l2", delay_qvalue=delay_qvalue, + return_tensorclass=False, ) if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): @@ -5013,6 +5020,7 @@ def test_cql( with_lagrange=with_lagrange, delay_actor=delay_actor, delay_qvalue=delay_qvalue, + return_tensorclass=False, ) if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): @@ -6648,7 +6656,13 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional): else: raise NotImplementedError - loss_fn = A2CLoss(actor, value, loss_critic_type="l2", functional=functional) + loss_fn = A2CLoss( + actor, + value, + loss_critic_type="l2", + functional=functional, + return_tensorclass=False, + ) # Check error is raised when actions require grads td["action"].requires_grad = True @@ -7113,6 +7127,7 @@ def test_reinforce_value_net( critic_network=value_net, delay_value=delay_value, functional=functional, + return_tensorclass=False, ) td = TensorDict( @@ -7705,6 +7720,7 @@ def test_dreamer_world_model( reco_loss=reco_loss, delayed_clamp=delayed_clamp, free_nats=free_nats, + return_tensorclass=False, ) loss_td, _ = loss_module(tensordict) for loss_str, lmbda in zip( @@ -8525,6 +8541,7 @@ def test_iql( temperature=temperature, expectile=expectile, loss_function="l2", + return_tensorclass=False, ) if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): with pytest.raises(NotImplementedError): diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 8fcbd5a6699..dc96cd8f56d 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import contextlib import warnings from copy import deepcopy @@ -9,7 +11,7 @@ from typing import Tuple import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import tensorclass, TensorDict, TensorDictBase from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule from tensordict.utils import NestedKey from torch import distributions as d @@ -31,6 +33,20 @@ ) +@tensorclass +class A2CLosses: + """The tensorclass for The A2CLoss Loss class.""" + + loss_objective: torch.Tensor + loss_critic: torch.Tensor | None = None + loss_entropy: torch.Tensor | None = None + entropy: torch.Tensor | None = None + + @property + def aggregate_loss(self): + return self.loss_critic + self.loss_objective + self.loss_entropy + + class A2CLoss(LossModule): """TorchRL implementation of the A2C loss. @@ -234,6 +250,7 @@ def __init__( functional: bool = True, actor: ProbabilisticTensorDictSequential = None, critic: ProbabilisticTensorDictSequential = None, + return_tensorclass: bool = False, ): if actor is not None: actor_network = actor @@ -290,6 +307,7 @@ def __init__( if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self.loss_critic_type = loss_critic_type + self.return_tensorclass = return_tensorclass @property def functional(self): @@ -445,7 +463,7 @@ def _cached_detach_critic_network_params(self): return self.critic_network_params.detach() @dispatch() - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + def forward(self, tensordict: TensorDictBase) -> A2CLosses: tensordict = tensordict.clone(False) advantage = tensordict.get(self.tensor_keys.advantage, None) if advantage is None: @@ -466,6 +484,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if self.critic_coef: loss_critic = self.loss_critic(tensordict).mean() td_out.set("loss_critic", loss_critic.mean()) + if self.return_tensorclass: + return A2CLosses._from_tensordict(td_out) return td_out def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): diff --git a/torchrl/objectives/cql.py b/torchrl/objectives/cql.py index 69a30c7f484..d6d7dc9e116 100644 --- a/torchrl/objectives/cql.py +++ b/torchrl/objectives/cql.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import math import warnings from copy import deepcopy @@ -12,7 +14,7 @@ import numpy as np import torch import torch.nn as nn -from tensordict import TensorDict, TensorDictBase +from tensordict import tensorclass, TensorDict, TensorDictBase from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey, unravel_key from torch import Tensor @@ -36,6 +38,20 @@ from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator +@tensorclass +class CQLLosses: + """The tensorclass for The CQLLoss Loss class.""" + + loss_objective: torch.Tensor + loss_critic: torch.Tensor | None = None + loss_entropy: torch.Tensor | None = None + entropy: torch.Tensor | None = None + + @property + def aggregate_loss(self): + return self.loss_critic + self.loss_objective + self.loss_entropy + + class CQLLoss(LossModule): """TorchRL implementation of the continuous CQL loss. @@ -269,6 +285,7 @@ def __init__( num_random: int = 10, with_lagrange: bool = False, lagrange_thresh: float = 0.0, + return_tensorclass: bool = False, ) -> None: self._out_keys = None super().__init__() @@ -354,6 +371,7 @@ def __init__( self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) + self.return_tensorclass = return_tensorclass @property def target_entropy(self): @@ -521,7 +539,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: } if self.with_lagrange: out["loss_alpha_prime"] = alpha_prime_loss.mean() - return TensorDict(out, []) + td_out = TensorDict(out, []) + if self.return_tensorclass: + return CQLLosses._from_tensordict(td_out) + return td_out @property @_cache_values @@ -1000,6 +1021,7 @@ def __init__( delay_value: bool = True, gamma: float = None, action_space=None, + return_tensorclass: bool = False, ) -> None: super().__init__() self._in_keys = None @@ -1040,6 +1062,7 @@ def __init__( if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self.return_tensorclass = return_tensorclass def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: @@ -1171,7 +1194,7 @@ def value_loss( return loss, metadata @dispatch - def forward(self, tensordict: TensorDictBase) -> TensorDict: + def forward(self, tensordict: TensorDictBase) -> CQLLosses: """Computes the (DQN) CQL loss given a tensordict sampled from the replay buffer. This function will also write a "td_error" key that can be used by prioritized replay buffers to assign @@ -1196,6 +1219,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: source=source, batch_size=[], ) + if self.return_tensorclass: + return CQLLosses._from_tensordict(td_out) return td_out diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index 03e82689ad5..4e12ab57585 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -10,7 +10,7 @@ from typing import Tuple import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import tensorclass, TensorDict, TensorDictBase from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey, unravel_key @@ -26,6 +26,20 @@ from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator +@tensorclass +class DDPGLosses: + """The tensorclass for The DDPGLoss class.""" + + loss_objective: torch.Tensor + loss_critic: torch.Tensor | None = None + loss_entropy: torch.Tensor | None = None + entropy: torch.Tensor | None = None + + @property + def aggregate_loss(self): + return self.loss_critic + self.loss_objective + self.loss_entropy + + class DDPGLoss(LossModule): """The DDPG Loss class. @@ -189,6 +203,7 @@ def __init__( delay_value: bool = True, gamma: float = None, separate_losses: bool = False, + return_tensorclass: bool = False, ) -> None: self._in_keys = None super().__init__() @@ -229,6 +244,7 @@ def __init__( ) self.loss_function = loss_function + self.return_tensorclass = return_tensorclass if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) @@ -266,7 +282,7 @@ def in_keys(self, values): self._in_keys = values @dispatch - def forward(self, tensordict: TensorDictBase) -> TensorDict: + def forward(self, tensordict: TensorDictBase) -> DDPGLosses: """Computes the DDPG losses given a tensordict sampled from the replay buffer. This function will also write a "td_error" key that can be used by prioritized replay buffers to assign @@ -283,10 +299,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: loss_value, metadata = self.loss_value(tensordict) loss_actor, metadata_actor = self.loss_actor(tensordict) metadata.update(metadata_actor) - return TensorDict( + td_out = TensorDict( source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata}, batch_size=[], ) + if self.return_tensorclass: + return DDPGLosses._from_tensordict(td_out) + return td_out def loss_actor( self, diff --git a/torchrl/objectives/decision_transformer.py b/torchrl/objectives/decision_transformer.py index 954bd0b9a42..8b95efd2e44 100644 --- a/torchrl/objectives/decision_transformer.py +++ b/torchrl/objectives/decision_transformer.py @@ -2,13 +2,14 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations import math from dataclasses import dataclass from typing import Union import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import tensorclass, TensorDict, TensorDictBase from tensordict.nn import dispatch from tensordict.utils import NestedKey @@ -19,6 +20,20 @@ from torchrl.objectives.utils import distance_loss +@tensorclass +class OnlineDTLosses: + """The tensorclass for The OnlineDTLoss Loss class.""" + + loss_objective: torch.Tensor + loss_critic: torch.Tensor | None = None + loss_entropy: torch.Tensor | None = None + entropy: torch.Tensor | None = None + + @property + def aggregate_loss(self): + return self.loss_critic + self.loss_objective + self.loss_entropy + + class OnlineDTLoss(LossModule): r"""TorchRL implementation of the Online Decision Transformer loss. @@ -78,6 +93,7 @@ def __init__( fixed_alpha: bool = False, target_entropy: Union[str, float] = "auto", samples_mc_entropy: int = 1, + return_tensorclass: bool = False, ) -> None: self._in_keys = None self._out_keys = None @@ -146,6 +162,7 @@ def __init__( ) self.samples_mc_entropy = samples_mc_entropy + self.return_tensorclass = return_tensorclass self._set_in_keys() def _set_in_keys(self): @@ -223,7 +240,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "entropy": entropy.detach(), "alpha": self.alpha.detach(), } - return TensorDict(out, []) + td_out = TensorDict(out, []) + if self.return_tensorclass: + return OnlineDTLosses._from_tensordict(td_out) + return td_out class DTLoss(LossModule): @@ -265,6 +285,7 @@ def __init__( actor_network: ProbabilisticActor, *, loss_function: str = "l2", + return_tensorclass: bool = False, ) -> None: self._in_keys = None self._out_keys = None @@ -277,6 +298,7 @@ def __init__( create_target_params=False, ) self.loss_function = loss_function + self.return_tensorclass = return_tensorclass def _set_in_keys(self): keys = self.actor_network.in_keys @@ -310,7 +332,7 @@ def out_keys(self, values): self._out_keys = values @dispatch - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + def forward(self, tensordict: TensorDictBase) -> OnlineDTLosses: """Compute the loss for the Online Decision Transformer.""" # extract action targets tensordict = tensordict.clone(False) @@ -328,4 +350,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: out = { "loss": loss, } - return TensorDict(out, []) + td_out = TensorDict(out, []) + if self.return_tensorclass: + return OnlineDTLosses._from_tensordict(td_out) + return td_out diff --git a/torchrl/objectives/dqn.py b/torchrl/objectives/dqn.py index 2298c262368..0f25dcbabad 100644 --- a/torchrl/objectives/dqn.py +++ b/torchrl/objectives/dqn.py @@ -2,12 +2,14 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import warnings from dataclasses import dataclass from typing import Optional, Union import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import tensorclass, TensorDict, TensorDictBase from tensordict.nn import dispatch from tensordict.utils import NestedKey from torch import nn @@ -33,6 +35,20 @@ from torchrl.objectives.value.advantages import TD0Estimator, TD1Estimator +@tensorclass +class DQNLosses: + """The tensorclass for The DQN Loss class.""" + + loss_objective: torch.Tensor + loss_critic: torch.Tensor | None = None + loss_entropy: torch.Tensor | None = None + entropy: torch.Tensor | None = None + + @property + def aggregate_loss(self): + return self.loss_critic + self.loss_objective + self.loss_entropy + + class DQNLoss(LossModule): """The DQN Loss class. @@ -171,6 +187,7 @@ def __init__( gamma: float = None, action_space: Union[str, TensorSpec] = None, priority_key: str = None, + return_tensorclass: bool = False, ) -> None: if delay_value is None: warnings.warn( @@ -225,6 +242,7 @@ def __init__( ) action_space = "one-hot" self.action_space = _find_action_space(action_space) + self.return_tensorclass = return_tensorclass if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) @@ -292,7 +310,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self._value_estimator.set_keys(**tensor_keys) @dispatch - def forward(self, tensordict: TensorDictBase) -> TensorDict: + def forward(self, tensordict: TensorDictBase) -> DQNLosses: """Computes the DQN loss given a tensordict sampled from the replay buffer. This function will also write a "td_error" key that can be used by prioritized replay buffers to assign @@ -362,7 +380,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: inplace=True, ) loss = distance_loss(pred_val_index, target_value, self.loss_function) - return TensorDict({"loss": loss.mean()}, []) + td_out = TensorDict({"loss": loss.mean()}, []) + if self.return_tensorclass: + return DQNLosses._from_tensordict(td_out) + return td_out class DistributionalDQNLoss(LossModule): @@ -435,6 +456,7 @@ def __init__( gamma: float, delay_value: bool = None, priority_key: str = None, + return_tensorclass: bool = False, ): if delay_value is None: warnings.warn( @@ -461,6 +483,7 @@ def __init__( create_target_params=self.delay_value, ) self.action_space = self.value_network.action_space + self.return_tensorclass = return_tensorclass def _forward_value_estimator_keys(self, **kwargs) -> None: pass @@ -483,7 +506,7 @@ def _log_ps_a_categorical(action, action_log_softmax): action = action.expand(new_shape) return torch.gather(action_log_softmax, -1, index=action).squeeze(-1) - def forward(self, input_tensordict: TensorDictBase) -> TensorDict: + def forward(self, input_tensordict: TensorDictBase) -> DQNLosses: # from https://github.com/Kaixhin/Rainbow/blob/9ff5567ad1234ae0ed30d8471e8f13ae07119395/agent.py tensordict = TensorDict( source=input_tensordict, batch_size=input_tensordict.batch_size @@ -597,6 +620,8 @@ def forward(self, input_tensordict: TensorDictBase) -> TensorDict: inplace=True, ) loss_td = TensorDict({"loss": loss.mean()}, []) + if self.return_tensorclass: + return DQNLosses._from_tensordict(loss_td) return loss_td def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): diff --git a/torchrl/objectives/dreamer.py b/torchrl/objectives/dreamer.py index 9fd8a8a0bd2..1d5f21b28f2 100644 --- a/torchrl/objectives/dreamer.py +++ b/torchrl/objectives/dreamer.py @@ -2,11 +2,13 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + from dataclasses import dataclass from typing import Optional, Tuple import torch -from tensordict import TensorDict +from tensordict import tensorclass, TensorDict from tensordict.nn import TensorDictModule from tensordict.utils import NestedKey @@ -23,6 +25,20 @@ from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator +@tensorclass +class DreamerModelLosses: + """The tensorclass for The Dreamer Model Loss class.""" + + loss_objective: torch.Tensor + loss_critic: torch.Tensor | None = None + loss_entropy: torch.Tensor | None = None + entropy: torch.Tensor | None = None + + @property + def aggregate_loss(self): + return self.loss_critic + self.loss_objective + self.loss_entropy + + class DreamerModelLoss(LossModule): """Dreamer Model Loss. @@ -99,6 +115,7 @@ def __init__( free_nats: int = 3, delayed_clamp: bool = False, global_average: bool = False, + return_tensorclass: bool = False, ): super().__init__() self.world_model = world_model @@ -110,6 +127,7 @@ def __init__( self.free_nats = free_nats self.delayed_clamp = delayed_clamp self.global_average = global_average + self.return_tensorclass = return_tensorclass def _forward_value_estimator_keys(self, **kwargs) -> None: pass @@ -238,6 +256,7 @@ def __init__( discount_loss: bool = False, # for consistency with paper gamma: int = None, lmbda: int = None, + return_tensorclass: bool = False, ): super().__init__() self.actor_model = actor_model @@ -249,6 +268,7 @@ def __init__( raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) if lmbda is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self.return_tensorclass = return_tensorclass def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: @@ -256,7 +276,9 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: value=self._tensor_keys.value, ) - def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: + def forward( + self, tensordict: TensorDict + ) -> Tuple[DreamerModelLosses, DreamerModelLosses]: with torch.no_grad(): tensordict = tensordict.select("state", self.tensor_keys.belief) tensordict = tensordict.reshape(-1) @@ -293,6 +315,10 @@ def forward(self, tensordict: TensorDict) -> Tuple[TensorDict, TensorDict]: else: actor_loss = -lambda_target.sum((-2, -1)).mean() loss_tensordict = TensorDict({"loss_actor": actor_loss}, []) + if self.return_tensorclass: + return DreamerModelLosses._from_tensordict( + loss_tensordict + ), DreamerModelLosses._from_tensordict(fake_data.detach()) return loss_tensordict, fake_data.detach() def lambda_target(self, reward: torch.Tensor, value: torch.Tensor) -> torch.Tensor: diff --git a/torchrl/objectives/iql.py b/torchrl/objectives/iql.py index 62d2a628af4..47eb76a791a 100644 --- a/torchrl/objectives/iql.py +++ b/torchrl/objectives/iql.py @@ -2,12 +2,14 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import warnings from dataclasses import dataclass from typing import Optional, Tuple, Union import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import tensorclass, TensorDict, TensorDictBase from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey from torch import Tensor @@ -26,6 +28,20 @@ from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator +@tensorclass +class IQLLosses: + """The tensorclass for The PPOLoss Loss class.""" + + loss_objective: torch.Tensor + loss_critic: torch.Tensor | None = None + loss_entropy: torch.Tensor | None = None + entropy: torch.Tensor | None = None + + @property + def aggregate_loss(self): + return self.loss_critic + self.loss_objective + self.loss_entropy + + class IQLLoss(LossModule): r"""TorchRL implementation of the IQL loss. @@ -236,6 +252,7 @@ def __init__( gamma: float = None, priority_key: str = None, separate_losses: bool = False, + return_tensorclass: bool = False, ) -> None: self._in_keys = None self._out_keys = None @@ -289,6 +306,7 @@ def __init__( self._vmap_qvalue_networkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness ) + self.return_tensorclass = return_tensorclass @property def device(self) -> torch.device: @@ -336,7 +354,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None: self._set_in_keys() @dispatch - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + def forward(self, tensordict: TensorDictBase) -> IQLLosses: shape = None if tensordict.ndimension() > 1: shape = tensordict.shape @@ -368,11 +386,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: "loss_value": loss_value.mean(), "entropy": entropy.mean(), } - - return TensorDict( + td_out = TensorDict( out, [], ) + if self.return_tensorclass: + return IQLLosses._from_tensordict(td_out) + return td_out def actor_loss(self, tensordict: TensorDictBase) -> Tensor: # KL loss diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index ac2244b9a23..f80d0ca0848 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -13,7 +13,7 @@ from typing import Tuple import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import tensorclass, TensorDict, TensorDictBase from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule from tensordict.utils import NestedKey from torch import distributions as d @@ -30,6 +30,20 @@ from .value import GAE, TD0Estimator, TD1Estimator, TDLambdaEstimator, VTrace +@tensorclass +class PPOLosses: + """The tensorclass for The PPOLoss Loss class.""" + + loss_objective: torch.Tensor + loss_critic: torch.Tensor | None = None + loss_entropy: torch.Tensor | None = None + entropy: torch.Tensor | None = None + + @property + def aggregate_loss(self): + return self.loss_critic + self.loss_objective + self.loss_entropy + + class PPOLoss(LossModule): """A parent PPO loss class. @@ -278,6 +292,7 @@ def __init__( functional: bool = True, actor: ProbabilisticTensorDictSequential = None, critic: ProbabilisticTensorDictSequential = None, + return_tensorclass: bool = False, ): if actor is not None: actor_network = actor @@ -336,6 +351,7 @@ def __init__( value_target=value_target_key, value=value_key, ) + self.return_tensorclass = return_tensorclass @property def functional(self): @@ -514,7 +530,7 @@ def _cached_critic_network_params_detached(self): return self.critic_network_params.detach() @dispatch - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + def forward(self, tensordict: TensorDictBase) -> PPOLosses: tensordict = tensordict.clone(False) advantage = tensordict.get(self.tensor_keys.advantage, None) if advantage is None: @@ -539,6 +555,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if self.critic_coef: loss_critic = self.loss_critic(tensordict).mean() td_out.set("loss_critic", loss_critic.mean()) + if self.return_tensorclass: + return PPOLosses._from_tensordict(td_out) return td_out def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 61aaf5990e4..829f8b073d3 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -2,13 +2,15 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import math from dataclasses import dataclass from numbers import Number from typing import Union import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import tensorclass, TensorDict, TensorDictBase from tensordict.nn import dispatch, TensorDictModule, TensorDictSequential from tensordict.utils import NestedKey @@ -29,6 +31,20 @@ from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator +@tensorclass +class REDQLosses: + """The tensorclass for The REDQLoss Loss class.""" + + loss_objective: torch.Tensor + loss_critic: torch.Tensor | None = None + loss_entropy: torch.Tensor | None = None + entropy: torch.Tensor | None = None + + @property + def aggregate_loss(self): + return self.loss_critic + self.loss_objective + self.loss_entropy + + class REDQLoss(LossModule): """REDQ Loss module. @@ -252,6 +268,7 @@ def __init__( gamma: float = None, priority_key: str = None, separate_losses: bool = False, + return_tensorclass: bool = False, ): super().__init__() self._in_keys = None @@ -320,6 +337,7 @@ def __init__( self._vmap_getdist = _vmap_func( self.actor_network, func="get_dist_params", randomness=self.vmap_randomness ) + self.return_tensorclass = return_tensorclass @property def target_entropy(self): @@ -421,7 +439,7 @@ def _qvalue_params_cat(self, selected_q_params): return qvalue_params @dispatch - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + def forward(self, tensordict: TensorDictBase) -> REDQLosses: obs_keys = self.actor_network.in_keys tensordict_select = tensordict.clone(False).select( "next", *obs_keys, self.tensor_keys.action @@ -565,6 +583,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: }, [], ) + if self.return_tensorclass: + return REDQLosses._from_tensordict(td_out) return td_out diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 9738b922c5d..9270cedcd67 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -10,7 +10,7 @@ from dataclasses import dataclass import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import tensorclass, TensorDict, TensorDictBase from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule from tensordict.utils import NestedKey @@ -30,6 +30,20 @@ ) +@tensorclass +class ReinforceLosses: + """The tensorclass for The Reinforce Loss class.""" + + loss_objective: torch.Tensor + loss_critic: torch.Tensor | None = None + loss_entropy: torch.Tensor | None = None + entropy: torch.Tensor | None = None + + @property + def aggregate_loss(self): + return self.loss_critic + self.loss_objective + self.loss_entropy + + class ReinforceLoss(LossModule): """Reinforce loss module. @@ -223,6 +237,7 @@ def __init__( functional: bool = True, actor: ProbabilisticTensorDictSequential = None, critic: ProbabilisticTensorDictSequential = None, + return_tensorclass: bool = False, ) -> None: if actor is not None: actor_network = actor @@ -281,6 +296,7 @@ def __init__( if gamma is not None: raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self.return_tensorclass = return_tensorclass @property def functional(self): @@ -366,7 +382,7 @@ def in_keys(self, values): self._in_keys = values @dispatch - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + def forward(self, tensordict: TensorDictBase) -> ReinforceLosses: advantage = tensordict.get(self.tensor_keys.advantage, None) if advantage is None: self.value_estimator( @@ -392,7 +408,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: td_out = TensorDict({"loss_actor": loss_actor}, []) td_out.set("loss_value", self.loss_critic(tensordict).mean()) - + if self.return_tensorclass: + return ReinforceLosses._from_tensordict(td_out) return td_out def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: diff --git a/torchrl/objectives/sac.py b/torchrl/objectives/sac.py index 5b722fd05f3..e26c847c358 100644 --- a/torchrl/objectives/sac.py +++ b/torchrl/objectives/sac.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import math import warnings from dataclasses import dataclass @@ -11,7 +13,7 @@ import numpy as np import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import tensorclass, TensorDict, TensorDictBase from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey @@ -43,6 +45,20 @@ def new_func(self, *args, **kwargs): return new_func +@tensorclass +class SACLosses: + """The tensorclass for The SACLoss Loss class.""" + + loss_objective: torch.Tensor + loss_critic: torch.Tensor | None = None + loss_entropy: torch.Tensor | None = None + entropy: torch.Tensor | None = None + + @property + def aggregate_loss(self): + return self.loss_critic + self.loss_objective + self.loss_entropy + + class SACLoss(LossModule): """TorchRL implementation of the SAC loss. @@ -280,6 +296,7 @@ def __init__( gamma: float = None, priority_key: str = None, separate_losses: bool = False, + return_tensorclass: bool = False, ) -> None: self._in_keys = None self._out_keys = None @@ -381,6 +398,7 @@ def __init__( self._vmap_qnetwork00 = _vmap_func( qvalue_network, randomness=self.vmap_randomness ) + self.return_tensorclass = return_tensorclass @property def target_entropy_buffer(self): @@ -532,7 +550,7 @@ def out_keys(self, values): self._out_keys = values @dispatch - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + def forward(self, tensordict: TensorDictBase) -> SACLosses: shape = None if tensordict.ndimension() > 1: shape = tensordict.shape @@ -567,7 +585,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: } if self._version == 1: out["loss_value"] = loss_value.mean() - return TensorDict(out, []) + td_out = TensorDict(out, []) + if self.return_tensorclass: + return SACLosses._from_tensordict(td_out) + return td_out @property @_cache_values diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 877a8f0c819..9aedfd37d46 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -2,12 +2,14 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + from dataclasses import dataclass from typing import Optional, Tuple import torch -from tensordict import TensorDict, TensorDictBase +from tensordict import tensorclass, TensorDict, TensorDictBase from tensordict.nn import dispatch, TensorDictModule from tensordict.utils import NestedKey from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec, TensorSpec @@ -26,6 +28,20 @@ from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator +@tensorclass +class TD3Losses: + """The tensorclass for The TD3 Loss class.""" + + loss_objective: torch.Tensor + loss_critic: torch.Tensor | None = None + loss_entropy: torch.Tensor | None = None + entropy: torch.Tensor | None = None + + @property + def aggregate_loss(self): + return self.loss_critic + self.loss_objective + self.loss_entropy + + class TD3Loss(LossModule): """TD3 Loss module. @@ -217,6 +233,7 @@ def __init__( gamma: float = None, priority_key: str = None, separate_losses: bool = False, + return_tensorclass: bool = False, ) -> None: super().__init__() self._in_keys = None @@ -299,6 +316,7 @@ def __init__( self._vmap_actor_network00 = _vmap_func( self.actor_network, randomness=self.vmap_randomness ) + self.return_tensorclass = return_tensorclass def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: @@ -447,7 +465,7 @@ def value_loss(self, tensordict): return loss_qval, metadata @dispatch - def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + def forward(self, tensordict: TensorDictBase) -> TD3Losses: tensordict_save = tensordict loss_actor, metadata_actor = self.actor_loss(tensordict) loss_qval, metadata_value = self.value_loss(tensordict_save) @@ -467,6 +485,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: }, batch_size=[], ) + if self.return_tensorclass: + return TD3Losses._from_tensordict(td_out) return td_out