Skip to content

Commit

Permalink
review changes - 1
Browse files Browse the repository at this point in the history
  • Loading branch information
SandishKumarHN committed Feb 29, 2024
1 parent 64837f9 commit 66fc382
Show file tree
Hide file tree
Showing 13 changed files with 79 additions and 223 deletions.
27 changes: 4 additions & 23 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from tensordict.utils import NestedKey
from torch import distributions as d

from torchrl.objectives.common import LossModule
from torchrl.objectives.common import LossModule, LossContainerBase

from torchrl.objectives.utils import (
_cache_values,
Expand All @@ -34,20 +34,6 @@
VTrace,
)


class LossContainerBase:
"""ContainerBase class loss tensorclass's."""

__getitem__ = TensorDictBase.__getitem__

def aggregate_loss(self):
result = 0.0
for key in self.__dataclass_attr__:
if key.startswith("loss_"):
result += getattr(self, key)
return result


@tensorclass
class A2CLosses(LossContainerBase):
"""The tensorclass for The A2CLoss Loss class."""
Expand All @@ -58,11 +44,6 @@ class A2CLosses(LossContainerBase):
loss_entropy: torch.Tensor | None = None
entropy: torch.Tensor | None = None

@property
def aggregate_loss(self):
return self.loss_critic + self.loss_objective + self.loss_entropy


class A2CLoss(LossModule):
"""TorchRL implementation of the A2C loss.
Expand Down Expand Up @@ -164,8 +145,8 @@ class A2CLoss(LossModule):
A2CLosses(
entropy=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_critic=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_entropy=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_objective=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
loss_entropy=Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
loss_objective=Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False),
batch_size=torch.Size([]),
device=None,
is_shared=False)
Expand Down Expand Up @@ -497,7 +478,7 @@ def _cached_detach_critic_network_params(self):
return self.critic_network_params.detach()

@dispatch()
def forward(self, tensordict: TensorDictBase) -> A2CLosses:
def forward(self, tensordict: TensorDictBase) -> A2CLosses | TensorDictBase:
tensordict = tensordict.clone(False)
advantage = tensordict.get(self.tensor_keys.advantage, None)
if advantage is None:
Expand Down
21 changes: 16 additions & 5 deletions torchrl/objectives/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Iterator, List, Optional, Tuple

import torch
from tensordict import is_tensor_collection, TensorDict, TensorDictBase
from tensordict import tensorclass, is_tensor_collection, TensorDict, TensorDictBase

from tensordict.nn import TensorDictModule, TensorDictModuleBase, TensorDictParams
from torch import nn
Expand All @@ -38,6 +38,19 @@ def __init__(cls, name, bases, attr_dict):
cls.forward = set_exploration_type(ExplorationType.MODE)(cls.forward)


class LossContainerBase:
"""ContainerBase class loss tensorclass's."""

__getitem__ = TensorDictBase.__getitem__

@property
def aggregate_loss(self):
result = torch.zeros((), device=self.device)
for key in self.__dataclass_attr__:
if key.startswith("loss_"):
result += getattr(self, key)
return result

class LossModule(TensorDictModuleBase, metaclass=_LossMeta):
"""A parent class for RL losses.
Expand Down Expand Up @@ -252,7 +265,6 @@ def _compare_and_expand(param):
return param._apply_nest(
_compare_and_expand,
batch_size=[expand_dim, *param.shape],
filter_empty=False,
call_on_nested=True,
)
if not isinstance(param, nn.Parameter):
Expand All @@ -276,7 +288,6 @@ def _compare_and_expand(param):
params.apply(
_compare_and_expand,
batch_size=[expand_dim, *params.shape],
filter_empty=False,
call_on_nested=True,
),
no_convert=True,
Expand All @@ -298,7 +309,7 @@ def _compare_and_expand(param):
# set the functional module: we need to convert the params to non-differentiable params
# otherwise they will appear twice in parameters
with params.apply(
self._make_meta_params, device=torch.device("meta"), filter_empty=False
self._make_meta_params, device=torch.device("meta")
).to_module(module):
# avoid buffers and params being exposed
self.__dict__[module_name] = deepcopy(module)
Expand All @@ -309,7 +320,7 @@ def _compare_and_expand(param):
# we create a TensorDictParams to keep the target params as Buffer instances
target_params = TensorDictParams(
params.apply(
_make_target_param(clone=create_target_params), filter_empty=False
_make_target_param(clone=create_target_params)
),
no_convert=True,
)
Expand Down
20 changes: 3 additions & 17 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from torchrl.modules import ProbabilisticActor, QValueActor
from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible
from torchrl.objectives.common import LossModule
from torchrl.objectives.common import LossModule, LossContainerBase
from torchrl.objectives.utils import (
_cache_values,
_GAMMA_LMBDA_DEPREC_ERROR,
Expand All @@ -37,20 +37,6 @@

from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator


class LossContainerBase:
"""ContainerBase class loss tensorclass's."""

