* the Read-The-Docs is failing as this one leads to the public domain, which requires the repo to be public too
PyTorch has its own version of FSDP, which is upstreamed from their fairscale project. It was introduced in their v1.11.0 release, but it is recommended to use it with PyTorch v1.12 or more, and that's what Lightning supports.
Model layers should be wrapped in FSDP in a nested way to save peak memory and enable communication and computation overlapping. The
simplest way to do it is auto wrapping, which can serve as a drop-in replacement for DDP without changing the rest of the code. You don't
have to wrap
layers manually, as in the case of manual wrapping.
While initializing the optimizers inside configure_optimizers
hook, make sure to use self.trainer.model.parameters()
, else
PyTorch will raise an error. This is required because when you use auto-wrap, the model layers are sharded and your
lightning_module.parameters()
will return a generator with no params. This inconvenience will be addressed in the future.
from lightning_fairscale.strategies import DDPFullyShardedStrategy
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel
model = BoringModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy=DDPFullyShardedStrategy(), precision=16)
trainer.fit(model)
Read more here.
Manual wrapping can be useful to explore complex sharding strategies by applying wrap
selectively to some parts of the model. To activate
parameter sharding with manual wrapping, you can wrap your model using the wrap
function. Internally in Lightning, we enable a context manager around the configure_sharded_model
function to make sure the wrap
parameters are passed correctly.
When not using Fully Sharded these wrap functions are a no-op. This means once the changes have been made, there is no need to remove the changes for other strategies.
wrap
simply wraps the module with a Fully Sharded Parallel class with the correct parameters from the Lightning context manager.
Here's an example using that uses wrap
to create your model:
import torch
import torch.nn as nn
from lightning_fairscale.strategies import DDPFullyShardedStrategy
from pytorch_lightning import Trainer, LightningModule
from torch.distributed.fsdp.wrap import wrap
class MyModel(LightningModule):
def __init__(self):
super().__init__()
self.linear_layer = nn.Linear(32, 32)
self.block = nn.Sequential(nn.Linear(32, 32), nn.Linear(32, 32))
def configure_sharded_model(self):
# modules are sharded across processes
# as soon as they are wrapped with `wrap`.
# During the forward/backward passes, weights get synced across processes
# and de-allocated once computation is complete, saving memory.
# Wraps the layer in a Fully Sharded Wrapper automatically
linear_layer = wrap(self.linear_layer)
for i, layer in enumerate(self.block):
self.block[i] = wrap(layer)
self.model = nn.Sequential(linear_layer, nn.ReLU(), self.block)
def configure_optimizers(self):
return torch.optim.AdamW(self.model.parameters())
model = MyModel()
trainer = Trainer(accelerator="gpu", devices=4, strategy=DDPFullyShardedStrategy(), precision=16)
trainer.fit(model)
You can customize the strategy configuration by adjusting the arguments of :class:~pytorch_lightning.strategies.fully_sharded_native.DDPFullyShardedNativeStrategy
and pass that to the strategy
argument inside the Trainer
.
from pytorch_lightning import Trainer
from lightning_fairscale.strategies import DDPFullyShardedStrategy
native_fsdp = DDPFullyShardedStrategy(cpu_offload=True)
trainer = Trainer(strategy=native_fsdp, accelerator="gpu", devices=4)
Check out this tutorial to learn more about native support.
Activation checkpointing reduces GPU memory usage by avoiding the storage of intermediate activation tensors in selected layers. The tradeoff is that the computation cost for the backpropagation increases as the dropped activations need to be recomputed.
Enable checkpointing on large layers (like Transformers) by providing the layer class/type to the strategy:
from pytorch_lightning import Trainer
from lightning_fairscale.strategies import DDPFullyShardedStrategy
fsdp = DDPFullyShardedStrategy(
activation_checkpointing=MyTransformerBlock, # or pass a list with multiple types
)
trainer = Trainer(strategy=fsdp, accelerator="gpu", devices=4)
- We are using Napoleon style, and we shall use static types...
- It is nice to see doctest as they are also generated as examples in documentation
- For wider and edge cases testing, use pytest parametrization :]