Skip to content

Commit

Permalink
Add RNN example
Browse files Browse the repository at this point in the history
* Add examples/rnn.rst
* Add link to the RNN example in examples/index.rst
* Add doc test in test_rst.py
  • Loading branch information
ValerianRey committed Jan 2, 2025
1 parent 654dd88 commit 4d84d3e
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docs/source/examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ This section contains some usage examples for TorchJD.
- :doc:`Multi-Task Learning (MTL) <mtl>` provides an example of multi-task learning where Jacobian
descent is used to optimize the vector of per-task losses of a multi-task model, using the
dedicated backpropagation function :doc:`mtl_backward <../docs/autojac/mtl_backward>`.
- :doc:`Recurrent Neural Network (RNN) <rnn>` shows how to apply Jacobian descent to RNN training,
with one loss per output sequence element.
- :doc:`PyTorch Lightning Integration <lightning_integration>` showcases how to combine
TorchJD with PyTorch Lightning, by providing an example implementation of a multi-task
``LightningModule`` optimized by Jacobian descent.
Expand All @@ -23,4 +25,5 @@ This section contains some usage examples for TorchJD.
basic_usage.rst
iwrm.rst
mtl.rst
rnn.rst
lightning_integration.rst
38 changes: 38 additions & 0 deletions docs/source/examples/rnn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
Recurrent Neural Network (RNN)
==============================

When training recurrent neural networks for sequence modelling, we can easily obtain one loss per
element of the output sequences. If the gradients of these losses are likely to conflict, Jacobian
descent can be leveraged to enhance optimization.

.. code-block:: python
:emphasize-lines: 5-6, 10, 17, 20
import torch
from torch.nn import RNN
from torch.optim import SGD
from torchjd import backward
from torchjd.aggregation import UPGrad
rnn = RNN(input_size=10, hidden_size=20, num_layers=2)
optimizer = SGD(rnn.parameters(), lr=0.1)
aggregator = UPGrad()
inputs = torch.randn(8, 5, 3, 10) # 8 batches of 3 sequences of length 5 and of dim 10.
targets = torch.randn(8, 5, 3, 20) # 8 batches of 3 sequences of length 5 and of dim 20.
for input, target in zip(inputs, targets):
output, _ = rnn(input) # output is of shape [5, 3, 20].
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.
optimizer.zero_grad()
backward(losses, aggregator, parallel_chunk_size=1)
optimizer.step()
.. note::
At the time of writing, there seems to be an incompatibility between ``torch.vmap`` and
``torch.nn.RNN`` when running on CUDA (see `this issue
<https://github.com/TorchJD/torchjd/issues/220>`_ for more info), so we advise to set the
``parallel_chunk_size`` to ``1`` to avoid using ``torch.vmap``. To improve performance, you can
check whether ``parallel_chunk_size=None`` (maximal parallelization) works on your side.
24 changes: 24 additions & 0 deletions tests/doc/test_rst.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,27 @@ def configure_optimizers(self) -> OptimizerLRScheduler:
)

trainer.fit(model=model, train_dataloaders=train_loader)


def test_rnn():
import torch
from torch.nn import RNN
from torch.optim import SGD

from torchjd import backward
from torchjd.aggregation import UPGrad

rnn = RNN(input_size=10, hidden_size=20, num_layers=2)
optimizer = SGD(rnn.parameters(), lr=0.1)
aggregator = UPGrad()

inputs = torch.randn(8, 5, 3, 10) # 8 batches of 3 sequences of length 5 and of dim 10.
targets = torch.randn(8, 5, 3, 20) # 8 batches of 3 sequences of length 5 and of dim 20.

for input, target in zip(inputs, targets):
output, _ = rnn(input) # output is of shape [5, 3, 20].
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element.

optimizer.zero_grad()
backward(losses, aggregator, parallel_chunk_size=1)
optimizer.step()

0 comments on commit 4d84d3e

Please sign in to comment.