Skip to content

Commit

Permalink
docs: Add AMP usage example (#245)
Browse files Browse the repository at this point in the history
* Add AMP example in amp.rst
* Add link to the AMP example in examples/index.rst
* Add doc test for AMP example in doc/test_rst.py
* Add changelog entry
  • Loading branch information
ValerianRey authored Feb 9, 2025
1 parent 1d87f56 commit 2fae5ea
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 0 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ changes that do not affect the user.

## [Unreleased]

### Added

- Added usage example showing how to combine TorchJD with automatic mixed precision (AMP).

### Changed

- Refactored the underlying optimization problem that `UPGrad` and `DualProj` have to solve to
Expand Down
64 changes: 64 additions & 0 deletions docs/source/examples/amp.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
Automatic Mixed Precision (AMP)
===============================

In some cases, to save memory and reduce computation time, you may want to use `automatic mixed
precision <https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html>`_. Since the
`torch.amp.GradScaler <https://pytorch.org/docs/stable/amp.html#gradient-scaling>`_ class already
works on multiple losses, it's pretty straightforward to combine TorchJD and AMP. As usual, the
forward pass should be wrapped within a `torch.autocast
<https://pytorch.org/docs/stable/amp.html#torch.autocast>`_ context, and as usual, the loss (in our
case, the losses) should preferably be scaled with a `GradScaler
<https://pytorch.org/docs/stable/amp.html#gradient-scaling>`_ to avoid gradient underflow. The
following example shows the resulting code for a multi-task learning use-case.

.. code-block:: python
:emphasize-lines: 2, 17, 27, 34, 36-38
import torch
from torch.amp import GradScaler
from torch.nn import Sequential, Linear, ReLU, MSELoss
from torch.optim import SGD
from torchjd import mtl_backward
from torchjd.aggregation import UPGrad
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
task2_module = Linear(3, 1)
params = [
*shared_module.parameters(),
*task1_module.parameters(),
*task2_module.parameters(),
]
scaler = GradScaler(device="cpu")
loss_fn = MSELoss()
optimizer = SGD(params, lr=0.1)
aggregator = UPGrad()
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
with torch.autocast(device_type="cpu", dtype=torch.float16):
features = shared_module(input)
output1 = task1_module(features)
output2 = task2_module(features)
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)
scaled_losses = scaler.scale([loss1, loss2])
optimizer.zero_grad()
mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator)
scaler.step(optimizer)
scaler.update()
.. hint::
Within the ``torch.autocast`` context, some operations may be done in ``float16`` type. For
those operations, the tensors saved for the backward pass will also be of ``float16`` type.
However, the Jacobian computed by ``mtl_backward`` will be of type ``float32``, so the ``.grad``
fields of the model parameters will also be of type ``float32``. This is in line with the
behavior of PyTorch, that would also compute all gradients in ``float32`` type.

.. note::
:doc:`torchjd.backward <../docs/autojac/backward>` can be similarly combined with AMP.
2 changes: 2 additions & 0 deletions docs/source/examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ This section contains some usage examples for TorchJD.
- :doc:`PyTorch Lightning Integration <lightning_integration>` showcases how to combine
TorchJD with PyTorch Lightning, by providing an example implementation of a multi-task
``LightningModule`` optimized by Jacobian descent.
- :doc:`Automatic Mixed Precision <amp>` shows how to combine mixed precision training with TorchJD.

.. toctree::
:hidden:
Expand All @@ -27,3 +28,4 @@ This section contains some usage examples for TorchJD.
mtl.rst
rnn.rst
lightning_integration.rst
amp.rst
41 changes: 41 additions & 0 deletions tests/doc/test_rst.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,44 @@ def test_rnn():
optimizer.zero_grad()
backward(losses, aggregator, parallel_chunk_size=1)
optimizer.step()


def test_amp():
import torch
from torch.amp import GradScaler
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd import mtl_backward
from torchjd.aggregation import UPGrad

shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
task2_module = Linear(3, 1)
params = [
*shared_module.parameters(),
*task1_module.parameters(),
*task2_module.parameters(),
]
scaler = GradScaler(device="cpu")
loss_fn = MSELoss()
optimizer = SGD(params, lr=0.1)
aggregator = UPGrad()

inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task

for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
with torch.autocast(device_type="cpu", dtype=torch.float16):
features = shared_module(input)
output1 = task1_module(features)
output2 = task2_module(features)
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)

scaled_losses = scaler.scale([loss1, loss2])
optimizer.zero_grad()
mtl_backward(losses=scaled_losses, features=features, aggregator=aggregator)
scaler.step(optimizer)
scaler.update()

0 comments on commit 2fae5ea

Please sign in to comment.