diff --git a/docs/source/reference/modules.rst b/docs/source/reference/modules.rst index ccd6cb23ed0..b46d789ed15 100644 --- a/docs/source/reference/modules.rst +++ b/docs/source/reference/modules.rst @@ -317,6 +317,7 @@ Regular modules Conv3dNet SqueezeLayer Squeeze2dLayer + BatchRenorm Algorithm-specific modules ~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/sota-implementations/crossq/utils.py b/sota-implementations/crossq/utils.py index f6615689384..26798b1ee10 100644 --- a/sota-implementations/crossq/utils.py +++ b/sota-implementations/crossq/utils.py @@ -4,8 +4,6 @@ # LICENSE file in the root directory of this source tree. import torch - -from batchrenorm import BatchRenorm from tensordict.nn import InteractionType, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor from torch import nn, optim @@ -26,6 +24,8 @@ from torchrl.envs.utils import ExplorationType, set_exploration_type from torchrl.modules import MLP, ProbabilisticActor, ValueOperator from torchrl.modules.distributions import TanhNormal + +from torchrl.modules.models.batchrenorm1d import BatchRenorm1d from torchrl.objectives import CrossQLoss # ==================================================================== @@ -154,7 +154,7 @@ def make_crossQ_agent(cfg, train_env, device): "num_cells": cfg.network.actor_hidden_sizes, "out_features": 2 * action_spec.shape[-1], "activation_class": get_activation(cfg.network.actor_activation), - "norm_class": BatchRenorm, + "norm_class": BatchRenorm1d, "norm_kwargs": { "momentum": cfg.network.batch_norm_momentum, "num_features": cfg.network.actor_hidden_sizes[-1], @@ -201,7 +201,7 @@ def make_crossQ_agent(cfg, train_env, device): "num_cells": cfg.network.critic_hidden_sizes, "out_features": 1, "activation_class": get_activation(cfg.network.critic_activation), - "norm_class": BatchRenorm, + "norm_class": BatchRenorm1d, "norm_kwargs": { "momentum": cfg.network.batch_norm_momentum, "num_features": cfg.network.critic_hidden_sizes[-1], diff --git a/test/test_modules.py b/test/test_modules.py index 59adbea653d..e3b774a0358 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -34,7 +34,14 @@ VDNMixer, ) from torchrl.modules.distributions.utils import safeatanh, safetanh -from torchrl.modules.models import Conv3dNet, ConvNet, MLP, NoisyLazyLinear, NoisyLinear +from torchrl.modules.models import ( + BatchRenorm1d, + Conv3dNet, + ConvNet, + MLP, + NoisyLazyLinear, + NoisyLinear, +) from torchrl.modules.models.decision_transformer import ( _has_transformers, DecisionTransformer, @@ -1438,6 +1445,32 @@ def test_python_gru(device, bias, dropout, batch_first, num_layers): torch.testing.assert_close(h1, h2) +class TestBatchRenorm: + @pytest.mark.parametrize("num_steps", [0, 5]) + def test_batchrenorm(self, num_steps): + torch.manual_seed(0) + bn = torch.nn.BatchNorm1d(5, momentum=0.1, eps=1e-5) + brn = BatchRenorm1d( + 5, momentum=0.1, eps=1e-5, warmup_steps=num_steps, max_d=10000, max_r=10000 + ) + bn.train() + brn.train() + data_train = torch.randn(100, 5).split(25) + data_test = torch.randn(100, 5) + for d in data_train: + _ = bn(d) + _ = brn(d) + # if num_steps == 0: + # print(a, b) + # torch.testing.assert_close(a, b) + # else: + # assert not torch.isclose(a, b).all() + + bn.eval() + brn.eval() + torch.testing.assert_close(bn(data_test), brn(data_test)) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 7f462782757..4241f6613a0 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -26,8 +26,8 @@ LazyStackedTensorDict, TensorDict, TensorDictBase, + unravel_key, ) -from tensordict._tensordict import unravel_key from torch import multiprocessing as mp from torchrl._utils import ( _check_for_faulty_process, diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index bec76c603e6..eb9cdce923d 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -39,7 +39,7 @@ unravel_key, unravel_key_list, ) -from tensordict._tensordict import _unravel_key_to_tuple +from tensordict._C import _unravel_key_to_tuple from tensordict.nn import dispatch, TensorDictModuleBase from tensordict.utils import expand_as_right, expand_right, NestedKey from torch import nn, Tensor diff --git a/torchrl/modules/models/__init__.py b/torchrl/modules/models/__init__.py index fb0cc0135b8..2a2fc6b31d3 100644 --- a/torchrl/modules/models/__init__.py +++ b/torchrl/modules/models/__init__.py @@ -6,6 +6,8 @@ from torchrl.modules.tensordict_module.common import DistributionalDQNnet +from .batchrenorm1d import BatchRenorm1d + from .decision_transformer import DecisionTransformer from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise from .model_based import ( diff --git a/sota-implementations/crossq/batchrenorm.py b/torchrl/modules/models/batchrenorm1d.py similarity index 69% rename from sota-implementations/crossq/batchrenorm.py rename to torchrl/modules/models/batchrenorm1d.py index 9d1a78d1135..33b4df31cc7 100644 --- a/sota-implementations/crossq/batchrenorm.py +++ b/torchrl/modules/models/batchrenorm1d.py @@ -6,14 +6,16 @@ import torch.nn as nn -class BatchRenorm(nn.Module): +class BatchRenorm1d(nn.Module): """ BatchRenorm Module (https://arxiv.org/abs/1702.03275). + The code is adapted from https://github.com/google-research/corenet + BatchRenorm is an enhanced version of the standard BatchNorm. Unlike BatchNorm, - BatchRenorm utilizes running statistics to normalize batches after an initial warmup phase. - This approach reduces the impact of "outlier" batches that may occur during extended training periods, - making BatchRenorm more robust for long training runs. + it utilizes running statistics to normalize batches after an initial warmup phase. + This approach reduces the impact of "outlier" batches that may occur during + extended training periods, making BatchRenorm more robust for long training runs. During the warmup phase, BatchRenorm functions identically to a BatchNorm layer. @@ -21,21 +23,27 @@ class BatchRenorm(nn.Module): num_features (int): Number of features in the input tensor. Keyword Args: - momentum (float, optional): Momentum factor for computing the running mean and variance. Default is 0.01. - eps (float, optional): Small value added to the variance to avoid division by zero. Default is 1e-5. - max_r (float, optional): Maximum value for the scaling factor r. Default is 3.0. - max_d (float, optional): Maximum value for the bias factor d. Default is 5.0. - warmup_steps (int, optional): Number of warm-up steps for the running mean and variance. Default is 10000. + momentum (float, optional): Momentum factor for computing the running mean and variance. + Defaults to ``0.01``. + eps (float, optional): Small value added to the variance to avoid division by zero. + Defaults to ``1e-5``. + max_r (float, optional): Maximum value for the scaling factor r. + Defaults to ``3.0``. + max_d (float, optional): Maximum value for the bias factor d. + Defaults to ``5.0``. + warmup_steps (int, optional): Number of warm-up steps for the running mean and variance. + Defaults to ``10000``. """ def __init__( self, - num_features, - momentum=0.01, - eps=1e-5, - max_r=3.0, - max_d=5.0, - warmup_steps=10000, + num_features: int, + *, + momentum: float = 0.01, + eps: float = 1e-5, + max_r: float = 3.0, + max_d: float = 5.0, + warmup_steps: int = 10000, ): super().__init__() self.num_features = num_features @@ -56,9 +64,12 @@ def __init__( self.bias = nn.Parameter(torch.zeros(num_features, dtype=torch.float32)) def forward(self, x: torch.Tensor) -> torch.Tensor: - assert x.dim() >= 2 + if not x.dim() >= 2: + raise ValueError( + f"The {type(self).__name__} expects a 2D (or more) tensor, got {x.dim()}." + ) + view_dims = [1, x.shape[1]] + [1] * (x.dim() - 2) - # _v = lambda v: v.view(view_dims) def _v(v): return v.view(view_dims) @@ -79,18 +90,18 @@ def _v(v): ) # Compute warmup factor (0 during warmup, 1 after warmup) - warmup_factor = torch.clamp( - self.num_batches_tracked / self.warmup_steps, 0.0, 1.0 - ) - r = 1.0 + (r - 1.0) * warmup_factor - d = d * warmup_factor + if self.warmup_steps > 0: + warmup_factor = self.num_batches_tracked / self.warmup_steps + r = 1.0 + (r - 1.0) * warmup_factor + d = d * warmup_factor x = (x - _v(b_mean)) / _v(b_std) * _v(r) + _v(d) - unbiased_var = b_var.detach() * x.shape[1] / (x.shape[1] - 1) + unbiased_var = b_var.detach() * x.shape[0] / (x.shape[0] - 1) self.running_var += self.momentum * (unbiased_var - self.running_var) self.running_mean += self.momentum * (b_mean.detach() - self.running_mean) self.num_batches_tracked += 1 + self.num_batches_tracked.clamp_max(self.warmup_steps) else: x = (x - _v(self.running_mean)) / _v(running_std) diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 6e7e6db1697..2372b9e3163 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -251,7 +251,6 @@ def __init__( fixed_alpha: bool = False, target_entropy: Union[str, float] = "auto", delay_actor: bool = False, - gamma: float = None, priority_key: str = None, separate_losses: bool = False, reduction: str = None, @@ -326,8 +325,6 @@ def __init__( self._target_entropy = target_entropy self._action_spec = action_spec - if gamma is not None: - raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._vmap_qnetworkN0 = _vmap_func( self.qvalue_network, (None, 0), randomness=self.vmap_randomness )