-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add examples/rnn.rst * Add link to the RNN example in examples/index.rst * Add doc test in test_rst.py
- Loading branch information
1 parent
654dd88
commit 4d84d3e
Showing
3 changed files
with
65 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
Recurrent Neural Network (RNN) | ||
============================== | ||
|
||
When training recurrent neural networks for sequence modelling, we can easily obtain one loss per | ||
element of the output sequences. If the gradients of these losses are likely to conflict, Jacobian | ||
descent can be leveraged to enhance optimization. | ||
|
||
.. code-block:: python | ||
:emphasize-lines: 5-6, 10, 17, 20 | ||
import torch | ||
from torch.nn import RNN | ||
from torch.optim import SGD | ||
from torchjd import backward | ||
from torchjd.aggregation import UPGrad | ||
rnn = RNN(input_size=10, hidden_size=20, num_layers=2) | ||
optimizer = SGD(rnn.parameters(), lr=0.1) | ||
aggregator = UPGrad() | ||
inputs = torch.randn(8, 5, 3, 10) # 8 batches of 3 sequences of length 5 and of dim 10. | ||
targets = torch.randn(8, 5, 3, 20) # 8 batches of 3 sequences of length 5 and of dim 20. | ||
for input, target in zip(inputs, targets): | ||
output, _ = rnn(input) # output is of shape [5, 3, 20]. | ||
losses = ((output - target) ** 2).mean(dim=[1, 2]) # 1 loss per sequence element. | ||
optimizer.zero_grad() | ||
backward(losses, aggregator, parallel_chunk_size=1) | ||
optimizer.step() | ||
.. note:: | ||
At the time of writing, there seems to be an incompatibility between ``torch.vmap`` and | ||
``torch.nn.RNN`` when running on CUDA (see `this issue | ||
<https://github.com/TorchJD/torchjd/issues/220>`_ for more info), so we advise to set the | ||
``parallel_chunk_size`` to ``1`` to avoid using ``torch.vmap``. To improve performance, you can | ||
check whether ``parallel_chunk_size=None`` (maximal parallelization) works on your side. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters