Skip to content

Commit 87b76bd

Browse files
committed
FSDP1 deprecation msg
1 parent 78a6c91 commit 87b76bd

File tree

3 files changed

+35
-35
lines changed

3 files changed

+35
-35
lines changed

index.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -766,14 +766,14 @@ Welcome to PyTorch Tutorials
766766
:tags: Parallel-and-Distributed-Training
767767

768768
.. customcarditem::
769-
:header: Getting Started with Fully Sharded Data Parallel(FSDP)
770-
:card_description: Learn how to train models with Fully Sharded Data Parallel package.
769+
:header: Getting Started with Fully Sharded Data Parallel (FSDP2)
770+
:card_description: Learn how to train models with Fully Sharded Data Parallel (fully_shard) package.
771771
:image: _static/img/thumbnails/cropped/Getting-Started-with-FSDP.png
772772
:link: intermediate/FSDP_tutorial.html
773773
:tags: Parallel-and-Distributed-Training
774774

775775
.. customcarditem::
776-
:header: Advanced Model Training with Fully Sharded Data Parallel (FSDP)
776+
:header: Advanced Model Training with Fully Sharded Data Parallel (FSDP1)
777777
:card_description: Explore advanced model training with Fully Sharded Data Parallel package.
778778
:image: _static/img/thumbnails/cropped/Getting-Started-with-FSDP.png
779779
:link: intermediate/FSDP_advanced_tutorial.html

intermediate_source/FSDP1_tutorial.rst

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,19 @@ Getting Started with Fully Sharded Data Parallel(FSDP)
44
**Author**: `Hamid Shojanazeri <https://github.com/HamidShojanazeri>`__, `Yanli Zhao <https://github.com/zhaojuanmao>`__, `Shen Li <https://mrshenli.github.io/>`__
55

66
.. note::
7-
|edit| View and edit this tutorial in `github <https://github.com/pytorch/tutorials/blob/main/intermediate_source/FSDP_tutorial.rst>`__.
7+
|edit| FSDP1 is deprecated. Please check out `FSDP2 tutorial <https://docs.pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`_.
88

9-
Training AI models at a large scale is a challenging task that requires a lot of compute power and resources.
9+
Training AI models at a large scale is a challenging task that requires a lot of compute power and resources.
1010
It also comes with considerable engineering complexity to handle the training of these very large models.
1111
`PyTorch FSDP <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`__, released in PyTorch 1.11 makes this easier.
1212

13-
In this tutorial, we show how to use `FSDP APIs <https://pytorch.org/docs/stable/fsdp.html>`__, for simple MNIST models that can be extended to other larger models such as `HuggingFace BERT models <https://huggingface.co/blog/zero-deepspeed-fairscale>`__,
14-
`GPT 3 models up to 1T parameters <https://pytorch.medium.com/training-a-1-trillion-parameter-model-with-pytorch-fully-sharded-data-parallel-on-aws-3ac13aa96cff>`__ . The sample DDP MNIST code courtesy of `Patrick Hu <https://github.com/yqhu/>`_.
13+
In this tutorial, we show how to use `FSDP APIs <https://pytorch.org/docs/stable/fsdp.html>`__, for simple MNIST models that can be extended to other larger models such as `HuggingFace BERT models <https://huggingface.co/blog/zero-deepspeed-fairscale>`__,
14+
`GPT 3 models up to 1T parameters <https://pytorch.medium.com/training-a-1-trillion-parameter-model-with-pytorch-fully-sharded-data-parallel-on-aws-3ac13aa96cff>`__ . The sample DDP MNIST code courtesy of `Patrick Hu <https://github.com/yqhu/>`_.
1515

1616

1717
How FSDP works
1818
--------------
19-
In `DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html>`__, (DDP) training, each process/ worker owns a replica of the model and processes a batch of data, finally it uses all-reduce to sum up gradients over different workers. In DDP the model weights and optimizer states are replicated across all workers. FSDP is a type of data parallelism that shards model parameters, optimizer states and gradients across DDP ranks.
19+
In `DistributedDataParallel <https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html>`__, (DDP) training, each process/ worker owns a replica of the model and processes a batch of data, finally it uses all-reduce to sum up gradients over different workers. In DDP the model weights and optimizer states are replicated across all workers. FSDP is a type of data parallelism that shards model parameters, optimizer states and gradients across DDP ranks.
2020

2121
When training with FSDP, the GPU memory footprint is smaller than when training with DDP across all workers. This makes the training of some very large models feasible by allowing larger models or batch sizes to fit on device. This comes with the cost of increased communication volume. The communication overhead is reduced by internal optimizations like overlapping communication and computation.
2222

@@ -44,7 +44,7 @@ At a high level FSDP works as follow:
4444
* Run all_gather to collect all shards from all ranks to recover the full parameter in this FSDP unit
4545
* Run backward computation
4646
* Run reduce_scatter to sync gradients
47-
* Discard parameters.
47+
* Discard parameters.
4848

4949
One way to view FSDP's sharding is to decompose the DDP gradient all-reduce into reduce-scatter and all-gather. Specifically, during the backward pass, FSDP reduces and scatters gradients, ensuring that each rank possesses a shard of the gradients. Then it updates the corresponding shard of the parameters in the optimizer step. Finally, in the subsequent forward pass, it performs an all-gather operation to collect and combine the updated parameter shards.
5050

@@ -57,7 +57,7 @@ One way to view FSDP's sharding is to decompose the DDP gradient all-reduce into
5757

5858
How to use FSDP
5959
---------------
60-
Here we use a toy model to run training on the MNIST dataset for demonstration purposes. The APIs and logic can be applied to training larger models as well.
60+
Here we use a toy model to run training on the MNIST dataset for demonstration purposes. The APIs and logic can be applied to training larger models as well.
6161

6262
*Setup*
6363

@@ -116,7 +116,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
116116
def cleanup():
117117
dist.destroy_process_group()
118118
119-
2.1 Define our toy model for handwritten digit classification.
119+
2.1 Define our toy model for handwritten digit classification.
120120

121121
.. code-block:: python
122122
@@ -131,7 +131,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
131131
self.fc2 = nn.Linear(128, 10)
132132
133133
def forward(self, x):
134-
134+
135135
x = self.conv1(x)
136136
x = F.relu(x)
137137
x = self.conv2(x)
@@ -146,7 +146,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
146146
output = F.log_softmax(x, dim=1)
147147
return output
148148
149-
2.2 Define a train function
149+
2.2 Define a train function
150150

151151
.. code-block:: python
152152
@@ -169,7 +169,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
169169
if rank == 0:
170170
print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1]))
171171
172-
2.3 Define a validation function
172+
2.3 Define a validation function
173173

174174
.. code-block:: python
175175
@@ -230,8 +230,8 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
230230
size_based_auto_wrap_policy, min_num_params=100
231231
)
232232
torch.cuda.set_device(rank)
233-
234-
233+
234+
235235
init_start_event = torch.cuda.Event(enable_timing=True)
236236
init_end_event = torch.cuda.Event(enable_timing=True)
237237
@@ -261,7 +261,7 @@ We add the following code snippets to a python script “FSDP_mnist.py”.
261261
states = model.state_dict()
262262
if rank == 0:
263263
torch.save(states, "mnist_cnn.pt")
264-
264+
265265
cleanup()
266266
267267
@@ -309,7 +309,7 @@ We have recorded cuda events to measure the time of FSDP model specifics. The CU
309309
CUDA event elapsed time on training loop 40.67462890625sec
310310
311311
Wrapping the model with FSDP, the model will look as follows, we can see the model has been wrapped in one FSDP unit.
312-
Alternatively, we will look at adding the auto_wrap_policy next and will discuss the differences.
312+
Alternatively, we will look at adding the auto_wrap_policy next and will discuss the differences.
313313

314314
.. code-block:: bash
315315
@@ -326,7 +326,7 @@ Alternatively, we will look at adding the auto_wrap_policy next and will discuss
326326
)
327327
)
328328
329-
The following is the peak memory usage from FSDP MNIST training on g4dn.12.xlarge AWS EC2 instance with 4 GPUs captured from PyTorch Profiler.
329+
The following is the peak memory usage from FSDP MNIST training on g4dn.12.xlarge AWS EC2 instance with 4 GPUs captured from PyTorch Profiler.
330330

331331

332332
.. figure:: /_static/img/distributed/FSDP_memory.gif
@@ -336,18 +336,18 @@ The following is the peak memory usage from FSDP MNIST training on g4dn.12.xlarg
336336

337337
FSDP Peak Memory Usage
338338

