Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

_rescale_parameters() inconsistent with the paper for the tied embedding scenario? #55

Open
ofivite opened this issue Jul 12, 2023 · 2 comments

Comments

@ofivite
Copy link

ofivite commented Jul 12, 2023

Hi! I've been looking into the integration of muP into the Megatron-LM setup and I was wondering about the _rescale_parameters() method of MuReadout in case of shared (tied) input/output embeddings. Specifically, in the Transformer example I am not really sure that it is in line with the suggested embedding initialisation (i.e., constant) from the paper.

Currently, in the example:

  • encoder is initialised from N(0,1) <- default nn.Embedding init
    self.encoder = nn.Embedding(ntoken, ninp)
  • decoder is firstly initialised within MuSharedReadout from U(-1/sqrt(fan_in), 1/sqrt(fan_in)) <- default nn.Linear init
    super().__init__(*weight.shape, bias=bias, **kwargs)
  • but then the decoder weights are overwritten with those from encoder (the next line 68) -> they become N(0,1) init
  • finally, once set_base_shapes() is called, both encoder and decoder weights will be rescaled within _rescale_parameters() by *= self.width_mult()**0.5 -> which makes them initialised from N(0, sqrt(d/d_0)) and so scale with width.

However, in the muP paper it is suggested to initialise them as constants to be muP compatible. It also should be mentioned that in the untied case, the output embeddings are set to 0, so _rescale_parameters() doesn't have an effect and things are consistent with the paper.

Below I also attach the coordinate check plots for the Transformer example for untied, tied+rescaling (current implementation) and tied+no rescaling (_rescale_parameters() disabled), respectively. One can see that for untied the norms are nicely flat, for tied+rescaling some layers have growing activations, and for tied+no rescaling one layer has a vanishing trend.

So I was wondering if _rescale_parameters() should be disabled for the tied embedding scenario to keep the init constant, assuming the inheritance of N(0,1) initialisation in nn.Embedding()?

μp_trsfmr_adam_coord
μp_trsfmr_adam_coord_tied
μp_trsfmr_adam_coord_tied_fix

@edwardjhu
Copy link
Collaborator

Thanks for pointing this out! Your analysis seems correct.

A simple fix is to add self._has_rescaled_params = True to the constructor of MuSharedReadout, so we don't trigger rescaling. I'll do that after making sure it doesn't have unintended consequences.

The vanishing preactivation in the last row should be the final logits. The GP behavior at init follows CLT yet the scaling accounts for LLN. A way to get rid of it is to initialize the shared embedding layer to zero, which is okay as long as there the embedded input is not always zero (e.g., through a non-zero positional embedding). You should produce flat curves that way.

@ofivite
Copy link
Author

ofivite commented Aug 7, 2023

Thank you @edwardjhu for your answer and suggestions! Initialising shared embeddings to zero is a good idea, I will try that out in the Megatron setup and see if the curves look flat there :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants