-
Notifications
You must be signed in to change notification settings - Fork 325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] adding tensor classes annotation for loss functions #1905
base: main
Are you sure you want to change the base?
Changes from 1 commit
e4761a3
5d432c8
387953f
8e16b63
bfb5930
bfde82f
7473445
8163f90
b21e43e
60e3d51
44a70e6
23ef8ea
191ab2e
1e373ca
3f058e1
582c9c5
79d8a29
64837f9
715d4c0
5bb8894
7c0ae77
e9125fb
bae4237
e17c91e
f07c4f4
8b5e0ff
73a4dcd
9b5f4e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2,14 +2,16 @@ | |||||
# | ||||||
# This source code is licensed under the MIT license found in the | ||||||
# LICENSE file in the root directory of this source tree. | ||||||
from __future__ import annotations | ||||||
|
||||||
import contextlib | ||||||
import warnings | ||||||
from copy import deepcopy | ||||||
from dataclasses import dataclass | ||||||
from typing import Tuple | ||||||
|
||||||
import torch | ||||||
from tensordict import TensorDict, TensorDictBase | ||||||
from tensordict import tensorclass, TensorDict, TensorDictBase | ||||||
from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule | ||||||
from tensordict.utils import NestedKey | ||||||
from torch import distributions as d | ||||||
|
@@ -31,6 +33,20 @@ | |||||
) | ||||||
|
||||||
|
||||||
@tensorclass | ||||||
class A2CLosses: | ||||||
"""The tensorclass for The A2CLoss Loss class.""" | ||||||
|
||||||
loss_objective: torch.Tensor | ||||||
loss_critic: torch.Tensor | None = None | ||||||
loss_entropy: torch.Tensor | None = None | ||||||
entropy: torch.Tensor | None = None | ||||||
|
||||||
@property | ||||||
def aggregate_loss(self): | ||||||
return self.loss_critic + self.loss_objective + self.loss_entropy | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to recode this |
||||||
|
||||||
|
||||||
class A2CLoss(LossModule): | ||||||
"""TorchRL implementation of the A2C loss. | ||||||
|
||||||
|
@@ -234,6 +250,7 @@ def __init__( | |||||
functional: bool = True, | ||||||
actor: ProbabilisticTensorDictSequential = None, | ||||||
critic: ProbabilisticTensorDictSequential = None, | ||||||
return_tensorclass: bool = False, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should be added to the docstrings There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. working on it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vmoens added doctests for tensorclass changes. but I see some doctest issues and blockers. can you please help me resolve.
|
||||||
): | ||||||
if actor is not None: | ||||||
actor_network = actor | ||||||
|
@@ -290,6 +307,7 @@ def __init__( | |||||
if gamma is not None: | ||||||
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) | ||||||
self.loss_critic_type = loss_critic_type | ||||||
self.return_tensorclass = return_tensorclass | ||||||
|
||||||
@property | ||||||
def functional(self): | ||||||
|
@@ -445,7 +463,7 @@ def _cached_detach_critic_network_params(self): | |||||
return self.critic_network_params.detach() | ||||||
|
||||||
@dispatch() | ||||||
def forward(self, tensordict: TensorDictBase) -> TensorDictBase: | ||||||
def forward(self, tensordict: TensorDictBase) -> A2CLosses: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
tensordict = tensordict.clone(False) | ||||||
advantage = tensordict.get(self.tensor_keys.advantage, None) | ||||||
if advantage is None: | ||||||
|
@@ -466,6 +484,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: | |||||
if self.critic_coef: | ||||||
loss_critic = self.loss_critic(tensordict).mean() | ||||||
td_out.set("loss_critic", loss_critic.mean()) | ||||||
if self.return_tensorclass: | ||||||
return A2CLosses._from_tensordict(td_out) | ||||||
return td_out | ||||||
|
||||||
def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
from typing import Tuple | ||
|
||
import torch | ||
from tensordict import TensorDict, TensorDictBase | ||
from tensordict import tensorclass, TensorDict, TensorDictBase | ||
from tensordict.nn import dispatch, TensorDictModule | ||
|
||
from tensordict.utils import NestedKey, unravel_key | ||
|
@@ -26,6 +26,20 @@ | |
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator | ||
|
||
|
||
@tensorclass | ||
class DDPGLosses: | ||
"""The tensorclass for The DDPGLoss class.""" | ||
|
||
loss_objective: torch.Tensor | ||
loss_critic: torch.Tensor | None = None | ||
loss_entropy: torch.Tensor | None = None | ||
entropy: torch.Tensor | None = None | ||
|
||
@property | ||
def aggregate_loss(self): | ||
return self.loss_critic + self.loss_objective + self.loss_entropy | ||
|
||
|
||
class DDPGLoss(LossModule): | ||
"""The DDPG Loss class. | ||
|
||
|
@@ -189,6 +203,7 @@ def __init__( | |
delay_value: bool = True, | ||
gamma: float = None, | ||
separate_losses: bool = False, | ||
return_tensorclass: bool = False, | ||
) -> None: | ||
self._in_keys = None | ||
super().__init__() | ||
|
@@ -229,6 +244,7 @@ def __init__( | |
) | ||
|
||
self.loss_function = loss_function | ||
self.return_tensorclass = return_tensorclass | ||
|
||
if gamma is not None: | ||
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) | ||
|
@@ -266,7 +282,7 @@ def in_keys(self, values): | |
self._in_keys = values | ||
|
||
@dispatch | ||
def forward(self, tensordict: TensorDictBase) -> TensorDict: | ||
def forward(self, tensordict: TensorDictBase) -> DDPGLosses: | ||
"""Computes the DDPG losses given a tensordict sampled from the replay buffer. | ||
|
||
This function will also write a "td_error" key that can be used by prioritized replay buffers to assign | ||
|
@@ -283,10 +299,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: | |
loss_value, metadata = self.loss_value(tensordict) | ||
loss_actor, metadata_actor = self.loss_actor(tensordict) | ||
metadata.update(metadata_actor) | ||
return TensorDict( | ||
td_out = TensorDict( | ||
source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata}, | ||
batch_size=[], | ||
) | ||
if self.return_tensorclass: | ||
return DDPGLosses._from_tensordict(td_out) | ||
return td_out | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there are 2 return one after the other |
||
|
||
def loss_actor( | ||
self, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doens't it work if we make the base class a tensorclass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, It doesn't work.