339-
Applying *auto_wrap_policy* in FSDP otherwise, FSDP will put the entire model in one FSDP unit, which will reduce computation efficiency and memory efficiency.
340-
The way it works is that, suppose your model contains 100 Linear layers. If you do FSDP(model), there will only be one FSDP unit which wraps the entire model.
339+
Applying *auto_wrap_policy* in FSDP otherwise, FSDP will put the entire model in one FSDP unit, which will reduce computation efficiency and memory efficiency.
340+
The way it works is that, suppose your model contains 100 Linear layers. If you do FSDP(model), there will only be one FSDP unit which wraps the entire model.
341341
In that case, the allgather would collect the full parameters for all 100 linear layers, and hence won't save CUDA memory for parameter sharding.
342-
Also, there is only one blocking allgather call for the all 100 linear layers, there will not be communication and computation overlapping between layers.
342+
Also, there is only one blocking allgather call for the all 100 linear layers, there will not be communication and computation overlapping between layers.
343343

344344
To avoid that, you can pass in an auto_wrap_policy, which will seal the current FSDP unit and start a new one automatically when the specified condition is met (e.g., size limit).
345345
In that way you will have multiple FSDP units, and only one FSDP unit needs to collect full parameters at a time. E.g., suppose you have 5 FSDP units, and each wraps 20 linear layers.
346346
Then, in the forward, the 1st FSDP unit will allgather parameters for the first 20 linear layers, do computation, discard the parameters and then move on to the next 20 linear layers. So, at any point in time, each rank only materializes parameters/grads for 20 linear layers instead of 100.
347347

348348

349349
To do so in 2.4 we define the auto_wrap_policy and pass it to FSDP wrapper, in the following example, my_auto_wrap_policy defines that a layer could be wrapped or sharded by FSDP if the number of parameters in this layer is larger than 100.
350-
If the number of parameters in this layer is smaller than 100, it will be wrapped with other small layers together by FSDP.
350+
If the number of parameters in this layer is smaller than 100, it will be wrapped with other small layers together by FSDP.
351351
Finding an optimal auto wrap policy is challenging, PyTorch will add auto tuning for this config in the future. Without an auto tuning tool, it is good to profile your workflow using different auto wrap policies experimentally and find the optimal one.
352352

353353
.. code-block:: python
@@ -388,7 +388,7 @@ Applying the auto_wrap_policy, the model would be as follows:
388388
389389
CUDA event elapsed time on training loop 41.89130859375sec
390390
391-
The following is the peak memory usage from FSDP with auto_wrap policy of MNIST training on a g4dn.12.xlarge AWS EC2 instance with 4 GPUs captured from PyTorch Profiler.
391+
The following is the peak memory usage from FSDP with auto_wrap policy of MNIST training on a g4dn.12.xlarge AWS EC2 instance with 4 GPUs captured from PyTorch Profiler.
392392
It can be observed that the peak memory usage on each device is smaller compared to FSDP without auto wrap policy applied, from ~75 MB to 66 MB.
393393
394394
.. figure:: /_static/img/distributed/FSDP_autowrap.gif
@@ -398,13 +398,13 @@ It can be observed that the peak memory usage on each device is smaller compared
398398
399399
FSDP Peak Memory Usage using Auto_wrap policy
400400
401-
*CPU Off-loading*: In case the model is very large that even with FSDP wouldn't fit into GPUs, then CPU offload can be helpful here.
401+
*CPU Off-loading*: In case the model is very large that even with FSDP wouldn't fit into GPUs, then CPU offload can be helpful here.
402402
403403
Currently, only parameter and gradient CPU offload is supported. It can be enabled via passing in cpu_offload=CPUOffload(offload_params=True).
404404
405405
Note that this currently implicitly enables gradient offloading to CPU in order for params and grads to be on the same device to work with the optimizer. This API is subject to change. The default is None in which case there will be no offloading.
406406
407-
Using this feature may slow down the training considerably, due to frequent copying of tensors from host to device, but it could help improve memory efficiency and train larger scale models.
407+
Using this feature may slow down the training considerably, due to frequent copying of tensors from host to device, but it could help improve memory efficiency and train larger scale models.
408408
409409
In 2.4 we just add it to the FSDP wrapper
410410
@@ -430,7 +430,7 @@ Compare it with DDP, if in 2.4 we just normally wrap the model in DPP, saving th
430430
431431
CUDA event elapsed time on training loop 39.77766015625sec
432432
433-
The following is the peak memory usage from DDP MNIST training on g4dn.12.xlarge AWS EC2 instance with 4 GPUs captured from PyTorch profiler.
433+
The following is the peak memory usage from DDP MNIST training on g4dn.12.xlarge AWS EC2 instance with 4 GPUs captured from PyTorch profiler.
434434
435435
.. figure:: /_static/img/distributed/DDP_memory.gif
436436
:width: 100%
@@ -440,9 +440,9 @@ The following is the peak memory usage from DDP MNIST training on g4dn.12.xlarge
440440
DDP Peak Memory Usage using Auto_wrap policy
441441
442442
443-
Considering the toy example and tiny MNIST model we defined here, we can observe the difference between peak memory usage of DDP and FSDP.
443+
Considering the toy example and tiny MNIST model we defined here, we can observe the difference between peak memory usage of DDP and FSDP.
444444
In DDP each process holds a replica of the model, so the memory footprint is higher compared to FSDP which shards the model parameters, optimizer states and gradients over DDP ranks.
445-
The peak memory usage using FSDP with auto_wrap policy is the lowest followed by FSDP and DDP.
445+
The peak memory usage using FSDP with auto_wrap policy is the lowest followed by FSDP and DDP.
446446
447447
Also, looking at timings, considering the small model and running the training on a single machine, FSDP with and without auto_wrap policy performed almost as fast as DDP.
448448
This example does not represent most of the real applications, for detailed analysis and comparison between DDP and FSDP please refer to this `blog post <https://pytorch.medium.com/6c8da2be180d>`__ .