__getitem__ = TensorDictBase.__getitem__

def aggregate_loss(self):
result = 0.0
for key in self.__dataclass_attr__:
if key.startswith("loss_"):
result += getattr(self, key)
return result


@tensorclass
class CQLLosses(LossContainerBase):
"""The tensorclass for The CQLLoss Loss class."""
Expand Down Expand Up @@ -217,7 +203,7 @@ class CQLLoss(LossModule):
>>> loss = CQLLoss(actor, qvalue)
>>> batch = [2, ]
>>> action = spec.rand(batch)
>>> loss_actor, loss_qvalue, _, _, _, _ = loss(
>>> loss_actor, loss_qvalue, loss_actor_bc, loss_qvalue, loss_cql, loss_alpha = loss(
... observation=torch.randn(*batch, n_obs),
... action=action,
... next_done=torch.zeros(*batch, 1, dtype=torch.bool),
Expand Down Expand Up @@ -532,7 +518,7 @@ def out_keys(self, values):
self._out_keys = values

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def forward(self, tensordict: TensorDictBase) -> CQLLosses | TensorDictBase:
shape = None
if tensordict.ndimension() > 1:
shape = tensordict.shape
Expand Down
20 changes: 3 additions & 17 deletions torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from tensordict.utils import NestedKey, unravel_key
from torchrl.modules.tensordict_module.actors import ActorCriticWrapper
from torchrl.objectives.common import LossModule
from torchrl.objectives.common import LossModule, LossContainerBase
from torchrl.objectives.utils import (
_cache_values,
_GAMMA_LMBDA_DEPREC_ERROR,
Expand All @@ -26,20 +26,6 @@
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator


class LossContainerBase:
"""ContainerBase class loss tensorclass's."""

__getitem__ = TensorDictBase.__getitem__

def aggregate_loss(self):
result = 0.0
for key in self.__dataclass_attr__:
if key.startswith("loss_"):
result += getattr(self, key)
return result


@tensorclass
class DDPGLosses(LossContainerBase):
"""The tensorclass for The DDPGLoss class."""
Expand Down Expand Up @@ -171,7 +157,7 @@ class DDPGLoss(LossModule):
method.
Examples:
>>> out_keys = loss.select_out_keys('loss_actor', 'loss_value')
>>> _ = loss.select_out_keys('loss_actor', 'loss_value')
>>> loss_actor, loss_value = loss(
... observation=torch.randn(n_obs),
... action=spec.rand(),
Expand Down Expand Up @@ -315,7 +301,7 @@ def in_keys(self, values):
self._in_keys = values

@dispatch
def forward(self, tensordict: TensorDictBase) -> DDPGLosses:
def forward(self, tensordict: TensorDictBase) -> DDPGLosses | TensorDictBase:
"""Computes the DDPG losses given a tensordict sampled from the replay buffer.
This function will also write a "td_error" key that can be used by prioritized replay buffers to assign
Expand Down
18 changes: 2 additions & 16 deletions torchrl/objectives/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,9 @@
from torch import distributions as d
from torchrl.modules import ProbabilisticActor

from torchrl.objectives.common import LossModule
from torchrl.objectives.common import LossModule, LossContainerBase
from torchrl.objectives.utils import distance_loss


class LossContainerBase:
"""ContainerBase class loss tensorclass's."""

__getitem__ = TensorDictBase.__getitem__

def aggregate_loss(self):
result = 0.0
for key in self.__dataclass_attr__:
if key.startswith("loss_"):
result += getattr(self, key)
return result


@tensorclass
class OnlineDTLosses(LossContainerBase):
"""The tensorclass for The OnlineDTLoss Loss class."""
Expand Down Expand Up @@ -226,7 +212,7 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
return -log_p.mean(axis=0)

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def forward(self, tensordict: TensorDictBase) -> OnlineDTLosses | TensorDictBase:
"""Compute the loss for the Online Decision Transformer."""
# extract action targets
tensordict = tensordict.clone(False)
Expand Down
36 changes: 12 additions & 24 deletions torchrl/objectives/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)
from torchrl.modules.tensordict_module.common import ensure_tensordict_compatible

from torchrl.objectives.common import LossModule
from torchrl.objectives.common import LossModule, LossContainerBase
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_ERROR,
default_value_kwargs,
Expand All @@ -34,31 +34,13 @@
from torchrl.objectives.value import TDLambdaEstimator
from torchrl.objectives.value.advantages import TD0Estimator, TD1Estimator


class LossContainerBase:
"""ContainerBase class loss tensorclass's."""

__getitem__ = TensorDictBase.__getitem__

def aggregate_loss(self):
result = 0.0
for key in self.__dataclass_attr__:
if key.startswith("loss_"):
result += getattr(self, key)
return result


@tensorclass
class DQNLosses(LossContainerBase):
"""The tensorclass for The DQN Loss class."""

loss_objective: torch.Tensor
loss: torch.Tensor

@property
def aggregate_loss(self):
return self.loss_critic + self.loss_objective + self.loss_entropy


class DQNLoss(LossModule):
"""The DQN Loss class.
Expand Down Expand Up @@ -334,7 +316,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
self._value_estimator.set_keys(**tensor_keys)

@dispatch
def forward(self, tensordict: TensorDictBase) -> DQNLosses:
def forward(self, tensordict: TensorDictBase) -> DQNLosses | TensorDictBase:
"""Computes the DQN loss given a tensordict sampled from the replay buffer.
This function will also write a "td_error" key that can be used by prioritized replay buffers to assign
Expand Down Expand Up @@ -404,7 +386,10 @@ def forward(self, tensordict: TensorDictBase) -> DQNLosses:
inplace=True,
)
loss = distance_loss(pred_val_index, target_value, self.loss_function)
return TensorDict({"loss": loss.mean()}, [])
loss_td = TensorDict({"loss": loss.mean()}, [])
if self.return_tensorclass:
return DQNLosses._from_tensordict(loss_td)
return loss_td


class DistributionalDQNLoss(LossModule):
Expand Down Expand Up @@ -531,7 +516,7 @@ def _log_ps_a_categorical(action, action_log_softmax):
action = action.expand(new_shape)
return torch.gather(action_log_softmax, -1, index=action).squeeze(-1)

def forward(self, input_tensordict: TensorDictBase) -> DQNLosses:
def forward(self, input_tensordict: TensorDictBase) -> DQNLosses | TensorDictBase:
# from https://github.com/Kaixhin/Rainbow/blob/9ff5567ad1234ae0ed30d8471e8f13ae07119395/agent.py
tensordict = TensorDict(
source=input_tensordict, batch_size=input_tensordict.batch_size
Expand Down Expand Up @@ -644,8 +629,11 @@ def forward(self, input_tensordict: TensorDictBase) -> DQNLosses:
loss.detach().unsqueeze(1).to(input_tensordict.device),
inplace=True,
)
loss_td = TensorDict({"loss": loss.mean()}, [])
return loss_td
loss = _reduce(loss, reduction=self.reduction)
td_out = TensorDict({"loss": loss}, [])
if self.return_tensorclass:
return DQNLosses._from_tensordict(loss_td)
return td_out

def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
if value_type is None:
Expand Down
20 changes: 3 additions & 17 deletions torchrl/objectives/dreamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from torchrl.envs.model_based.dreamer import DreamerEnv
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.objectives.common import LossModule
from torchrl.objectives.common import LossModule, LossContainerBase
from torchrl.objectives.utils import (
_GAMMA_LMBDA_DEPREC_ERROR,
default_value_kwargs,
Expand All @@ -24,20 +24,6 @@
)
from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator


class LossContainerBase:
"""ContainerBase class loss tensorclass's."""

__getitem__ = TensorDictBase.__getitem__

def aggregate_loss(self):
result = 0.0
for key in self.__dataclass_attr__:
if key.startswith("loss_"):
result += getattr(self, key)
return result


@tensorclass
class DreamerModelLosses(LossContainerBase):
"""The tensorclass for The Dreamer Model Loss class."""
Expand Down Expand Up @@ -288,7 +274,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:

def forward(
self, tensordict: TensorDict
) -> Tuple[DreamerModelLosses, DreamerModelLosses]:
) -> Tuple[DreamerModelLosses, TensorDict] | Tuple[TensorDict, TensorDict]:
with torch.no_grad():
tensordict = tensordict.select("state", self.tensor_keys.belief)
tensordict = tensordict.reshape(-1)
Expand Down Expand Up @@ -328,7 +314,7 @@ def forward(
if self.return_tensorclass:
return DreamerModelLosses._from_tensordict(
loss_tensordict
), DreamerModelLosses._from_tensordict(fake_data.detach())
), fake_data.detach()
return loss_tensordict, fake_data.detach()

def lambda_target(self, reward: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
Expand Down
Loading

0 comments on commit 66fc382

Please sign in to comment.