diff --git a/docs/source/examples/index.rst b/docs/source/examples/index.rst index 5278cd76..62058bd8 100644 --- a/docs/source/examples/index.rst +++ b/docs/source/examples/index.rst @@ -13,6 +13,8 @@ This section contains some usage examples for TorchJD. - :doc:`Multi-Task Learning (MTL) ` provides an example of multi-task learning where Jacobian descent is used to optimize the vector of per-task losses of a multi-task model, using the dedicated backpropagation function :doc:`mtl_backward <../docs/autojac/mtl_backward>`. +- :doc:`Recurrent Neural Network (RNN) ` shows how to apply Jacobian descent to RNN training, + with one loss per output sequence element. - :doc:`PyTorch Lightning Integration ` showcases how to combine TorchJD with PyTorch Lightning, by providing an example implementation of a multi-task ``LightningModule`` optimized by Jacobian descent. @@ -23,4 +25,5 @@ This section contains some usage examples for TorchJD. basic_usage.rst iwrm.rst mtl.rst + rnn.rst lightning_integration.rst diff --git a/docs/source/examples/rnn.rst b/docs/source/examples/rnn.rst new file mode 100644 index 00000000..02f7ecdb --- /dev/null +++ b/docs/source/examples/rnn.rst @@ -0,0 +1,38 @@ +Recurrent Neural Network (RNN) +============================== + +When training recurrent neural networks for sequence modelling, we can easily obtain one loss per +element of the output sequences. If the gradients of these losses are likely to conflict, Jacobian +descent can be leveraged to enhance optimization. + +.. code-block:: python + :emphasize-lines: 5-6, 10, 17, 20 + + import torch + from torch.nn import RNN + from torch.optim import SGD + + from torchjd import backward + from torchjd.aggregation import UPGrad + + rnn = RNN(input_size=10, hidden_size=20, num_layers=2) + optimizer = SGD(rnn.parameters(), lr=0.1) + aggregator = UPGrad() + + inputs = torch.randn(8, 5, 3, 10) # 8 batches of 3 sequences of length 5 and of dim 10. + targets = torch.randn(8, 5, 3, 20) # 8 batches of 3 sequences of length 5 and of dim 20. + + for input, target in zip(inputs, targets): + output, _ = rnn(input) # output is of shape [5, 3, 20]. + losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element. + + optimizer.zero_grad() + backward(losses, aggregator, parallel_chunk_size=1) + optimizer.step() + +.. note:: + At the time of writing, there seems to be an incompatibility between ``torch.vmap`` and + ``torch.nn.RNN`` when running on CUDA (see `this issue + `_ for more info), so we advise to set the + ``parallel_chunk_size`` to ``1`` to avoid using ``torch.vmap``. To improve performance, you can + check whether ``parallel_chunk_size=None`` (maximal parallelization) works on your side. diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 944a7528..40e91682 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -183,3 +183,27 @@ def configure_optimizers(self) -> OptimizerLRScheduler: ) trainer.fit(model=model, train_dataloaders=train_loader) + + +def test_rnn(): + import torch + from torch.nn import RNN + from torch.optim import SGD + + from torchjd import backward + from torchjd.aggregation import UPGrad + + rnn = RNN(input_size=10, hidden_size=20, num_layers=2) + optimizer = SGD(rnn.parameters(), lr=0.1) + aggregator = UPGrad() + + inputs = torch.randn(8, 5, 3, 10) # 8 batches of 3 sequences of length 5 and of dim 10. + targets = torch.randn(8, 5, 3, 20) # 8 batches of 3 sequences of length 5 and of dim 20. + + for input, target in zip(inputs, targets): + output, _ = rnn(input) # output is of shape [5, 3, 20]. + losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element. + + optimizer.zero_grad() + backward(losses, aggregator, parallel_chunk_size=1) + optimizer.step()