Skip to content

[Feature] adding tensor classes annotation for loss functions #1905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e4761a3
adding tensor classes annotation for loss functions
SandishKumarHN Feb 12, 2024
5d432c8
review changes add doctests for tensorclass
SandishKumarHN Feb 17, 2024
387953f
review changes add doctests for tensorclass | merging to main
SandishKumarHN Feb 17, 2024
8e16b63
review changes add doctests for tensorclass | merging to main
SandishKumarHN Feb 17, 2024
bfb5930
Merge remote-tracking branch 'origin/main' into tensorclass-losses
vmoens Feb 21, 2024
bfde82f
amend
vmoens Feb 21, 2024
7473445
amend
vmoens Feb 22, 2024
8163f90
review changes add doctests for tensorclass | merging to main
SandishKumarHN Feb 22, 2024
b21e43e
build error fix
SandishKumarHN Feb 22, 2024
60e3d51
Reviewers:
SandishKumarHN Feb 22, 2024
44a70e6
Merge remote-tracking branch 'upstream/main' into tensorclass-losses
SandishKumarHN Feb 22, 2024
23ef8ea
build error fix
SandishKumarHN Feb 23, 2024
191ab2e
Merge remote-tracking branch 'upstream/main' into tensorclass-losses
SandishKumarHN Feb 23, 2024
1e373ca
build error fix, doc test aggregated func
SandishKumarHN Feb 23, 2024
3f058e1
build error fix, docstring formatted
SandishKumarHN Feb 23, 2024
582c9c5
build error fix, docstring formatted
SandishKumarHN Feb 24, 2024
79d8a29
Merge remote-tracking branch 'upstream/main' into tensorclass-losses
SandishKumarHN Feb 28, 2024
64837f9
flake8 errors
SandishKumarHN Feb 29, 2024
715d4c0
review changes - 1
SandishKumarHN Feb 29, 2024
5bb8894
Merge remote-tracking branch 'upstream/main' into tensorclass-losses
SandishKumarHN Mar 11, 2024
7c0ae77
compiler errors
SandishKumarHN Mar 11, 2024
e9125fb
Update torchrl/objectives/decision_transformer.py
SandishKumarHN Mar 12, 2024
bae4237
compiler errors
SandishKumarHN Mar 12, 2024
e17c91e
review changes
SandishKumarHN Mar 14, 2024
f07c4f4
Merge remote-tracking branch 'upstream/main' into tensorclass-losses
SandishKumarHN Mar 14, 2024
8b5e0ff
review changes
SandishKumarHN Mar 18, 2024
73a4dcd
Merge remote-tracking branch 'upstream/main' into tensorclass-losses
SandishKumarHN Mar 18, 2024
9b5f4e6
review changes
SandishKumarHN Mar 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ DQN

DQNLoss
DistributionalDQNLoss
DQNLosses

DDPG
----
Expand All @@ -101,6 +102,7 @@ DDPG
:template: rl_template_noinherit.rst

DDPGLoss
DDPGLosses

SAC
---
Expand All @@ -111,6 +113,7 @@ SAC

SACLoss
DiscreteSACLoss
SACLosses

REDQ
----
Expand All @@ -120,6 +123,7 @@ REDQ
:template: rl_template_noinherit.rst

REDQLoss
REDQLosses

IQL
----
Expand All @@ -130,6 +134,7 @@ IQL

IQLLoss
DiscreteIQLLoss
IQLLosses

CQL
----
Expand All @@ -140,6 +145,7 @@ CQL

CQLLoss
DiscreteCQLLoss
CQLLosses

DT
----
Expand All @@ -150,6 +156,7 @@ DT

DTLoss
OnlineDTLoss
DTLosses

TD3
----
Expand All @@ -159,6 +166,7 @@ TD3
:template: rl_template_noinherit.rst

TD3Loss
TD3Losses

PPO
---
Expand All @@ -170,6 +178,7 @@ PPO
PPOLoss
ClipPPOLoss
KLPENPPOLoss
PPOLosses

A2C
---
Expand All @@ -179,6 +188,7 @@ A2C
:template: rl_template_noinherit.rst

A2CLoss
A2CLosses

Reinforce
---------
Expand All @@ -188,6 +198,7 @@ Reinforce
:template: rl_template_noinherit.rst

ReinforceLoss
ReinforceLosses

Dreamer
-------
Expand All @@ -199,6 +210,7 @@ Dreamer
DreamerActorLoss
DreamerModelLoss
DreamerValueLoss
DreamerActorLosses

Multi-agent objectives
-----------------------
Expand Down
14 changes: 11 additions & 3 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ def test_dqn(self, delay_value, double_dqn, device, action_spec_type, td_est):
loss_function="l2",
delay_value=delay_value,
double_dqn=double_dqn,
return_tensorclass=False,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -799,7 +800,7 @@ def test_dqn_notensordict(
SoftUpdate(dqn_loss, eps=0.5)
loss_val = dqn_loss(**kwargs)
loss_val_td = dqn_loss(td)
torch.testing.assert_close(loss_val_td.get("loss"), loss_val)
torch.testing.assert_close(loss_val_td.get("loss_objective"), loss_val)

def test_distributional_dqn_tensordict_keys(self):
torch.manual_seed(self.seed)
Expand Down Expand Up @@ -883,7 +884,7 @@ def test_dqn_reduction(self, reduction):
for key in loss.keys():
if not key.startswith("loss"):
continue
assert loss[key].shape == torch.Size([])
assert loss[key].shape == torch.Size([2])

@pytest.mark.parametrize("atoms", range(4, 10))
@pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"])
Expand Down Expand Up @@ -911,7 +912,7 @@ def test_distributional_dqn_reduction(self, reduction, atoms):
for key in loss.keys():
if not key.startswith("loss"):
continue
assert loss[key].shape == torch.Size([])
assert loss[key].shape == torch.Size([2, 4])


class TestQMixer(LossModuleTestBase):
Expand Down Expand Up @@ -1565,6 +1566,7 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est):
loss_function="l2",
delay_actor=delay_actor,
delay_value=delay_value,
return_tensorclass=False,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -2232,6 +2234,7 @@ def test_td3(
noise_clip=noise_clip,
delay_actor=delay_actor,
delay_qvalue=delay_qvalue,
return_tensorclass=False,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -4457,6 +4460,7 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est):
num_qvalue_nets=num_qvalue,
loss_function="l2",
delay_qvalue=delay_qvalue,
return_tensorclass=False,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -5302,6 +5306,7 @@ def test_cql(
with_lagrange=with_lagrange,
delay_actor=delay_actor,
delay_qvalue=delay_qvalue,
return_tensorclass=False,
)

if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
Expand Down Expand Up @@ -7695,6 +7700,7 @@ def test_reinforce_value_net(
critic_network=value_net,
delay_value=delay_value,
functional=functional,
return_tensorclass=False,
)

td = TensorDict(
Expand Down Expand Up @@ -8361,6 +8367,7 @@ def test_dreamer_world_model(
reco_loss=reco_loss,
delayed_clamp=delayed_clamp,
free_nats=free_nats,
return_tensorclass=False,
)
loss_td, _ = loss_module(tensordict)
for loss_str, lmbda in zip(
Expand Down Expand Up @@ -9220,6 +9227,7 @@ def test_iql(
temperature=temperature,
expectile=expectile,
loss_function="l2",
return_tensorclass=False,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
self.num_samples = self._param.shape[-1]

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
return super().log_prob(value.argmax(dim=-1))
return super().log_prob(value.int().argmax(dim=-1))

@property
def mode(self) -> torch.Tensor:
Expand Down
49 changes: 45 additions & 4 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from typing import Tuple

import torch
from tensordict import TensorDict, TensorDictBase
from tensordict import tensorclass, TensorDict, TensorDictBase
from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.utils import NestedKey
from torch import distributions as d

from torchrl.objectives.common import LossModule
from torchrl.objectives.common import LossContainerBase, LossModule

from torchrl.objectives.utils import (
_cache_values,
Expand All @@ -36,6 +36,17 @@
)


@tensorclass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doens't it work if we make the base class a tensorclass?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, It doesn't work.

class A2CLosses(LossContainerBase):
"""The tensorclass for The A2CLoss Loss class."""

loss_actor: torch.Tensor
loss_objective: torch.Tensor
loss_critic: torch.Tensor | None = None
loss_entropy: torch.Tensor | None = None
entropy: torch.Tensor | None = None


class A2CLoss(LossModule):
"""TorchRL implementation of the A2C loss.

Expand Down Expand Up @@ -137,6 +148,16 @@ class A2CLoss(LossModule):
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> loss = A2CLoss(actor, value, loss_critic_type="l2", return_tensorclass=True)
>>> loss(data)
A2CLosses(
entropy=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_critic=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_entropy=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_objective=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
batch_size=torch.Size([]),
device=None,
is_shared=False)

This class is compatible with non-tensordict based modules too and can be
used without recurring to any tensordict-related primitive. In this case,
Expand Down Expand Up @@ -182,7 +203,7 @@ class A2CLoss(LossModule):
method.

Examples:
>>> loss.select_out_keys('loss_objective', 'loss_critic')
>>> _ = loss.select_out_keys('loss_objective', 'loss_critic')
>>> loss_obj, loss_critic = loss(
... observation = torch.randn(*batch, n_obs),
... action = spec.rand(batch),
Expand Down Expand Up @@ -248,6 +269,7 @@ def __init__(
functional: bool = True,
actor: ProbabilisticTensorDictSequential = None,
critic: ProbabilisticTensorDictSequential = None,
return_tensorclass: bool = False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be added to the docstrings

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

working on it.

Copy link
Author

@SandishKumarHN SandishKumarHN Feb 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vmoens added doctests for tensorclass changes. but I see some doctest issues and blockers. can you please help me resolve.

  • there are some existing doctest failures, we might need a separate task to address.
  • what would be the aggregate_loss for each loss within tensorclass?
  • there are some existing errors like
  1.    ```Cannot interpret 'torch.int64' as a data type```
    
  2.    ```'key "action_value" not found in TensorDict with keys [\'done\', \'logits\', \'observation\', \'reward\', \'state_value\', \'terminated\']' ```
    
  3.    ```NameError: name 'actor' is not defined```
    
  4. etc

reduction: str = None,
clip_value: float | None = None,
):
Expand Down Expand Up @@ -309,6 +331,21 @@ def __init__(
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self.loss_critic_type = loss_critic_type
self.return_tensorclass = return_tensorclass

if clip_value is not None:
if isinstance(clip_value, float):
clip_value = torch.tensor(clip_value)
elif isinstance(clip_value, torch.Tensor):
if clip_value.numel() != 1:
raise ValueError(
f"clip_value must be a float or a scalar tensor, got {clip_value}."
)
else:
raise ValueError(
f"clip_value must be a float or a scalar tensor, got {clip_value}."
)
self.register_buffer("clip_value", clip_value)

if clip_value is not None:
if isinstance(clip_value, float):
Expand Down Expand Up @@ -502,7 +539,7 @@ def _cached_detach_critic_network_params(self):
return self.critic_network_params.detach()

@dispatch()
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def forward(self, tensordict: TensorDictBase) -> A2CLosses | TensorDictBase:
tensordict = tensordict.clone(False)
advantage = tensordict.get(self.tensor_keys.advantage, None)
if advantage is None:
Expand All @@ -523,6 +560,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.critic_coef:
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
if self.return_tensorclass:
return A2CLosses._from_tensordict(td_out)
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
if value_clip_fraction is not None:
td_out.set("value_clip_fraction", value_clip_fraction)
td_out = td_out.named_apply(
Expand Down
15 changes: 13 additions & 2 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ def __init__(cls, name, bases, attr_dict):
cls.forward = set_exploration_type(ExplorationType.MODE)(cls.forward)


class LossContainerBase:
"""A Base Container class for loss class, which is a subclass of nn.Module."""

@property
def aggregate_loss(self):
"""Aggregate the loss across all losses."""
result = torch.zeros((), device=self.device)
for key in self.__dataclass_attr__:
if key.startswith("loss_"):
result += getattr(self, key)
return result


class LossModule(TensorDictModuleBase, metaclass=_LossMeta):
"""A parent class for RL losses.

Expand Down Expand Up @@ -252,7 +265,6 @@ def _compare_and_expand(param):
return param._apply_nest(
_compare_and_expand,
batch_size=[expand_dim, *param.shape],
filter_empty=False,
call_on_nested=True,
)
if not isinstance(param, nn.Parameter):
Expand All @@ -276,7 +288,6 @@ def _compare_and_expand(param):
params.apply(
_compare_and_expand,
batch_size=[expand_dim, *params.shape],
filter_empty=False,
call_on_nested=True,
),
no_convert=True,
Expand Down
Loading