Skip to content

Commit

Permalink
Uniformize usage of code-block in documentation (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
ValerianRey authored Oct 30, 2024
1 parent 66a146c commit 95d9017
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 60 deletions.
54 changes: 35 additions & 19 deletions docs/source/examples/basic_usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,56 +12,72 @@ the parameters are updated using the resulting aggregation.

Import several classes from ``torch`` and ``torchjd``:

>>> import torch
>>> from torch.nn import MSELoss, Sequential, Linear, ReLU
>>> from torch.optim import SGD
>>>
>>> import torchjd
>>> from torchjd.aggregation import UPGrad
.. code-block:: python
import torch
from torch.nn import MSELoss, Sequential, Linear, ReLU
from torch.optim import SGD
import torchjd
from torchjd.aggregation import UPGrad
Define the model and the optimizer, as usual:

>>> model = Sequential(Linear(10, 5), ReLU(), Linear(5, 2))
>>> optimizer = SGD(model.parameters(), lr=0.1)
.. code-block:: python
model = Sequential(Linear(10, 5), ReLU(), Linear(5, 2))
optimizer = SGD(model.parameters(), lr=0.1)
Define the aggregator that will be used to combine the Jacobian matrix:

>>> A = UPGrad()
.. code-block:: python
A = UPGrad()
In essence, :doc:`UPGrad <../docs/aggregation/upgrad>` projects each gradient onto the dual cone of
the rows of the Jacobian and averages the results. This ensures that locally, no loss will be
negatively affected by the update.

Now that everything is defined, we can train the model. Define the input and the associated target:

>>> input = torch.randn(16, 10) # Batch of 16 random input vectors of length 10
>>> target1 = torch.randn(16) # First batch of 16 targets
>>> target2 = torch.randn(16) # Second batch of 16 targets
.. code-block:: python
input = torch.randn(16, 10) # Batch of 16 random input vectors of length 10
target1 = torch.randn(16) # First batch of 16 targets
target2 = torch.randn(16) # Second batch of 16 targets
Here, we generate fake inputs and labels for the sake of the example.

We can now compute the losses associated to each element of the batch.

>>> loss_fn = MSELoss()
>>> output = model(input)
>>> loss1 = loss_fn(output[:, 0], target1)
>>> loss2 = loss_fn(output[:, 1], target2)
.. code-block:: python
loss_fn = MSELoss()
output = model(input)
loss1 = loss_fn(output[:, 0], target1)
loss2 = loss_fn(output[:, 1], target2)
The last steps are similar to gradient descent-based optimization, but using the two losses.

Reset the ``.grad`` field of each model parameter:

>>> optimizer.zero_grad()
.. code-block:: python
optimizer.zero_grad()
Perform the Jacobian descent backward pass:

>>> torchjd.backward([loss1, loss2], model.parameters(), A)
.. code-block:: python
torchjd.backward([loss1, loss2], model.parameters(), A)
This will populate the ``.grad`` field of each model parameter with the corresponding aggregated
Jacobian matrix.

Update each parameter based on its ``.grad`` field, using the ``optimizer``:

>>> optimizer.step()
.. code-block:: python
optimizer.step()
The model's parameters have been updated!
2 changes: 1 addition & 1 deletion docs/source/examples/iwrm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ each Jacobian matrix consists of one gradient per loss. In this example, we use
IWRM with SSJD
^^^^^^^^^^^^^^
.. code-block:: python
:emphasize-lines: 10, 11, 21, 25, 29, 31
:emphasize-lines: 10-11, 21, 25, 29, 31
import torch
from torch.nn import (
Expand Down
84 changes: 44 additions & 40 deletions docs/source/examples/mtl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,50 @@ example shows how to use TorchJD to train a very simple multi-task model with tw
For the sake of the example, we generate a fake dataset consisting of 8 batches of 16 random input
vectors of dimension 10, and their corresponding scalar labels for both tasks.

>>> import torch
>>> from torch.nn import Linear, MSELoss, ReLU, Sequential
>>> from torch.optim import SGD
>>>
>>> from torchjd import mtl_backward
>>> from torchjd.aggregation import UPGrad
>>>
>>> shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
>>> task1_module = Linear(3, 1)
>>> task2_module = Linear(3, 1)
>>> params = [
>>> *shared_module.parameters(),
>>> *task1_module.parameters(),
>>> *task2_module.parameters(),
>>> ]
>>>
>>> loss_fn = MSELoss()
>>> optimizer = SGD(params, lr=0.1)
>>> A = UPGrad()
>>>
>>> inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
>>> task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
>>> task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
>>>
>>> for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
>>> features = shared_module(input)
>>> output1 = task1_module(features)
>>> output2 = task2_module(features)
>>> loss1 = loss_fn(output1, target1)
>>> loss2 = loss_fn(output2, target2)
>>>
>>> optimizer.zero_grad()
>>> mtl_backward(
... losses=[loss1, loss2],
... features=features,
... tasks_params=[task1_module.parameters(), task2_module.parameters()],
... shared_params=shared_module.parameters(),
... A=A,
... )
>>> optimizer.step()

.. code-block:: python
:emphasize-lines: 5-6, 19, 33-39
import torch
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD
from torchjd import mtl_backward
from torchjd.aggregation import UPGrad
shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
task1_module = Linear(3, 1)
task2_module = Linear(3, 1)
params = [
*shared_module.parameters(),
*task1_module.parameters(),
*task2_module.parameters(),
]
loss_fn = MSELoss()
optimizer = SGD(params, lr=0.1)
A = UPGrad()
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
for input, target1, target2 in zip(inputs, task1_targets, task2_targets):
features = shared_module(input)
output1 = task1_module(features)
output2 = task2_module(features)
loss1 = loss_fn(output1, target1)
loss2 = loss_fn(output2, target2)
optimizer.zero_grad()
mtl_backward(
losses=[loss1, loss2],
features=features,
tasks_params=[task1_module.parameters(), task2_module.parameters()],
shared_params=shared_module.parameters(),
A=A,
)
optimizer.step()
.. note::
In this example, the Jacobian is only with respect to the shared parameters. The task-specific
Expand Down

0 comments on commit 95d9017

Please sign in to comment.