intermediate_source/FSDP_tutorial.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
Getting Started with Fully Sharded Data Parallel(FSDP)
1+
Getting Started with Fully Sharded Data Parallel (FSDP2)
22
======================================================
33

44
**Author**: `Wei Feng <https://github.com/weifengpy>`__, `Will Constable <https://github.com/wconstab>`__, `Yifan Mao <https://github.com/mori360>`__
55

66
.. note::
7-
|edit| Check out the code in this tutorial from `pytorch/examples <https://github.com/pytorch/examples/tree/main/distributed/FSDP2>`__.
7+
|edit| Check out the code in this tutorial from `pytorch/examples <https://github.com/pytorch/examples/tree/main/distributed/FSDP2>`_. FSDP1 will be deprecated. The old tutorial can be found `here <https://docs.pytorch.org/tutorials/intermediate/FSDP1_tutorial.html>`_.
88

99
How FSDP2 works
1010
--------------
@@ -166,7 +166,7 @@ Explicit prefetching works well in following situation:
166166
Enabling Mixed Precision
167167
~~~~~~~~~~~~~~~
168168

169-
FSDP2 offers a flexible `mixed precision policy <https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.MixedPrecisionPolicy>`_ to speed up training. One typical use case are
169+
FSDP2 offers a flexible `mixed precision policy <https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.MixedPrecisionPolicy>`_ to speed up training. One typical use case is
170170

171171
* Casting float32 parameters to bfloat16 for forward/backward computation, see ``param_dtype=torch.bfloat16``
172172
* Upcasting gradients to float32 for reduce-scatter to preserve accuracy, see ``reduce_dtype=torch.float32``
@@ -399,13 +399,13 @@ sync_module_states=True/False: Moved to DCP. User can broadcast state dicts from
399399

400400
forward_prefetch: Manual control over prefetching is possible with
401401

402-
* Manually call ``fsdp_module.unshard()``
403-
* Use these APIs to control automatic prefetching, ``set_modules_to_forward_prefetch`` and ``set_modules_to_backward_prefetch``
402+
* Manually call `fsdp_module.unshard() <https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule.unshard>`_
403+
* Use these APIs to control automatic prefetching, `set_modules_to_forward_prefetch <https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule.set_modules_to_forward_prefetch>`_ and `set_modules_to_backward_prefetch <https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule.set_modules_to_backward_prefetch>`_
404404

405405
limit_all_gathers: No longer needed, because ``fully_shard`` removed cpu synchronization
406406

407407
use_orig_params: Original params are always used (no more flat parameter)
408408

409409
no_sync(): `set_requires_gradient_sync <https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule.set_requires_gradient_sync>`_
410410

411-
ignored_params and ignored_states: ignored_params
411+
ignored_params and ignored_states: `ignored_params <https://docs.pytorch.org/docs/main/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.fully_shard>`_

0 commit comments

Comments
 (0)