-
Notifications
You must be signed in to change notification settings - Fork 130
Parameter initialization
Albert Zeyer edited this page Jan 19, 2022
·
16 revisions
- std: standard deviation
- var: variance
- E[X]: expected value (mean, average)
- std = sqrt(var), var = std ** 2, var = E[(X - E[X]) ** 2] = E[X**2] - (E[X])**2
- TF
VarianceScaling
- std = sqrt(scale / fan)
- normal is always truncated normal
- Truncated normal std = std / .87962566103423978, via TF VarianceScaling, Scipy a=-2, b=2 ...
- Uniform bound = sqrt(3) * std (uniformly draw samples from interval [-bound, bound], results in the given std)
- fan_in, fan_out: input/output dimension, potentially multiplied by filter sizes (receptive field size) in case of convolution
- fan_avg = (fan_in + fan_out) / 2
- Xavier Glorot (paper 2010): VarianceScaling(scale=1.0, mode="fan_avg", usually distribution="uniform")
- Kaiming He (paper 2015): VarianceScaling(scale=2., mode="fan_in", usually distribution="normal")
-
RETURNN Theano:
VarianceScaling(scale=6.0, mode="fan_avg", distribution="normal")
-
RETURNN TensorFlow: Glorot uniform =
VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")
-
Keras: Glorot uniform
-
Lingvo: Glorot uniform
-
PyTorch:
uniform(-1/sqrt(fan_in), 1/sqrt(fan_in))
=VarianceScaling(scale=1. / 3, mode="fan_in", distribution="uniform")
-
PyTorch proposed:
kaiming_normal(mode='fan_in')
=VarianceScaling(scale=2., mode="fan_in", distribution="normal")
-
Transformer:
- RETURNN commonly used for both attention projection and FF:
VarianceScaling(scale=0.78, mode='fan_in', distribution='uniform')
- RETURNN Transformer LM commonly used (here):
VarianceScaling(mode='fan_in', distribution='uniform', scale=1.0)
- T2T defaults: initializer="uniform_unit_scaling", initializer_gain=1.0, means
VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")
-
Fairseq attention projection:
VarianceScaling(mode='fan_avg', distribution='uniform', scale=0.5)
- ESPNet Conformer FF:
VarianceScaling(mode='fan_avg', distribution='uniform', scale=1.0)
-
gpt-neo (config):
random_normal(stddev=0.02)
andrandom_normal(stddev=0.02 / sqrt(n_layer))
- RETURNN commonly used for both attention projection and FF:
def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
with torch.no_grad():
return tensor.uniform_(-bound, bound)
def _calculate_correct_fan(tensor, mode):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out
def _calculate_fan_in_and_fan_out(tensor):
dimensions = tensor.dim()
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
num_input_fmaps = tensor.size(1)
num_output_fmaps = tensor.size(0)
receptive_field_size = 1
if tensor.dim() > 2:
# math.prod is not always available, accumulate the product manually
# we could use functools.reduce but that is not supported by TorchScript
for s in tensor.shape[2:]:
receptive_field_size *= s
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def calculate_gain(nonlinearity, param=None):
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
elif nonlinearity == 'tanh':
return 5.0 / 3
elif nonlinearity == 'relu':
return math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope ** 2))
elif nonlinearity == 'selu':
return 3.0 / 4 # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
class VarianceScaling(Initializer):
def __init__(self,
scale=1.0,
mode="fan_in",
distribution="truncated_normal",
seed=None,
dtype=dtypes.float32):
if scale <= 0.:
raise ValueError("Argument `scale` must be a positive float. Received: "
f"{scale}")
if mode not in {"fan_in", "fan_out", "fan_avg"}:
raise ValueError("Argument `mode` should be one of ('fan_in', 'fan_out', "
f"'fan_avg'). Received: {mode}")
distribution = distribution.lower()
if distribution not in {
"normal", "uniform", "truncated_normal", "untruncated_normal"
}:
raise ValueError("Argument `distribution` should be one of ('normal', "
"uniform', 'truncated_normal', 'untruncated_normal'). "
f"Received: {distribution}")
self.scale = scale
self.mode = mode
self.distribution = distribution
self.seed = seed
self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))
def __call__(self, shape, dtype=None, partition_info=None):
if dtype is None:
dtype = self.dtype
scale = self.scale
scale_shape = shape
if partition_info is not None:
scale_shape = partition_info.full_shape
fan_in, fan_out = _compute_fans(scale_shape)
if self.mode == "fan_in":
scale /= max(1., fan_in)
elif self.mode == "fan_out":
scale /= max(1., fan_out)
else:
scale /= max(1., (fan_in + fan_out) / 2.)
if self.distribution == "normal" or self.distribution == "truncated_normal":
# constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
stddev = math.sqrt(scale) / .87962566103423978
return random_ops.truncated_normal(
shape, 0.0, stddev, dtype, seed=self.seed)
elif self.distribution == "untruncated_normal":
stddev = math.sqrt(scale)
return random_ops.random_normal(shape, 0.0, stddev, dtype, seed=self.seed)
else:
limit = math.sqrt(3.0 * scale)
return random_ops.random_uniform(
shape, -limit, limit, dtype, seed=self.seed)
def _compute_fans(shape):
"""Computes the number of input and output units for a weight shape.
Args:
shape: Integer shape tuple or TF tensor shape.
Returns:
A tuple of integer scalars (fan_in, fan_out).
"""
if len(shape) < 1: # Just to avoid errors for constants.
fan_in = fan_out = 1
elif len(shape) == 1:
fan_in = fan_out = shape[0]
elif len(shape) == 2:
fan_in = shape[0]
fan_out = shape[1]
else:
# Assuming convolution kernels (2D, 3D, or more).
# kernel shape: (..., input_depth, depth)
receptive_field_size = 1
for dim in shape[:-2]:
receptive_field_size *= dim
fan_in = shape[-2] * receptive_field_size
fan_out = shape[-1] * receptive_field_size
return int(fan_in), int(fan_out)