-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add lightning to test dependencies * Add lightning usage example * Add link to the lightning usage example in the examples index * Add test for lightning usage example * Add changelog entry
- Loading branch information
1 parent
95d9017
commit 17c2b94
Showing
5 changed files
with
151 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
PyTorch Lightning Integration | ||
============================= | ||
|
||
To use Jacobian descent with TorchJD in a :class:`~lightning.LightningModule`, you need to turn off | ||
automatic optimization by setting ``automatic_optimization`` to ``False`` and to customize the | ||
``training_step`` method to make it call the appropriate TorchJD method (:doc:`backward | ||
<../docs/autojac/backward>` or :doc:`mtl_backward <../docs/autojac/mtl_backward>`). | ||
|
||
The following code example demonstrates a basic multi-task learning setup using a | ||
:class:`~lightning.LightningModule` that will call :doc:`mtl_backward | ||
<../docs/autojac/mtl_backward>` at each training iteration. | ||
|
||
.. code-block:: python | ||
:emphasize-lines: 9-10, 18, 32-38 | ||
import torch | ||
from lightning import LightningModule, Trainer | ||
from lightning.pytorch.utilities.types import OptimizerLRScheduler | ||
from torch.nn import Linear, ReLU, Sequential | ||
from torch.nn.functional import mse_loss | ||
from torch.optim import Adam | ||
from torch.utils.data import DataLoader, TensorDataset | ||
from torchjd import mtl_backward | ||
from torchjd.aggregation import UPGrad | ||
class Model(LightningModule): | ||
def __init__(self): | ||
super().__init__() | ||
self.feature_extractor = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) | ||
self.task1_head = Linear(3, 1) | ||
self.task2_head = Linear(3, 1) | ||
self.automatic_optimization = False | ||
def training_step(self, batch, batch_idx) -> None: | ||
input, target1, target2 = batch | ||
features = self.feature_extractor(input) | ||
output1 = self.task1_head(features) | ||
output2 = self.task2_head(features) | ||
loss1 = mse_loss(output1, target1) | ||
loss2 = mse_loss(output2, target2) | ||
opt = self.optimizers() | ||
opt.zero_grad() | ||
mtl_backward( | ||
losses=[loss1, loss2], | ||
features=features, | ||
tasks_params=[self.task1_head.parameters(), self.task2_head.parameters()], | ||
shared_params=self.feature_extractor.parameters(), | ||
A=UPGrad(), | ||
) | ||
opt.step() | ||
def configure_optimizers(self) -> OptimizerLRScheduler: | ||
optimizer = Adam(self.parameters(), lr=1e-3) | ||
return optimizer | ||
model = Model() | ||
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 | ||
task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task | ||
task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task | ||
dataset = TensorDataset(inputs, task1_targets, task2_targets) | ||
train_loader = DataLoader(dataset) | ||
trainer = Trainer(accelerator="cpu", max_epochs=1, enable_checkpointing=False, logger=False) | ||
trainer.fit(model=model, train_dataloaders=train_loader) | ||
.. warning:: | ||
This will not handle automatic scaling in low-precision settings. There is currently no easy | ||
fix. | ||
|
||
.. warning:: | ||
TorchJD is incompatible with compiled models, so you must ensure that your model is not | ||
compiled. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters