Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] adding tensor classes annotation for loss functions #1905

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e4761a3
adding tensor classes annotation for loss functions
SandishKumarHN Feb 12, 2024
5d432c8
review changes add doctests for tensorclass
SandishKumarHN Feb 17, 2024
387953f
review changes add doctests for tensorclass | merging to main
SandishKumarHN Feb 17, 2024
8e16b63
review changes add doctests for tensorclass | merging to main
SandishKumarHN Feb 17, 2024
bfb5930
Merge remote-tracking branch 'origin/main' into tensorclass-losses
vmoens Feb 21, 2024
bfde82f
amend
vmoens Feb 21, 2024
7473445
amend
vmoens Feb 22, 2024
8163f90
review changes add doctests for tensorclass | merging to main
SandishKumarHN Feb 22, 2024
b21e43e
build error fix
SandishKumarHN Feb 22, 2024
60e3d51
Reviewers:
SandishKumarHN Feb 22, 2024
44a70e6
Merge remote-tracking branch 'upstream/main' into tensorclass-losses
SandishKumarHN Feb 22, 2024
23ef8ea
build error fix
SandishKumarHN Feb 23, 2024
191ab2e
Merge remote-tracking branch 'upstream/main' into tensorclass-losses
SandishKumarHN Feb 23, 2024
1e373ca
build error fix, doc test aggregated func
SandishKumarHN Feb 23, 2024
3f058e1
build error fix, docstring formatted
SandishKumarHN Feb 23, 2024
582c9c5
build error fix, docstring formatted
SandishKumarHN Feb 24, 2024
79d8a29
Merge remote-tracking branch 'upstream/main' into tensorclass-losses
SandishKumarHN Feb 28, 2024
64837f9
flake8 errors
SandishKumarHN Feb 29, 2024
715d4c0
review changes - 1
SandishKumarHN Feb 29, 2024
5bb8894
Merge remote-tracking branch 'upstream/main' into tensorclass-losses
SandishKumarHN Mar 11, 2024
7c0ae77
compiler errors
SandishKumarHN Mar 11, 2024
e9125fb
Update torchrl/objectives/decision_transformer.py
SandishKumarHN Mar 12, 2024
bae4237
compiler errors
SandishKumarHN Mar 12, 2024
e17c91e
review changes
SandishKumarHN Mar 14, 2024
f07c4f4
Merge remote-tracking branch 'upstream/main' into tensorclass-losses
SandishKumarHN Mar 14, 2024
8b5e0ff
review changes
SandishKumarHN Mar 18, 2024
73a4dcd
Merge remote-tracking branch 'upstream/main' into tensorclass-losses
SandishKumarHN Mar 18, 2024
9b5f4e6
review changes
SandishKumarHN Mar 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/dqn/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def main(cfg: "DictConfig"): # noqa: F821
sampled_tensordict = sampled_tensordict.to(device)

loss_td = loss_module(sampled_tensordict)
q_loss = loss_td["loss"]
q_loss = loss_td["loss_objective"]
optimizer.zero_grad()
q_loss.backward()
torch.nn.utils.clip_grad_norm_(
Expand Down
2 changes: 1 addition & 1 deletion examples/multiagent/iql.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def train(cfg: "DictConfig"): # noqa: F821
loss_vals = loss_module(subdata)
training_tds.append(loss_vals.detach())

loss_value = loss_vals["loss"]
loss_value = loss_vals["loss_objective"]

loss_value.backward()

Expand Down
14 changes: 13 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,13 @@ def test_dqn(self, delay_value, double_dqn, device, action_spec_type, td_est):
)
loss_fn = DQNLoss(
actor,

loss_function="l2",

delay_value=delay_value,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why we need these

double_dqn=double_dqn,
return_tensorclass=False,,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -799,7 +803,7 @@ def test_dqn_notensordict(
SoftUpdate(dqn_loss, eps=0.5)
loss_val = dqn_loss(**kwargs)
loss_val_td = dqn_loss(td)
torch.testing.assert_close(loss_val_td.get("loss"), loss_val)
torch.testing.assert_close(loss_val_td.get("loss_objective"), loss_val)

def test_distributional_dqn_tensordict_keys(self):
torch.manual_seed(self.seed)
Expand Down Expand Up @@ -1565,6 +1569,7 @@ def test_ddpg(self, delay_actor, delay_value, device, td_est):
loss_function="l2",
delay_actor=delay_actor,
delay_value=delay_value,
return_tensorclass=False,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -2232,6 +2237,7 @@ def test_td3(
noise_clip=noise_clip,
delay_actor=delay_actor,
delay_qvalue=delay_qvalue,
return_tensorclass=False,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -4457,6 +4463,7 @@ def test_redq(self, delay_qvalue, num_qvalue, device, td_est):
num_qvalue_nets=num_qvalue,
loss_function="l2",
delay_qvalue=delay_qvalue,
return_tensorclass=False,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down Expand Up @@ -5302,6 +5309,7 @@ def test_cql(
with_lagrange=with_lagrange,
delay_actor=delay_actor,
delay_qvalue=delay_qvalue,
return_tensorclass=False,
)

if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
Expand Down Expand Up @@ -7027,6 +7035,7 @@ def test_a2c(self, device, gradient_mode, advantage, td_est, functional):
value,
loss_critic_type="l2",
functional=functional,
reduction=reduction,
)

# Check error is raised when actions require grads
Expand Down Expand Up @@ -7529,6 +7538,7 @@ def test_reinforce_value_net(
critic_network=value_net,
delay_value=delay_value,
functional=functional,
return_tensorclass=False,
)

td = TensorDict(
Expand Down Expand Up @@ -8143,6 +8153,7 @@ def test_dreamer_world_model(
reco_loss=reco_loss,
delayed_clamp=delayed_clamp,
free_nats=free_nats,
return_tensorclass=False,
)
loss_td, _ = loss_module(tensordict)
for loss_str, lmbda in zip(
Expand Down Expand Up @@ -8963,6 +8974,7 @@ def test_iql(
temperature=temperature,
expectile=expectile,
loss_function="l2",
return_tensorclass=False,
)
if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace):
with pytest.raises(NotImplementedError):
Expand Down
2 changes: 1 addition & 1 deletion torchrl/modules/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
self.num_samples = self._param.shape[-1]

def log_prob(self, value: torch.Tensor) -> torch.Tensor:
return super().log_prob(value.argmax(dim=-1))
return super().log_prob(value.int().argmax(dim=-1))

@property
def mode(self) -> torch.Tensor:
Expand Down
52 changes: 49 additions & 3 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import contextlib
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Tuple

import torch
from tensordict import TensorDict, TensorDictBase
from tensordict import tensorclass, TensorDict, TensorDictBase
from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.utils import NestedKey
from torch import distributions as d
Expand All @@ -33,6 +35,34 @@
)


class LossContainerBase:
"""ContainerBase class loss tensorclass's."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That isn't very explicit. We should say what this class is about.

Also I think it should live in the common.py file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also wondering if we should not just make the base a tensorclass and inherit from it without creating new tensorclasses?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I try to make the base a tensorclass getting below error.

**********************************************************************
File "/home/sandish/rl/torchrl/objectives/a2c.py", line 144, in a2c.A2CLoss
Failed example:
    loss(data)
Exception raised:
    Traceback (most recent call last):
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/doctest.py", line 1334, in __run
        exec(compile(example.source, filename, "single",
      File "<doctest a2c.A2CLoss[21]>", line 1, in <module>
        loss(data)
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
        result = forward_call(*args, **kwargs)
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/site-packages/tensordict/_contextlib.py", line 126, in decorate_context
        return func(*args, **kwargs)
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/site-packages/tensordict/nn/common.py", line 291, in wrapper
        return func(_self, tensordict, *args, **kwargs)
      File "/home/sandish/rl/torchrl/objectives/a2c.py", line 503, in forward
        return A2CLosses._from_tensordict(td_out)
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/site-packages/tensordict/tensorclass.py", line 327, in wrapper
        raise ValueError(
    ValueError: Keys from the tensordict ({'loss_entropy', 'loss_objective', 'entropy', 'loss_critic'}) must correspond to the class attributes (set()).


__getitem__ = TensorDictBase.__getitem__

def aggregate_loss(self):
result = 0.0
for key in self.__dataclass_attr__:
if key.startswith("loss_"):
result += getattr(self, key)
return result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be a property
Should always return a tensor
Something like

result = torch.zeros((), device=self.device)
...
return result

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing docstring for this method.



@tensorclass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doens't it work if we make the base class a tensorclass?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, It doesn't work.

class A2CLosses(LossContainerBase):
"""The tensorclass for The A2CLoss Loss class."""

loss_actor: torch.Tensor
loss_objective: torch.Tensor
loss_critic: torch.Tensor | None = None
loss_entropy: torch.Tensor | None = None
entropy: torch.Tensor | None = None

@property
def aggregate_loss(self):
return self.loss_critic + self.loss_objective + self.loss_entropy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to recode this



class A2CLoss(LossModule):
"""TorchRL implementation of the A2C loss.

Expand Down Expand Up @@ -129,6 +159,16 @@ class A2CLoss(LossModule):
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> loss = A2CLoss(actor, value, loss_critic_type="l2", return_tensorclass=True)
>>> loss(data)
A2CLosses(
entropy=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_critic=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_entropy=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_objective=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
batch_size=torch.Size([]),
device=None,
is_shared=False)

This class is compatible with non-tensordict based modules too and can be
used without recurring to any tensordict-related primitive. In this case,
Expand Down Expand Up @@ -174,7 +214,7 @@ class A2CLoss(LossModule):
method.

Examples:
>>> loss.select_out_keys('loss_objective', 'loss_critic')
>>> _ = loss.select_out_keys('loss_objective', 'loss_critic')
>>> loss_obj, loss_critic = loss(
... observation = torch.randn(*batch, n_obs),
... action = spec.rand(batch),
Expand Down Expand Up @@ -240,6 +280,7 @@ def __init__(
functional: bool = True,
actor: ProbabilisticTensorDictSequential = None,
critic: ProbabilisticTensorDictSequential = None,
return_tensorclass: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be added to the docstrings

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

working on it.

Copy link
Author

@SandishKumarHN SandishKumarHN Feb 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vmoens added doctests for tensorclass changes. but I see some doctest issues and blockers. can you please help me resolve.

  • there are some existing doctest failures, we might need a separate task to address.
  • what would be the aggregate_loss for each loss within tensorclass?
  • there are some existing errors like
  1.    ```Cannot interpret 'torch.int64' as a data type```
    
  2.    ```'key "action_value" not found in TensorDict with keys [\'done\', \'logits\', \'observation\', \'reward\', \'state_value\', \'terminated\']' ```
    
  3.    ```NameError: name 'actor' is not defined```
    
  4. etc

reduction: str = None,
):
if actor is not None:
Expand Down Expand Up @@ -300,6 +341,7 @@ def __init__(
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self.loss_critic_type = loss_critic_type
self.return_tensorclass = return_tensorclass

@property
def functional(self):
Expand Down Expand Up @@ -455,7 +497,7 @@ def _cached_detach_critic_network_params(self):
return self.critic_network_params.detach()

@dispatch()
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def forward(self, tensordict: TensorDictBase) -> A2CLosses:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def forward(self, tensordict: TensorDictBase) -> A2CLosses:
def forward(self, tensordict: TensorDictBase) -> A2CLosses | TensorDictBase:

tensordict = tensordict.clone(False)
advantage = tensordict.get(self.tensor_keys.advantage, None)
if advantage is None:
Expand All @@ -474,6 +516,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
if self.critic_coef:
loss_critic = self.loss_critic(tensordict).mean()
td_out.set("loss_critic", loss_critic.mean())
if self.return_tensorclass:
return A2CLosses._from_tensordict(td_out)
loss_critic = self.loss_critic(tensordict)
td_out.set("loss_critic", loss_critic)
td_out = td_out.named_apply(
Expand Down
65 changes: 59 additions & 6 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import math
import warnings
from copy import deepcopy
Expand All @@ -12,7 +14,7 @@
import numpy as np
import torch
import torch.nn as nn
from tensordict import TensorDict, TensorDictBase
from tensordict import tensorclass, TensorDict, TensorDictBase
from tensordict.nn import dispatch, TensorDictModule
from tensordict.utils import NestedKey, unravel_key
from torch import Tensor
Expand All @@ -36,6 +38,32 @@
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator


class LossContainerBase:
"""ContainerBase class loss tensorclass's."""

__getitem__ = TensorDictBase.__getitem__

def aggregate_loss(self):
result = 0.0
for key in self.__dataclass_attr__:
if key.startswith("loss_"):
result += getattr(self, key)
return result


@tensorclass
class CQLLosses(LossContainerBase):
"""The tensorclass for The CQLLoss Loss class."""

alpha: torch.Tensor
loss_actor: torch.Tensor | None = None
loss_actor_bc: torch.Tensor | None = None
loss_qvalue: torch.Tensor | None = None
entropy: torch.Tensor | None = None
loss_alpha: torch.Tensor | None = None
loss_cql: torch.Tensor | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto



class CQLLoss(LossModule):
"""TorchRL implementation of the continuous CQL loss.

Expand Down Expand Up @@ -129,12 +157,27 @@ class CQLLoss(LossModule):
entropy: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_actor_bc: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_actor_bc: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_alpha: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_cql: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> loss = CQLLoss(actor, qvalue, return_tensorclass=True)
>>> loss(data)
CQLLosses(
alpha=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
entropy=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_actor=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_actor_bc=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_alpha=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_cql=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_qvalue=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
batch_size=torch.Size([]),
device=None,
is_shared=False)

This class is compatible with non-tensordict based modules too and can be
used without recurring to any tensordict-related primitive. In this case,
Expand Down Expand Up @@ -174,20 +217,21 @@ class CQLLoss(LossModule):
>>> loss = CQLLoss(actor, qvalue)
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> loss_actor, loss_actor_bc, loss_qvalue, loss_cql, *_ = loss(
>>> loss_actor, loss_qvalue, _, _, _, _ = loss(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this really the output now?
What I see is that the out_keys are

["loss_actor",
                "loss_actor_bc",
                "loss_qvalue",
                "loss_cql",
                "loss_alpha", ...]

... observation=torch.randn(*batch, n_obs),
... action=action,
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool),
... next_observation=torch.zeros(*batch, n_obs),
... next_reward=torch.randn(*batch, 1))
... next_reward=torch.randn(*batch, 1),
... )
>>> loss_actor.backward()

The output keys can also be filtered using the :meth:`CQLLoss.select_out_keys`
method.

Examples:
>>> _ = loss.select_out_keys('loss_actor', 'loss_qvalue')
>>> loss.select_out_keys('loss_actor', 'loss_qvalue')
>>> loss_actor, loss_qvalue = loss(
... observation=torch.randn(*batch, n_obs),
... action=action,
Expand Down Expand Up @@ -271,6 +315,7 @@ def __init__(
num_random: int = 10,
with_lagrange: bool = False,
lagrange_thresh: float = 0.0,
return_tensorclass: bool = False,
) -> None:
self._out_keys = None
super().__init__()
Expand Down Expand Up @@ -356,6 +401,7 @@ def __init__(
self._vmap_qvalue_network00 = _vmap_func(
self.qvalue_network, randomness=self.vmap_randomness
)
self.return_tensorclass = return_tensorclass

@property
def target_entropy(self):
Expand Down Expand Up @@ -524,7 +570,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
}
if self.with_lagrange:
out["loss_alpha_prime"] = alpha_prime_loss.mean()
return TensorDict(out, [])
td_out = TensorDict(out, [])
if self.return_tensorclass:
return CQLLosses._from_tensordict(td_out)
return td_out

@property
@_cache_values
Expand Down Expand Up @@ -1007,6 +1056,7 @@ def __init__(
delay_value: bool = True,
gamma: float = None,
action_space=None,
return_tensorclass: bool = False,
) -> None:
super().__init__()
self._in_keys = None
Expand Down Expand Up @@ -1047,6 +1097,7 @@ def __init__(

if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self.return_tensorclass = return_tensorclass

def _forward_value_estimator_keys(self, **kwargs) -> None:
if self._value_estimator is not None:
Expand Down Expand Up @@ -1178,7 +1229,7 @@ def value_loss(
return loss, metadata

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDict:
def forward(self, tensordict: TensorDictBase) -> CQLLosses:
"""Computes the (DQN) CQL loss given a tensordict sampled from the replay buffer.

This function will also write a "td_error" key that can be used by prioritized replay buffers to assign
Expand All @@ -1203,6 +1254,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
source=source,
batch_size=[],
)
if self.return_tensorclass:
return CQLLosses._from_tensordict(td_out)

return td_out

Expand Down
Loading
Loading