-
Notifications
You must be signed in to change notification settings - Fork 328
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] adding tensor classes annotation for loss functions #1905
base: main
Are you sure you want to change the base?
Changes from 17 commits
e4761a3
5d432c8
387953f
8e16b63
bfb5930
bfde82f
7473445
8163f90
b21e43e
60e3d51
44a70e6
23ef8ea
191ab2e
1e373ca
3f058e1
582c9c5
79d8a29
64837f9
715d4c0
5bb8894
7c0ae77
e9125fb
bae4237
e17c91e
f07c4f4
8b5e0ff
73a4dcd
9b5f4e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2,14 +2,16 @@ | |||||
# | ||||||
# This source code is licensed under the MIT license found in the | ||||||
# LICENSE file in the root directory of this source tree. | ||||||
from __future__ import annotations | ||||||
|
||||||
import contextlib | ||||||
import warnings | ||||||
from copy import deepcopy | ||||||
from dataclasses import dataclass | ||||||
from typing import Tuple | ||||||
|
||||||
import torch | ||||||
from tensordict import TensorDict, TensorDictBase | ||||||
from tensordict import tensorclass, TensorDict, TensorDictBase | ||||||
from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule | ||||||
from tensordict.utils import NestedKey | ||||||
from torch import distributions as d | ||||||
|
@@ -33,6 +35,34 @@ | |||||
) | ||||||
|
||||||
|
||||||
class LossContainerBase: | ||||||
"""ContainerBase class loss tensorclass's.""" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That isn't very explicit. We should say what this class is about. Also I think it should live in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm also wondering if we should not just make the base a tensorclass and inherit from it without creating new tensorclasses? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I try to make the base a tensorclass getting below error.
|
||||||
|
||||||
__getitem__ = TensorDictBase.__getitem__ | ||||||
|
||||||
def aggregate_loss(self): | ||||||
result = 0.0 | ||||||
for key in self.__dataclass_attr__: | ||||||
if key.startswith("loss_"): | ||||||
result += getattr(self, key) | ||||||
return result | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be a property result = torch.zeros((), device=self.device)
...
return result There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing docstring for this method. |
||||||
|
||||||
|
||||||
@tensorclass | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. doens't it work if we make the base class a tensorclass? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, It doesn't work. |
||||||
class A2CLosses(LossContainerBase): | ||||||
"""The tensorclass for The A2CLoss Loss class.""" | ||||||
|
||||||
loss_actor: torch.Tensor | ||||||
loss_objective: torch.Tensor | ||||||
loss_critic: torch.Tensor | None = None | ||||||
loss_entropy: torch.Tensor | None = None | ||||||
entropy: torch.Tensor | None = None | ||||||
|
||||||
@property | ||||||
def aggregate_loss(self): | ||||||
return self.loss_critic + self.loss_objective + self.loss_entropy | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to recode this |
||||||
|
||||||
|
||||||
class A2CLoss(LossModule): | ||||||
"""TorchRL implementation of the A2C loss. | ||||||
|
||||||
|
@@ -129,6 +159,16 @@ class A2CLoss(LossModule): | |||||
batch_size=torch.Size([]), | ||||||
device=None, | ||||||
is_shared=False) | ||||||
>>> loss = A2CLoss(actor, value, loss_critic_type="l2", return_tensorclass=True) | ||||||
>>> loss(data) | ||||||
A2CLosses( | ||||||
entropy=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), | ||||||
loss_critic=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), | ||||||
loss_entropy=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), | ||||||
loss_objective=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), | ||||||
batch_size=torch.Size([]), | ||||||
device=None, | ||||||
is_shared=False) | ||||||
|
||||||
This class is compatible with non-tensordict based modules too and can be | ||||||
used without recurring to any tensordict-related primitive. In this case, | ||||||
|
@@ -174,7 +214,7 @@ class A2CLoss(LossModule): | |||||
method. | ||||||
|
||||||
Examples: | ||||||
>>> loss.select_out_keys('loss_objective', 'loss_critic') | ||||||
>>> _ = loss.select_out_keys('loss_objective', 'loss_critic') | ||||||
>>> loss_obj, loss_critic = loss( | ||||||
... observation = torch.randn(*batch, n_obs), | ||||||
... action = spec.rand(batch), | ||||||
|
@@ -240,6 +280,7 @@ def __init__( | |||||
functional: bool = True, | ||||||
actor: ProbabilisticTensorDictSequential = None, | ||||||
critic: ProbabilisticTensorDictSequential = None, | ||||||
return_tensorclass: bool = False, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be added to the docstrings There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. working on it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vmoens added doctests for tensorclass changes. but I see some doctest issues and blockers. can you please help me resolve.
|
||||||
reduction: str = None, | ||||||
): | ||||||
if actor is not None: | ||||||
|
@@ -300,6 +341,7 @@ def __init__( | |||||
if gamma is not None: | ||||||
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) | ||||||
self.loss_critic_type = loss_critic_type | ||||||
self.return_tensorclass = return_tensorclass | ||||||
|
||||||
@property | ||||||
def functional(self): | ||||||
|
@@ -455,7 +497,7 @@ def _cached_detach_critic_network_params(self): | |||||
return self.critic_network_params.detach() | ||||||
|
||||||
@dispatch() | ||||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase: | ||||||
def forward(self, tensordict: TensorDictBase) -> A2CLosses: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
tensordict = tensordict.clone(False) | ||||||
advantage = tensordict.get(self.tensor_keys.advantage, None) | ||||||
if advantage is None: | ||||||
|
@@ -474,6 +516,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: | |||||
td_out.set("entropy", entropy.detach().mean()) # for logging | ||||||
td_out.set("loss_entropy", -self.entropy_coef * entropy) | ||||||
if self.critic_coef: | ||||||
loss_critic = self.loss_critic(tensordict).mean() | ||||||
td_out.set("loss_critic", loss_critic.mean()) | ||||||
if self.return_tensorclass: | ||||||
return A2CLosses._from_tensordict(td_out) | ||||||
loss_critic = self.loss_critic(tensordict) | ||||||
td_out.set("loss_critic", loss_critic) | ||||||
td_out = td_out.named_apply( | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,8 @@ | |
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from __future__ import annotations | ||
|
||
import math | ||
import warnings | ||
from copy import deepcopy | ||
|
@@ -12,7 +14,7 @@ | |
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from tensordict import TensorDict, TensorDictBase | ||
from tensordict import tensorclass, TensorDict, TensorDictBase | ||
from tensordict.nn import dispatch, TensorDictModule | ||
from tensordict.utils import NestedKey, unravel_key | ||
from torch import Tensor | ||
|
@@ -36,6 +38,32 @@ | |
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator | ||
|
||
|
||
class LossContainerBase: | ||
"""ContainerBase class loss tensorclass's.""" | ||
|
||
__getitem__ = TensorDictBase.__getitem__ | ||
|
||
def aggregate_loss(self): | ||
result = 0.0 | ||
for key in self.__dataclass_attr__: | ||
if key.startswith("loss_"): | ||
result += getattr(self, key) | ||
return result | ||
|
||
|
||
@tensorclass | ||
class CQLLosses(LossContainerBase): | ||
"""The tensorclass for The CQLLoss Loss class.""" | ||
|
||
alpha: torch.Tensor | ||
loss_actor: torch.Tensor | None = None | ||
loss_actor_bc: torch.Tensor | None = None | ||
loss_qvalue: torch.Tensor | None = None | ||
entropy: torch.Tensor | None = None | ||
loss_alpha: torch.Tensor | None = None | ||
loss_cql: torch.Tensor | None = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
|
||
|
||
class CQLLoss(LossModule): | ||
"""TorchRL implementation of the continuous CQL loss. | ||
|
||
|
@@ -129,12 +157,27 @@ class CQLLoss(LossModule): | |
entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), | ||
loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), | ||
loss_actor_bc: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), | ||
loss_actor_bc: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), | ||
loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), | ||
loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), | ||
loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), | ||
loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, | ||
batch_size=torch.Size([]), | ||
device=None, | ||
is_shared=False) | ||
>>> loss = CQLLoss(actor, qvalue, return_tensorclass=True) | ||
>>> loss(data) | ||
CQLLosses( | ||
alpha=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), | ||
entropy=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), | ||
loss_actor=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), | ||
loss_actor_bc=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), | ||
loss_alpha=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), | ||
loss_cql=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), | ||
loss_qvalue=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), | ||
batch_size=torch.Size([]), | ||
device=None, | ||
is_shared=False) | ||
|
||
This class is compatible with non-tensordict based modules too and can be | ||
used without recurring to any tensordict-related primitive. In this case, | ||
|
@@ -174,20 +217,21 @@ class CQLLoss(LossModule): | |
>>> loss = CQLLoss(actor, qvalue) | ||
>>> batch = [2, ] | ||
>>> action = spec.rand(batch) | ||
>>> loss_actor, loss_actor_bc, loss_qvalue, loss_cql, *_ = loss( | ||
>>> loss_actor, loss_qvalue, _, _, _, _ = loss( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this really the output now?
|
||
... observation=torch.randn(*batch, n_obs), | ||
... action=action, | ||
... next_done=torch.zeros(*batch, 1, dtype=torch.bool), | ||
... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), | ||
... next_observation=torch.zeros(*batch, n_obs), | ||
... next_reward=torch.randn(*batch, 1)) | ||
... next_reward=torch.randn(*batch, 1), | ||
... ) | ||
>>> loss_actor.backward() | ||
|
||
The output keys can also be filtered using the :meth:`CQLLoss.select_out_keys` | ||
method. | ||
|
||
Examples: | ||
>>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue') | ||
>>> loss.select_out_keys('loss_actor', 'loss_qvalue') | ||
>>> loss_actor, loss_qvalue = loss( | ||
... observation=torch.randn(*batch, n_obs), | ||
... action=action, | ||
|
@@ -271,6 +315,7 @@ def __init__( | |
num_random: int = 10, | ||
with_lagrange: bool = False, | ||
lagrange_thresh: float = 0.0, | ||
return_tensorclass: bool = False, | ||
) -> None: | ||
self._out_keys = None | ||
super().__init__() | ||
|
@@ -356,6 +401,7 @@ def __init__( | |
self._vmap_qvalue_network00 = _vmap_func( | ||
self.qvalue_network, randomness=self.vmap_randomness | ||
) | ||
self.return_tensorclass = return_tensorclass | ||
|
||
@property | ||
def target_entropy(self): | ||
|
@@ -524,7 +570,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: | |
} | ||
if self.with_lagrange: | ||
out["loss_alpha_prime"] = alpha_prime_loss.mean() | ||
return TensorDict(out, []) | ||
td_out = TensorDict(out, []) | ||
if self.return_tensorclass: | ||
return CQLLosses._from_tensordict(td_out) | ||
return td_out | ||
|
||
@property | ||
@_cache_values | ||
|
@@ -1007,6 +1056,7 @@ def __init__( | |
delay_value: bool = True, | ||
gamma: float = None, | ||
action_space=None, | ||
return_tensorclass: bool = False, | ||
) -> None: | ||
super().__init__() | ||
self._in_keys = None | ||
|
@@ -1047,6 +1097,7 @@ def __init__( | |
|
||
if gamma is not None: | ||
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) | ||
self.return_tensorclass = return_tensorclass | ||
|
||
def _forward_value_estimator_keys(self, **kwargs) -> None: | ||
if self._value_estimator is not None: | ||
|
@@ -1178,7 +1229,7 @@ def value_loss( | |
return loss, metadata | ||
|
||
@dispatch | ||
def forward(self, tensordict: TensorDictBase) -> TensorDict: | ||
def forward(self, tensordict: TensorDictBase) -> CQLLosses: | ||
"""Computes the (DQN) CQL loss given a tensordict sampled from the replay buffer. | ||
|
||
This function will also write a "td_error" key that can be used by prioritized replay buffers to assign | ||
|
@@ -1203,6 +1254,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: | |
source=source, | ||
batch_size=[], | ||
) | ||
if self.return_tensorclass: | ||
return CQLLosses._from_tensordict(td_out) | ||
|
||
return td_out | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure why we need these