You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi! I've been looking into the integration of muP into the Megatron-LM setup and I was wondering about the _rescale_parameters() method of MuReadout in case of shared (tied) input/output embeddings. Specifically, in the Transformer example I am not really sure that it is in line with the suggested embedding initialisation (i.e., constant) from the paper.
Currently, in the example:
encoder is initialised from N(0,1) <- default nn.Embedding init
but then the decoder weights are overwritten with those from encoder (the next line 68) -> they become N(0,1) init
finally, once set_base_shapes() is called, both encoder and decoder weights will be rescaled within _rescale_parameters() by *= self.width_mult()**0.5 -> which makes them initialised from N(0, sqrt(d/d_0)) and so scale with width.
However, in the muP paper it is suggested to initialise them as constants to be muP compatible. It also should be mentioned that in the untied case, the output embeddings are set to 0, so _rescale_parameters() doesn't have an effect and things are consistent with the paper.
Below I also attach the coordinate check plots for the Transformer example for untied, tied+rescaling (current implementation) and tied+no rescaling (_rescale_parameters() disabled), respectively. One can see that for untied the norms are nicely flat, for tied+rescaling some layers have growing activations, and for tied+no rescaling one layer has a vanishing trend.
So I was wondering if _rescale_parameters() should be disabled for the tied embedding scenario to keep the init constant, assuming the inheritance of N(0,1) initialisation in nn.Embedding()?
The text was updated successfully, but these errors were encountered:
Thanks for pointing this out! Your analysis seems correct.
A simple fix is to add self._has_rescaled_params = True to the constructor of MuSharedReadout, so we don't trigger rescaling. I'll do that after making sure it doesn't have unintended consequences.
The vanishing preactivation in the last row should be the final logits. The GP behavior at init follows CLT yet the scaling accounts for LLN. A way to get rid of it is to initialize the shared embedding layer to zero, which is okay as long as there the embedded input is not always zero (e.g., through a non-zero positional embedding). You should produce flat curves that way.
Thank you @edwardjhu for your answer and suggestions! Initialising shared embeddings to zero is a good idea, I will try that out in the Megatron setup and see if the curves look flat there :)
Hi! I've been looking into the integration of
muP
into theMegatron-LM
setup and I was wondering about the_rescale_parameters()
method ofMuReadout
in case of shared (tied) input/output embeddings. Specifically, in the Transformer example I am not really sure that it is in line with the suggested embedding initialisation (i.e., constant) from the paper.Currently, in the example:
encoder
is initialised fromN(0,1)
<- defaultnn.Embedding
initmup/examples/Transformer/model.py
Line 93 in a33ea80
decoder
is firstly initialised withinMuSharedReadout
fromU(-1/sqrt(fan_in), 1/sqrt(fan_in))
<- defaultnn.Linear
initmup/mup/layer.py
Line 67 in a33ea80
encoder
(the next line 68) -> they becomeN(0,1)
initset_base_shapes()
is called, both encoder and decoder weights will be rescaled within_rescale_parameters()
by*= self.width_mult()**0.5
-> which makes them initialised fromN(0, sqrt(d/d_0))
and so scale with width.However, in the muP paper it is suggested to initialise them as constants to be muP compatible. It also should be mentioned that in the untied case, the output embeddings are set to 0, so
_rescale_parameters()
doesn't have an effect and things are consistent with the paper.Below I also attach the coordinate check plots for the Transformer example for untied, tied+rescaling (current implementation) and tied+no rescaling (
_rescale_parameters()
disabled), respectively. One can see that for untied the norms are nicely flat, for tied+rescaling some layers have growing activations, and for tied+no rescaling one layer has a vanishing trend.So I was wondering if
_rescale_parameters()
should be disabled for the tied embedding scenario to keep the init constant, assuming the inheritance of N(0,1) initialisation innn.Embedding()
?The text was updated successfully, but these errors were encountered: