Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jul 8, 2024
1 parent 845c8a9 commit 68a1a9f
Show file tree
Hide file tree
Showing 8 changed files with 77 additions and 33 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ Regular modules
Conv3dNet
SqueezeLayer
Squeeze2dLayer
BatchRenorm

Algorithm-specific modules
~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
8 changes: 4 additions & 4 deletions sota-implementations/crossq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
# LICENSE file in the root directory of this source tree.

import torch

from batchrenorm import BatchRenorm
from tensordict.nn import InteractionType, TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn, optim
Expand All @@ -26,6 +24,8 @@
from torchrl.envs.utils import ExplorationType, set_exploration_type
from torchrl.modules import MLP, ProbabilisticActor, ValueOperator
from torchrl.modules.distributions import TanhNormal

from torchrl.modules.models.batchrenorm1d import BatchRenorm1d
from torchrl.objectives import CrossQLoss

# ====================================================================
Expand Down Expand Up @@ -154,7 +154,7 @@ def make_crossQ_agent(cfg, train_env, device):
"num_cells": cfg.network.actor_hidden_sizes,
"out_features": 2 * action_spec.shape[-1],
"activation_class": get_activation(cfg.network.actor_activation),
"norm_class": BatchRenorm,
"norm_class": BatchRenorm1d,
"norm_kwargs": {
"momentum": cfg.network.batch_norm_momentum,
"num_features": cfg.network.actor_hidden_sizes[-1],
Expand Down Expand Up @@ -201,7 +201,7 @@ def make_crossQ_agent(cfg, train_env, device):
"num_cells": cfg.network.critic_hidden_sizes,
"out_features": 1,
"activation_class": get_activation(cfg.network.critic_activation),
"norm_class": BatchRenorm,
"norm_class": BatchRenorm1d,
"norm_kwargs": {
"momentum": cfg.network.batch_norm_momentum,
"num_features": cfg.network.critic_hidden_sizes[-1],
Expand Down
35 changes: 34 additions & 1 deletion test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@
VDNMixer,
)
from torchrl.modules.distributions.utils import safeatanh, safetanh
from torchrl.modules.models import Conv3dNet, ConvNet, MLP, NoisyLazyLinear, NoisyLinear
from torchrl.modules.models import (
BatchRenorm1d,
Conv3dNet,
ConvNet,
MLP,
NoisyLazyLinear,
NoisyLinear,
)
from torchrl.modules.models.decision_transformer import (
_has_transformers,
DecisionTransformer,
Expand Down Expand Up @@ -1438,6 +1445,32 @@ def test_python_gru(device, bias, dropout, batch_first, num_layers):
torch.testing.assert_close(h1, h2)


class TestBatchRenorm:
@pytest.mark.parametrize("num_steps", [0, 5])
def test_batchrenorm(self, num_steps):
torch.manual_seed(0)
bn = torch.nn.BatchNorm1d(5, momentum=0.1, eps=1e-5)
brn = BatchRenorm1d(
5, momentum=0.1, eps=1e-5, warmup_steps=num_steps, max_d=10000, max_r=10000
)
bn.train()
brn.train()
data_train = torch.randn(100, 5).split(25)
data_test = torch.randn(100, 5)
for d in data_train:
_ = bn(d)
_ = brn(d)
# if num_steps == 0:
# print(a, b)
# torch.testing.assert_close(a, b)
# else:
# assert not torch.isclose(a, b).all()

bn.eval()
brn.eval()
torch.testing.assert_close(bn(data_test), brn(data_test))


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
2 changes: 1 addition & 1 deletion torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
LazyStackedTensorDict,
TensorDict,
TensorDictBase,
unravel_key,
)
from tensordict._tensordict import unravel_key
from torch import multiprocessing as mp
from torchrl._utils import (
_check_for_faulty_process,
Expand Down
2 changes: 1 addition & 1 deletion torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
unravel_key,
unravel_key_list,
)
from tensordict._tensordict import _unravel_key_to_tuple
from tensordict._C import _unravel_key_to_tuple
from tensordict.nn import dispatch, TensorDictModuleBase
from tensordict.utils import expand_as_right, expand_right, NestedKey
from torch import nn, Tensor
Expand Down
2 changes: 2 additions & 0 deletions torchrl/modules/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from torchrl.modules.tensordict_module.common import DistributionalDQNnet

from .batchrenorm1d import BatchRenorm1d

from .decision_transformer import DecisionTransformer
from .exploration import NoisyLazyLinear, NoisyLinear, reset_noise
from .model_based import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,44 @@
import torch.nn as nn


class BatchRenorm(nn.Module):
class BatchRenorm1d(nn.Module):
"""
BatchRenorm Module (https://arxiv.org/abs/1702.03275).
The code is adapted from https://github.com/google-research/corenet
BatchRenorm is an enhanced version of the standard BatchNorm. Unlike BatchNorm,
BatchRenorm utilizes running statistics to normalize batches after an initial warmup phase.
This approach reduces the impact of "outlier" batches that may occur during extended training periods,
making BatchRenorm more robust for long training runs.
it utilizes running statistics to normalize batches after an initial warmup phase.
This approach reduces the impact of "outlier" batches that may occur during
extended training periods, making BatchRenorm more robust for long training runs.
During the warmup phase, BatchRenorm functions identically to a BatchNorm layer.
Args:
num_features (int): Number of features in the input tensor.
Keyword Args:
momentum (float, optional): Momentum factor for computing the running mean and variance. Default is 0.01.
eps (float, optional): Small value added to the variance to avoid division by zero. Default is 1e-5.
max_r (float, optional): Maximum value for the scaling factor r. Default is 3.0.
max_d (float, optional): Maximum value for the bias factor d. Default is 5.0.
warmup_steps (int, optional): Number of warm-up steps for the running mean and variance. Default is 10000.
momentum (float, optional): Momentum factor for computing the running mean and variance.
Defaults to ``0.01``.
eps (float, optional): Small value added to the variance to avoid division by zero.
Defaults to ``1e-5``.
max_r (float, optional): Maximum value for the scaling factor r.
Defaults to ``3.0``.
max_d (float, optional): Maximum value for the bias factor d.
Defaults to ``5.0``.
warmup_steps (int, optional): Number of warm-up steps for the running mean and variance.
Defaults to ``10000``.
"""

def __init__(
self,
num_features,
momentum=0.01,
eps=1e-5,
max_r=3.0,
max_d=5.0,
warmup_steps=10000,
num_features: int,
*,
momentum: float = 0.01,
eps: float = 1e-5,
max_r: float = 3.0,
max_d: float = 5.0,
warmup_steps: int = 10000,
):
super().__init__()
self.num_features = num_features
Expand All @@ -56,9 +64,12 @@ def __init__(
self.bias = nn.Parameter(torch.zeros(num_features, dtype=torch.float32))

def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() >= 2
if not x.dim() >= 2:
raise ValueError(
f"The {type(self).__name__} expects a 2D (or more) tensor, got {x.dim()}."
)

view_dims = [1, x.shape[1]] + [1] * (x.dim() - 2)
# _v = lambda v: v.view(view_dims)

def _v(v):
return v.view(view_dims)
Expand All @@ -79,18 +90,18 @@ def _v(v):
)

# Compute warmup factor (0 during warmup, 1 after warmup)
warmup_factor = torch.clamp(
self.num_batches_tracked / self.warmup_steps, 0.0, 1.0
)
r = 1.0 + (r - 1.0) * warmup_factor
d = d * warmup_factor
if self.warmup_steps > 0:
warmup_factor = self.num_batches_tracked / self.warmup_steps
r = 1.0 + (r - 1.0) * warmup_factor
d = d * warmup_factor

x = (x - _v(b_mean)) / _v(b_std) * _v(r) + _v(d)

unbiased_var = b_var.detach() * x.shape[1] / (x.shape[1] - 1)
unbiased_var = b_var.detach() * x.shape[0] / (x.shape[0] - 1)
self.running_var += self.momentum * (unbiased_var - self.running_var)
self.running_mean += self.momentum * (b_mean.detach() - self.running_mean)
self.num_batches_tracked += 1
self.num_batches_tracked.clamp_max(self.warmup_steps)
else:
x = (x - _v(self.running_mean)) / _v(running_std)

Expand Down
3 changes: 0 additions & 3 deletions torchrl/objectives/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,6 @@ def __init__(
fixed_alpha: bool = False,
target_entropy: Union[str, float] = "auto",
delay_actor: bool = False,
gamma: float = None,
priority_key: str = None,
separate_losses: bool = False,
reduction: str = None,
Expand Down Expand Up @@ -326,8 +325,6 @@ def __init__(

self._target_entropy = target_entropy
self._action_spec = action_spec
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self._vmap_qnetworkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
)
Expand Down

0 comments on commit 68a1a9f

Please sign in to comment.