Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support FSDP usage #72

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ For example, the Adam learning rate of hidden weights `p` is calculated as `glo

- `set_base_shapes(model, ...)` assumes that `model` has just been randomly initialized in the standard way and rescales its parameters using the base shape information so the model is in μP.
- If you want data parallelism, please use `torch.nn.parallel.DistributedDataParallel` instead of `torch.nn.DataParallel`. This is because the latter removes the attributes the `mup` package adds to each parameter tensor of the model. Also, for performance, `pytorch` [recommends the former anyway](https://pytorch.org/docs/stable/notes/cuda.html#cuda-nn-ddp-instead).
- For `FullyShardedDataParallel` (FSDP) usage, you have to use a PyTorch version ≥2 and wrap your model like `FSDP(..., use_orig_params=True)`.
- We scale the learning rate according to μP explicitly by creating refined parameter groups from what is passed to the `mup` optimizer and by manipulating the `lr` attribute in those groups. This is compatible with PyTorch's learning rate schedulers. However, if you roll your own, make sure the scheduler sets the learning rate relative to what is currently in the refined parameter groups. The following is an example of what *not* to do and what is OK:
```python
optimizer = mup.MuAdam(model.parameters(), lr=1e-3)
Expand Down
17 changes: 11 additions & 6 deletions mup/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,17 @@ def reset_parameters(self) -> None:
super().reset_parameters()

def width_mult(self):
assert hasattr(self.weight, 'infshape'), (
'Please call set_base_shapes(...). If using torch.nn.DataParallel, '
'switch to distributed training with '
'torch.nn.parallel.DistributedDataParallel instead'
)
return self.weight.infshape.width_mult()
if hasattr(self.weight, 'infshape'):
width_mult = self.weight.infshape.width_mult()
elif hasattr(self, 'weight_infshape'):
width_mult = self.weight_infshape.width_mult()
else:
raise AssertionError(
'Please call set_base_shapes(...). If using torch.nn.DataParallel, '
'switch to distributed training with '
'torch.nn.parallel.DistributedDataParallel instead'
)
return width_mult

def _rescale_parameters(self):
'''Rescale parameters to convert SP initialization to μP initialization.
Expand Down
10 changes: 10 additions & 0 deletions mup/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ def apply_infshapes(model, infshapes):
for name, p in model.named_parameters():
p.infshape = infshapes[name]


def _fix_fsdp_readout(module):
assert isinstance(module, MuReadout)
assert hasattr(module.weight, 'infshape')
module.weight_infshape = module.weight.infshape


def set_base_shapes(model, base, rescale_params=True, delta=None, savefile=None, do_assert=True):
'''Sets the `p.infshape` attribute for each parameter `p` of `model`.

Expand Down Expand Up @@ -192,6 +199,9 @@ def set_base_shapes(model, base, rescale_params=True, delta=None, savefile=None,
module._rescale_parameters()
elif isinstance(module, (Linear, _ConvNd)):
rescale_linear_bias(module)
for name, module in model.named_modules():
if isinstance(module, MuReadout):
_fix_fsdp_readout(module)
return model

def assert_hidden_size_inf(model):
Expand Down