Skip to content

Commit

Permalink
implement weight norm
Browse files Browse the repository at this point in the history
Fix #91
  • Loading branch information
albertz committed Nov 12, 2022
1 parent d883a4c commit 09a0c70
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 2 deletions.
6 changes: 4 additions & 2 deletions nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from .math_ import *
from .array_ import *
from .rand import *

from . import init

from .utils import *
from .search import *
from .normalization import *
Expand All @@ -25,7 +28,6 @@
from .container import *
from .masked_computation import *
from .attention import *

from .transformer import *
from .conformer import *

from . import init
1 change: 1 addition & 0 deletions nn/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .label_smoothing import *
from .stochastic_depth import *
from .targets import *
from .weight_norm import weight_norm, remove_weight_norm
176 changes: 176 additions & 0 deletions nn/utils/weight_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
r"""
Weight Normalization from https://arxiv.org/abs/1602.07868
Code adapted from PyTorch implementation.
"""

from __future__ import annotations
from typing import Optional, Union, Sequence, TypeVar
import numpy
from ... import nn


T_module = TypeVar('T_module', bound=nn.Module)


def weight_norm(module: T_module, name: str = "weight", dim: Optional[nn.Dim] = nn.NotSpecified) -> T_module:
r"""Applies weight normalization to a parameter in the given module.
.. math::
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
Weight normalization is a reparameterization that decouples the magnitude
of a weight tensor from its direction. This replaces the parameter specified
by :attr:`name` (e.g. ``'weight'``) with two parameters: one specifying the magnitude
(e.g. ``'weight_g'``) and one specifying the direction (e.g. ``'weight_v'``).
Weight normalization is implemented via a hook that recomputes the weight
tensor from the magnitude and direction before every :meth:`~Module.forward`
call.
By default, with ``dim=weight.feature_dim``, the norm is computed independently per output
channel/plane. To compute a norm over the entire weight tensor, use
``dim=None``.
See https://arxiv.org/abs/1602.07868
Args:
module (Module): containing module
name (str, optional): name of weight parameter
dim (int, optional): dimension over which to compute the norm
Returns:
The original module with the weight norm hook
"""
weight = getattr(module, name)
if isinstance(weight, WeightNorm):
raise RuntimeError("Cannot register two weight_norm hooks on the same parameter {}".format(name))
assert isinstance(weight, nn.Parameter)

fn = WeightNorm(weight, dim)

delattr(module, name) # remove w from parameter list
setattr(module, f"{name}_normalized", fn) # add weight norm functions
setattr(module, name, fn.compute_weight()) # set it to calculated weight

return fn


def remove_weight_norm(module: T_module, name: str = 'weight') -> T_module:
r"""Removes the weight normalization reparameterization from a module.
Args:
module (Module): containing module
name (str, optional): name of weight parameter
"""
fn = getattr(module, f"_{name}_weight_normalized")
assert isinstance(fn, WeightNorm)
delattr(module, name)
delattr(module, f"{name}_normalized")

p = nn.Parameter(fn.v.shape_ordered, fn.v.dtype)
p.initial = fn.weight_init()
setattr(module, name, p)
return module


class WeightNorm(nn.Module):
"""
Encapsulates a weight-normalized parameter.
"""

def __init__(self, weight: nn.Parameter, dim: Optional[nn.Dim], eps=1e-6) -> None:
self.dim = dim
self.eps = eps

# add g and v as new parameters and express w as g/||v|| * v
g = nn.Parameter([dim] if dim else [], weight.dtype)
v = nn.Parameter(weight.shape_ordered, weight.dtype)
self.g = g
self.v = v

self.norm_axes = v.batch_dims_ordered(dim)
if isinstance(weight, nn.Parameter) and weight.initial is not None:
# Custom ParamInit such that any deepcopy will make individual random inits.
v.initial = WeightNormDirectionParamInit(weight.initial)
g.initial = WeightNormScaleParamInit(self)
else:
g.initial = 1.

def compute_weight(self) -> nn.Tensor:
"""computes the actual weight from g and v"""
g = self.g
v = self.v
# See _weight_norm in PyTorch.
# https://github.com/pytorch/pytorch/blob/324ac93a43a93f671bb34b835926b22d13442735/aten/src/ATen/native/WeightNorm.cpp#L107
# v*(g/at::norm_except_dim(v, 2, dim));
# Tensor norm_except_dim(const Tensor & v, int64_t pow, int64_t dim) {
# if (dim == -1)
# return v.norm(pow);
# else if (dim == 0) {
# std::vector<int64_t> output_size(v.dim(), 1);
# output_size[0] = v.size(0);
# return v.contiguous().view({v.size(0), -1}).norm(pow, 1).view(output_size);
# } ...
assert isinstance(v, nn.Tensor)
return v * (g * nn.rsqrt(nn.reduce(nn.square(v), mode="sum", axis=self.norm_axes) + self.eps))

def g_init(self, weight_init: Union[nn.Tensor, nn.RawTensorTypes]) -> Union[nn.Tensor, nn.RawTensorTypes]:
"""
given specific weight_init, calculate g_init
"""
if not isinstance(weight_init, nn.Tensor):
return numpy.sqrt(numpy.square(weight_init) + self.eps) # assume scalar
return nn.sqrt(nn.reduce(nn.square(weight_init), mode="sum", axis=self.norm_axes) + self.eps)

def weight_init(self) -> Optional[nn.init.ParamInitType]:
"""
from the original weight, or wrapped
"""
if self.v.initial is None:
return None
init = self.v.initial
if isinstance(init, WeightNormDirectionParamInit):
return init.weight_init
return None


class WeightNormDirectionParamInit(nn.init.ParamInit):
"""
Param init weight norm
"""

def __init__(self, weight_init: nn.init.ParamInitType):
self.weight_init = weight_init
self.weight_init_value = None # type: Optional[Union[nn.Tensor, nn.RawTensorTypes]]

def __call__(self, shape: Sequence[nn.Dim], dtype: str) -> Union[nn.Tensor, nn.RawTensorTypes]:
if isinstance(self.weight_init, nn.init.ParamInit):
if self.weight_init_value is None:
self.weight_init_value = self.weight_init(shape, dtype)
return self.weight_init_value
raise Exception(f"{self}: Don't call this twice. You probably miss a deepcopy.")
return self.weight_init

def __copy__(self):
return WeightNormDirectionParamInit(self.weight_init)

def get_weight_init_value(self) -> Union[nn.Tensor, nn.RawTensorTypes]:
"""get value"""
if isinstance(self.weight_init, nn.init.ParamInit):
assert self.weight_init_value is not None, f"{self}: Expected to be called before."
return self.weight_init_value
return self.weight_init


class WeightNormScaleParamInit(nn.init.ParamInit):
"""
Param init weight norm
"""
def __init__(self, parent: WeightNorm):
self.parent = parent

def __call__(self, shape: Sequence[nn.Dim], dtype: str) -> Union[nn.Tensor, nn.RawTensorTypes]:
v_init = self.parent.v.initial
if isinstance(v_init, WeightNormDirectionParamInit):
return self.parent.g_init(v_init.get_weight_init_value())
return 1.
16 changes: 16 additions & 0 deletions tests/test_nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,19 @@ def test_prev_target_seq():
out.mark_as_default_output()
config_str = nn.get_returnn_config().get_complete_py_code_str(nn.Module())
dummy_run_net_single_custom(config_str, eval_flag=True)


def test_weight_norm():
nn.reset_default_root_name_ctx()
time_dim = nn.SpatialDim("time")
in_dim = nn.FeatureDim("in", 3)
x = nn.Data("data", dim_tags=[nn.batch_dim, time_dim, in_dim])
x = nn.get_extern_data(x)
net = nn.Linear(in_dim, nn.FeatureDim("out", 5))
assert isinstance(net.weight, nn.Parameter)
nn.weight_norm(net, "weight", net.out_dim)
assert not isinstance(net.weight, nn.Parameter) and isinstance(net.weight, nn.Tensor)
y = net(x)
y.mark_as_default_output()
config_str = nn.get_returnn_config().get_complete_py_code_str(net)
dummy_run_net_single_custom(config_str, eval_flag=True)

0 comments on commit 09a0c70

Please sign in to comment.