From c085dbc3f3811e7c94a35b65d9c8ba94f8afc8bb Mon Sep 17 00:00:00 2001 From: janEbert Date: Fri, 3 May 2024 14:52:43 +0200 Subject: [PATCH 1/3] Fix FSDP --- mup/layer.py | 18 ++++++++++++------ mup/shape.py | 10 ++++++++++ 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/mup/layer.py b/mup/layer.py index 518a33b..8bb07af 100644 --- a/mup/layer.py +++ b/mup/layer.py @@ -25,12 +25,18 @@ def reset_parameters(self) -> None: super().reset_parameters() def width_mult(self): - assert hasattr(self.weight, 'infshape'), ( - 'Please call set_base_shapes(...). If using torch.nn.DataParallel, ' - 'switch to distributed training with ' - 'torch.nn.parallel.DistributedDataParallel instead' - ) - return self.weight.infshape.width_mult() + if not hasattr(self.weight, 'infshape'): + if not hasattr(self, 'weight_infshape'): + raise AssertionError( + 'Please call set_base_shapes(...). If using torch.nn.DataParallel, ' + 'switch to distributed training with ' + 'torch.nn.parallel.DistributedDataParallel instead' + ) + else: + width_mult = self.weight_infshape.width_mult() + else: + width_mult = self.weight.infshape.width_mult() + return width_mult def _rescale_parameters(self): '''Rescale parameters to convert SP initialization to μP initialization. diff --git a/mup/shape.py b/mup/shape.py index 6889e0b..69cc790 100644 --- a/mup/shape.py +++ b/mup/shape.py @@ -156,6 +156,13 @@ def apply_infshapes(model, infshapes): for name, p in model.named_parameters(): p.infshape = infshapes[name] + +def _fix_fsdp_readout(module): + assert isinstance(module, MuReadout) + assert hasattr(module.weight, 'infshape') + module.weight_infshape = module.weight.infshape + + def set_base_shapes(model, base, rescale_params=True, delta=None, savefile=None, do_assert=True): '''Sets the `p.infshape` attribute for each parameter `p` of `model`. @@ -192,6 +199,9 @@ def set_base_shapes(model, base, rescale_params=True, delta=None, savefile=None, module._rescale_parameters() elif isinstance(module, (Linear, _ConvNd)): rescale_linear_bias(module) + for name, module in model.named_modules(): + if isinstance(module, MuReadout): + _fix_fsdp_readout(module) return model def assert_hidden_size_inf(model): From be217f6c15e35d5b75e8e32db7465e180522b87c Mon Sep 17 00:00:00 2001 From: janEbert Date: Fri, 3 May 2024 15:52:49 +0200 Subject: [PATCH 2/3] Explain FSDP caveat --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index c87e1b1..732988a 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,7 @@ For example, the Adam learning rate of hidden weights `p` is calculated as `glo - `set_base_shapes(model, ...)` assumes that `model` has just been randomly initialized in the standard way and rescales its parameters using the base shape information so the model is in μP. - If you want data parallelism, please use `torch.nn.parallel.DistributedDataParallel` instead of `torch.nn.DataParallel`. This is because the latter removes the attributes the `mup` package adds to each parameter tensor of the model. Also, for performance, `pytorch` [recommends the former anyway](https://pytorch.org/docs/stable/notes/cuda.html#cuda-nn-ddp-instead). +- For `FullyShardedDataParallel` (FSDP) usage, you have to use a PyTorch version ≥2 and wrap your model like `FSDP(..., use_orig_params=True)`. - We scale the learning rate according to μP explicitly by creating refined parameter groups from what is passed to the `mup` optimizer and by manipulating the `lr` attribute in those groups. This is compatible with PyTorch's learning rate schedulers. However, if you roll your own, make sure the scheduler sets the learning rate relative to what is currently in the refined parameter groups. The following is an example of what *not* to do and what is OK: ```python optimizer = mup.MuAdam(model.parameters(), lr=1e-3) From 22ca9dd6964f8b7e2d309e2ec588f71edf26f78c Mon Sep 17 00:00:00 2001 From: janEbert Date: Fri, 3 May 2024 16:08:33 +0200 Subject: [PATCH 3/3] Improve logical flow This is nicer to read than the nested if-queries before. --- mup/layer.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/mup/layer.py b/mup/layer.py index 8bb07af..09c8161 100644 --- a/mup/layer.py +++ b/mup/layer.py @@ -25,17 +25,16 @@ def reset_parameters(self) -> None: super().reset_parameters() def width_mult(self): - if not hasattr(self.weight, 'infshape'): - if not hasattr(self, 'weight_infshape'): - raise AssertionError( - 'Please call set_base_shapes(...). If using torch.nn.DataParallel, ' - 'switch to distributed training with ' - 'torch.nn.parallel.DistributedDataParallel instead' - ) - else: - width_mult = self.weight_infshape.width_mult() - else: + if hasattr(self.weight, 'infshape'): width_mult = self.weight.infshape.width_mult() + elif hasattr(self, 'weight_infshape'): + width_mult = self.weight_infshape.width_mult() + else: + raise AssertionError( + 'Please call set_base_shapes(...). If using torch.nn.DataParallel, ' + 'switch to distributed training with ' + 'torch.nn.parallel.DistributedDataParallel instead' + ) return width_mult def _rescale_parameters(self):