Skip to content

Commit

Permalink
Add lightning usage example (#160)
Browse files Browse the repository at this point in the history
* Add lightning to test dependencies
* Add lightning usage example
* Add link to the lightning usage example in the examples index
* Add test for lightning usage example
* Add changelog entry
  • Loading branch information
ValerianRey authored Oct 30, 2024
1 parent 95d9017 commit 17c2b94
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 0 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). This changelog does not include internal
changes that do not affect the user.

## [Unreleased]

### Added

- PyTorch Lightning integration example.

## [0.2.1] - 2024-09-17

### Changed
Expand Down
4 changes: 4 additions & 0 deletions docs/source/examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ This section contains some usage examples for TorchJD.
- :doc:`Multi-Task Learning (MTL) <mtl>` provides an example of multi-task learning where Jacobian
descent is used to optimize the vector of per-task losses of a multi-task model, using the
dedicated backpropagation function :doc:`mtl_backward <../docs/autojac/mtl_backward>`.
- :doc:`PyTorch Lightning Integration <lightning_integration>` showcases how to combine
TorchJD with PyTorch Lightning, by providing an example implementation of a multi-task
``LightningModule`` optimized by Jacobian descent.

.. toctree::
:hidden:

basic_usage.rst
iwrm.rst
mtl.rst
lightning_integration.rst
78 changes: 78 additions & 0 deletions docs/source/examples/lightning_integration.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
PyTorch Lightning Integration
=============================

To use Jacobian descent with TorchJD in a :class:`~lightning.LightningModule`, you need to turn off
automatic optimization by setting ``automatic_optimization`` to ``False`` and to customize the
``training_step`` method to make it call the appropriate TorchJD method (:doc:`backward
<../docs/autojac/backward>` or :doc:`mtl_backward <../docs/autojac/mtl_backward>`).

The following code example demonstrates a basic multi-task learning setup using a
:class:`~lightning.LightningModule` that will call :doc:`mtl_backward
<../docs/autojac/mtl_backward>` at each training iteration.

.. code-block:: python
:emphasize-lines: 9-10, 18, 32-38
import torch
from lightning import LightningModule, Trainer
from lightning.pytorch.utilities.types import OptimizerLRScheduler
from torch.nn import Linear, ReLU, Sequential
from torch.nn.functional import mse_loss
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset
from torchjd import mtl_backward
from torchjd.aggregation import UPGrad
class Model(LightningModule):
def __init__(self):
super().__init__()
self.feature_extractor = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
self.task1_head = Linear(3, 1)
self.task2_head = Linear(3, 1)
self.automatic_optimization = False
def training_step(self, batch, batch_idx) -> None:
input, target1, target2 = batch
features = self.feature_extractor(input)
output1 = self.task1_head(features)
output2 = self.task2_head(features)
loss1 = mse_loss(output1, target1)
loss2 = mse_loss(output2, target2)
opt = self.optimizers()
opt.zero_grad()
mtl_backward(
losses=[loss1, loss2],
features=features,
tasks_params=[self.task1_head.parameters(), self.task2_head.parameters()],
shared_params=self.feature_extractor.parameters(),
A=UPGrad(),
)
opt.step()
def configure_optimizers(self) -> OptimizerLRScheduler:
optimizer = Adam(self.parameters(), lr=1e-3)
return optimizer
model = Model()
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task
dataset = TensorDataset(inputs, task1_targets, task2_targets)
train_loader = DataLoader(dataset)
trainer = Trainer(accelerator="cpu", max_epochs=1, enable_checkpointing=False, logger=False)
trainer.fit(model=model, train_dataloaders=train_loader)
.. warning::
This will not handle automatic scaling in low-precision settings. There is currently no easy
fix.

.. warning::
TorchJD is incompatible with compiled models, so you must ensure that your model is not
compiled.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ doc = [

test = [
"pytest>=7.3", # Before version 7.3, not all tests are run
"lightning>=2.0.9", # No OptimizerLRScheduler public type before 2.0.9
]

plot = [
Expand Down
62 changes: 62 additions & 0 deletions tests/doc/test_rst.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,65 @@ def test_mtl():
A=A,
)
optimizer.step()


def test_lightning_integration():
import warnings

warnings.filterwarnings("ignore")

import torch
from lightning import LightningModule, Trainer
from lightning.pytorch.utilities.types import OptimizerLRScheduler
from torch.nn import Linear, ReLU, Sequential
from torch.nn.functional import mse_loss
from torch.optim import Adam
from torch.utils.data import DataLoader, TensorDataset

from torchjd import mtl_backward
from torchjd.aggregation import UPGrad

class Model(LightningModule):
def __init__(self):
super().__init__()
self.feature_extractor = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
self.task1_head = Linear(3, 1)
self.task2_head = Linear(3, 1)
self.automatic_optimization = False

def training_step(self, batch, batch_idx) -> None:
input, target1, target2 = batch

features = self.feature_extractor(input)
output1 = self.task1_head(features)
output2 = self.task2_head(features)

loss1 = mse_loss(output1, target1)
loss2 = mse_loss(output2, target2)

opt = self.optimizers()
opt.zero_grad()
mtl_backward(
losses=[loss1, loss2],
features=features,
tasks_params=[self.task1_head.parameters(), self.task2_head.parameters()],
shared_params=self.feature_extractor.parameters(),
A=UPGrad(),
)
opt.step()

def configure_optimizers(self) -> OptimizerLRScheduler:
optimizer = Adam(self.parameters(), lr=1e-3)
return optimizer

model = Model()

inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task

dataset = TensorDataset(inputs, task1_targets, task2_targets)
train_loader = DataLoader(dataset)
trainer = Trainer(accelerator="cpu", max_epochs=1, enable_checkpointing=False, logger=False)

trainer.fit(model=model, train_dataloaders=train_loader)

0 comments on commit 17c2b94

Please sign in to comment.