Skip to content

Commit

Permalink
adding tensor classes annotation for loss functions
Browse files Browse the repository at this point in the history
  • Loading branch information
SandishKumarHN committed Feb 13, 2024
1 parent 899af07 commit e4761a3
Show file tree
Hide file tree
Showing 13 changed files with 309 additions and 36 deletions.
21 changes: 19 additions & 2 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,11 @@ def test_dqn(self, delay_value, double_dqn, device, action_spec_type, td_est):
action_spec_type=action_spec_type, device=device
)
loss_fn = DQNLoss(
actor, loss_function="l2", delay_value=delay_value, double_dqn=double_dqn
actor,
loss_function="l2",
delay_value=delay_value,
double_dqn=double_dqn,
return_tensorclass=False,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -1490,6 +1494,7 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est):
loss_function="l2",
delay_actor=delay_actor,
delay_value=delay_value,
return_tensorclass=False,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -2118,6 +2123,7 @@ def test_td3(
noise_clip=noise_clip,
delay_actor=delay_actor,
delay_qvalue=delay_qvalue,
return_tensorclass=False,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -4216,6 +4222,7 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est):
num_qvalue_nets=num_qvalue,
loss_function="l2",
delay_qvalue=delay_qvalue,
return_tensorclass=False,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -5013,6 +5020,7 @@ def test_cql(
with_lagrange=with_lagrange,
delay_actor=delay_actor,
delay_qvalue=delay_qvalue,
return_tensorclass=False,
)

if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
Expand Down Expand Up @@ -6648,7 +6656,13 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
else:
raise NotImplementedError

loss_fn = A2CLoss(actor, value, loss_critic_type="l2", functional=functional)
loss_fn = A2CLoss(
actor,
value,
loss_critic_type="l2",
functional=functional,
return_tensorclass=False,
)

# Check error is raised when actions require grads
td["action"].requires_grad = True
Expand Down Expand Up @@ -7113,6 +7127,7 @@ def test_reinforce_value_net(
critic_network=value_net,
delay_value=delay_value,
functional=functional,
return_tensorclass=False,
)

td = TensorDict(
Expand Down Expand Up @@ -7705,6 +7720,7 @@ def test_dreamer_world_model(
reco_loss=reco_loss,
delayed_clamp=delayed_clamp,
free_nats=free_nats,
return_tensorclass=False,
)
loss_td, _ = loss_module(tensordict)
for loss_str, lmbda in zip(
Expand Down Expand Up @@ -8525,6 +8541,7 @@ def test_iql(
temperature=temperature,
expectile=expectile,
loss_function="l2",
return_tensorclass=False,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down
24 changes: 22 additions & 2 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import contextlib
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple

import torch
from tensordict import TensorDict, TensorDictBase
from tensordict import tensorclass, TensorDict, TensorDictBase
from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.utils import NestedKey
from torch import distributions as d
Expand All @@ -31,6 +33,20 @@
)


@tensorclass
class A2CLosses:
"""The tensorclass for The A2CLoss Loss class."""

loss_objective: torch.Tensor
loss_critic: torch.Tensor | None = None
loss_entropy: torch.Tensor | None = None
entropy: torch.Tensor | None = None

@property
def aggregate_loss(self):
return self.loss_critic + self.loss_objective + self.loss_entropy


class A2CLoss(LossModule):
"""TorchRL implementation of the A2C loss.
Expand Down Expand Up @@ -234,6 +250,7 @@ def __init__(
functional: bool = True,
actor: ProbabilisticTensorDictSequential = None,
critic: ProbabilisticTensorDictSequential = None,
return_tensorclass: bool = False,
):
if actor is not None:
actor_network = actor
Expand Down Expand Up @@ -290,6 +307,7 @@ def __init__(
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self.loss_critic_type = loss_critic_type
self.return_tensorclass = return_tensorclass

@property
def functional(self):
Expand Down Expand Up @@ -445,7 +463,7 @@ def _cached_detach_critic_network_params(self):
return self.critic_network_params.detach()

@dispatch()
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def forward(self, tensordict: TensorDictBase) -> A2CLosses:
tensordict = tensordict.clone(False)
advantage = tensordict.get(self.tensor_keys.advantage, None)
if advantage is None:
Expand All @@ -466,6 +484,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.critic_coef:
loss_critic = self.loss_critic(tensordict).mean()
td_out.set("loss_critic", loss_critic.mean())
if self.return_tensorclass:
return A2CLosses._from_tensordict(td_out)
return td_out

def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
Expand Down
31 changes: 28 additions & 3 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import math
import warnings
from copy import deepcopy
Expand All @@ -12,7 +14,7 @@
import numpy as np
import torch
import torch.nn as nn
from tensordict import TensorDict, TensorDictBase
from tensordict import tensorclass, TensorDict, TensorDictBase
from tensordict.nn import dispatch, TensorDictModule
from tensordict.utils import NestedKey, unravel_key
from torch import Tensor
Expand All @@ -36,6 +38,20 @@
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator


@tensorclass
class CQLLosses:
"""The tensorclass for The CQLLoss Loss class."""

loss_objective: torch.Tensor
loss_critic: torch.Tensor | None = None
loss_entropy: torch.Tensor | None = None
entropy: torch.Tensor | None = None

@property
def aggregate_loss(self):
return self.loss_critic + self.loss_objective + self.loss_entropy


class CQLLoss(LossModule):
"""TorchRL implementation of the continuous CQL loss.
Expand Down Expand Up @@ -269,6 +285,7 @@ def __init__(
num_random: int = 10,
with_lagrange: bool = False,
lagrange_thresh: float = 0.0,
return_tensorclass: bool = False,
) -> None:
self._out_keys = None
super().__init__()
Expand Down Expand Up @@ -354,6 +371,7 @@ def __init__(
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
self.return_tensorclass = return_tensorclass

@property
def target_entropy(self):
Expand Down Expand Up @@ -521,7 +539,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
}
if self.with_lagrange:
out["loss_alpha_prime"] = alpha_prime_loss.mean()
return TensorDict(out, [])
td_out = TensorDict(out, [])
if self.return_tensorclass:
return CQLLosses._from_tensordict(td_out)
return td_out

@property
@_cache_values
Expand Down Expand Up @@ -1000,6 +1021,7 @@ def __init__(
delay_value: bool = True,
gamma: float = None,
action_space=None,
return_tensorclass: bool = False,
) -> None:
super().__init__()
self._in_keys = None
Expand Down Expand Up @@ -1040,6 +1062,7 @@ def __init__(

if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self.return_tensorclass = return_tensorclass

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
Expand Down Expand Up @@ -1171,7 +1194,7 @@ def value_loss(
return loss, metadata

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDict:
def forward(self, tensordict: TensorDictBase) -> CQLLosses:
"""Computes the (DQN) CQL loss given a tensordict sampled from the replay buffer.
This function will also write a "td_error" key that can be used by prioritized replay buffers to assign
Expand All @@ -1196,6 +1219,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
source=source,
batch_size=[],
)
if self.return_tensorclass:
return CQLLosses._from_tensordict(td_out)

return td_out

Expand Down
25 changes: 22 additions & 3 deletions torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Tuple

import torch
from tensordict import TensorDict, TensorDictBase
from tensordict import tensorclass, TensorDict, TensorDictBase
from tensordict.nn import dispatch, TensorDictModule

from tensordict.utils import NestedKey, unravel_key
Expand All @@ -26,6 +26,20 @@
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator


@tensorclass
class DDPGLosses:
"""The tensorclass for The DDPGLoss class."""

loss_objective: torch.Tensor
loss_critic: torch.Tensor | None = None
loss_entropy: torch.Tensor | None = None
entropy: torch.Tensor | None = None

@property
def aggregate_loss(self):
return self.loss_critic + self.loss_objective + self.loss_entropy


class DDPGLoss(LossModule):
"""The DDPG Loss class.
Expand Down Expand Up @@ -189,6 +203,7 @@ def __init__(
delay_value: bool = True,
gamma: float = None,
separate_losses: bool = False,
return_tensorclass: bool = False,
) -> None:
self._in_keys = None
super().__init__()
Expand Down Expand Up @@ -229,6 +244,7 @@ def __init__(
)

self.loss_function = loss_function
self.return_tensorclass = return_tensorclass

if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
Expand Down Expand Up @@ -266,7 +282,7 @@ def in_keys(self, values):
self._in_keys = values

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDict:
def forward(self, tensordict: TensorDictBase) -> DDPGLosses:
"""Computes the DDPG losses given a tensordict sampled from the replay buffer.
This function will also write a "td_error" key that can be used by prioritized replay buffers to assign
Expand All @@ -283,10 +299,13 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
loss_value, metadata = self.loss_value(tensordict)
loss_actor, metadata_actor = self.loss_actor(tensordict)
metadata.update(metadata_actor)
return TensorDict(
td_out = TensorDict(
source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata},
batch_size=[],
)
if self.return_tensorclass:
return DDPGLosses._from_tensordict(td_out)
return td_out

def loss_actor(
self,
Expand Down
Loading

0 comments on commit e4761a3

Please sign in to comment.