Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid using vmap when parallel_chunk_size=1 #221

Closed
wants to merge 8 commits into from
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ changes that do not affect the user.

## [Unreleased]

### Changed

- Changed how the Jacobians are computed when calling `backward` or `mtl_backward` with
`parallel_chunk_size=1` to not rely on `torch.autograd.vmap` in this case. Whenever `vmap` does
not support something (compiled functions, RNN on cuda, etc.), users should now be able to avoid
using `vmap` by calling `backward` or `mtl_backward` with `parallel_chunk_size=1` and
`retain_graph=True`.

## [0.3.1] - 2024-12-21

### Changed
Expand Down
21 changes: 17 additions & 4 deletions src/torchjd/autojac/_transform/jac.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,23 @@ def get_vjp(grad_outputs: Sequence[Tensor]) -> Tensor:
grads = _materialize(optional_grads, inputs=inputs)
return torch.concatenate([grad.reshape([-1]) for grad in grads])

# Because of a limitation of vmap, this breaks when some tensors have `retains_grad=True`.
# See https://pytorch.org/functorch/stable/ux_limitations.html for more information.
# This also breaks when some tensors have been produced by compiled functions.
grouped_jacobian_matrix = torch.vmap(get_vjp, chunk_size=self.chunk_size)(jac_outputs)
if self.chunk_size == 1:
# In this special case, we don't need vmap, and because of the issues of vmap, we're
# better off not using it. In most cases, this should be equivalent to the vmap call,
# but in cases where vmap breaks (compiled functions, RNN on cuda, etc.), this should
# still work.
rows = []
for i in range(jac_outputs[0].shape[0]):
grad_outputs = [jac_output[i] for jac_output in jac_outputs]
gradient_vector = get_vjp(grad_outputs)
rows.append(gradient_vector)
grouped_jacobian_matrix = torch.vstack(rows)
else:
# Because of a limitation of vmap, this breaks when some tensors have
# `retains_grad=True`. See https://pytorch.org/functorch/stable/ux_limitations.html for
# more information. This also breaks when some tensors have been produced by compiled
# functions, and in some other cases (RNN on cuda, etc.).
grouped_jacobian_matrix = torch.vmap(get_vjp, chunk_size=self.chunk_size)(jac_outputs)

lengths = [input.numel() for input in inputs]
jacobian_matrices = _extract_sub_matrices(grouped_jacobian_matrix, lengths)
Expand Down
27 changes: 16 additions & 11 deletions tests/unit/autojac/_transform/test_jac.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import torch
from pytest import raises
from pytest import mark, raises
from unit.conftest import DEVICE

from torchjd.autojac._transform import Jac, Jacobians

from ._dict_assertions import assert_tensor_dicts_are_close


def test_single_input():
@mark.parametrize(["chunk_size", "retain_graph"], [(1, True), (3, True), (None, False)])
def test_single_input(chunk_size: int | None, retain_graph: bool):
"""
Tests that the Jac transform works correctly for an example of multiple differentiation. Here,
the function considered is: `y = [a1 * x, a2 * x]`. We want to compute the jacobians of `y` with
Expand All @@ -20,7 +21,7 @@ def test_single_input():
y = torch.stack([a1 * x, a2 * x])
input = Jacobians({y: torch.eye(2, device=DEVICE)})

jac = Jac(outputs=[y], inputs=[a1, a2], chunk_size=None)
jac = Jac(outputs=[y], inputs=[a1, a2], chunk_size=chunk_size, retain_graph=True)

jacobians = jac(input)
expected_jacobians = {
Expand All @@ -31,7 +32,8 @@ def test_single_input():
assert_tensor_dicts_are_close(jacobians, expected_jacobians)


def test_empty_inputs_1():
@mark.parametrize(["chunk_size", "retain_graph"], [(1, True), (3, True), (None, False)])
def test_empty_inputs_1(chunk_size: int | None, retain_graph: bool):
"""
Tests that the Jac transform works correctly when the `inputs` parameter is an empty `Iterable`.
"""
Expand All @@ -41,15 +43,16 @@ def test_empty_inputs_1():
y = torch.stack([y1, y2])
input = Jacobians({y: torch.eye(2, device=DEVICE)})

jac = Jac(outputs=[y], inputs=[], chunk_size=None)
jac = Jac(outputs=[y], inputs=[], chunk_size=chunk_size, retain_graph=True)

jacobians = jac(input)
expected_jacobians = {}

assert_tensor_dicts_are_close(jacobians, expected_jacobians)


def test_empty_inputs_2():
@mark.parametrize(["chunk_size", "retain_graph"], [(1, True), (3, True), (None, False)])
def test_empty_inputs_2(chunk_size: int | None, retain_graph: bool):
"""
Tests that the Jac transform works correctly when the `inputs` parameter is an empty `Iterable`.
"""
Expand All @@ -62,7 +65,7 @@ def test_empty_inputs_2():
y = torch.stack([y1, y2])
input = Jacobians({y: torch.eye(2, device=DEVICE)})

jac = Jac(outputs=[y], inputs=[], chunk_size=None)
jac = Jac(outputs=[y], inputs=[], chunk_size=chunk_size, retain_graph=True)

jacobians = jac(input)
expected_jacobians = {}
Expand Down Expand Up @@ -122,7 +125,8 @@ def test_two_levels():
assert_tensor_dicts_are_close(jacobians, expected_jacobians)


def test_multiple_outputs_1():
@mark.parametrize(["chunk_size", "retain_graph"], [(1, True), (3, True), (None, False)])
def test_multiple_outputs_1(chunk_size: int | None, retain_graph: bool):
"""
Tests that the Jac transform works correctly when the `outputs` contains 3 vectors.
The input (jac_outputs) is not the same for all outputs, so that this test also checks that the
Expand All @@ -143,7 +147,7 @@ def test_multiple_outputs_1():
jac_output3 = torch.cat([zeros_2x2, zeros_2x2, identity_2x2])
input = Jacobians({y1: jac_output1, y2: jac_output2, y3: jac_output3})

jac = Jac(outputs=[y1, y2, y3], inputs=[a1, a2], chunk_size=None)
jac = Jac(outputs=[y1, y2, y3], inputs=[a1, a2], chunk_size=chunk_size, retain_graph=True)

jacobians = jac(input)
zero_scalar = torch.tensor(0.0, device=DEVICE)
Expand All @@ -155,7 +159,8 @@ def test_multiple_outputs_1():
assert_tensor_dicts_are_close(jacobians, expected_jacobians)


def test_multiple_outputs_2():
@mark.parametrize(["chunk_size", "retain_graph"], [(1, True), (3, True), (None, False)])
def test_multiple_outputs_2(chunk_size: int | None, retain_graph: bool):
"""
Same as test_multiple_outputs_1 but with different jac_outputs, so the returned jacobians are of
different shapes.
Expand All @@ -175,7 +180,7 @@ def test_multiple_outputs_2():
jac_output3 = torch.stack([zeros_2, zeros_2, ones_2])
input = Jacobians({y1: jac_output1, y2: jac_output2, y3: jac_output3})

jac = Jac(outputs=[y1, y2, y3], inputs=[a1, a2], chunk_size=None)
jac = Jac(outputs=[y1, y2, y3], inputs=[a1, a2], chunk_size=chunk_size, retain_graph=True)

jacobians = jac(input)
expected_jacobians = {
Expand Down
68 changes: 64 additions & 4 deletions tests/unit/autojac/test_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
from pytest import mark, raises
from torch import nn
from torch.testing import assert_close
from unit._utils import ExceptionContext
from unit.conftest import DEVICE
Expand All @@ -26,11 +27,15 @@ def test_various_aggregators(aggregator: Aggregator):
assert (a.grad is not None) and (a.shape == a.grad.shape)


@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA()])
@mark.parametrize("shape", [(2, 3), (2, 6), (5, 8), (60, 55), (120, 143)])
@mark.parametrize("aggregator", [Mean(), UPGrad()])
@mark.parametrize("shape", [(2, 3), (2, 6), (5, 8), (20, 55)])
@mark.parametrize("manually_specify_inputs", [True, False])
@mark.parametrize("chunk_size", [1, 3, None])
def test_value_is_correct(
aggregator: Aggregator, shape: tuple[int, int], manually_specify_inputs: bool
aggregator: Aggregator,
shape: tuple[int, int],
manually_specify_inputs: bool,
chunk_size: int | None,
):
"""
Tests that the .grad value filled by backward is correct in a simple example of matrix-vector
Expand All @@ -46,7 +51,13 @@ def test_value_is_correct(
else:
inputs = None

backward([output], aggregator, inputs=inputs)
backward(
[output],
aggregator,
inputs=inputs,
retain_graph=True,
parallel_chunk_size=chunk_size,
)

assert_close(input.grad, aggregator(J))

Expand Down Expand Up @@ -203,3 +214,52 @@ def test_non_input_retaining_grad_fails():
with raises(RuntimeError):
# Using such a BatchedTensor should result in an error
_ = -b.grad


@mark.parametrize("chunk_size", [1, 3, None])
def test_tensor_used_multiple_times(chunk_size: int | None):
"""
Tests that backward works correctly when one of the inputs is used multiple times. In this
setup, the autograd graph is still acyclic, but the graph of tensors used becomes cyclic.
"""

a = torch.tensor(3.0, requires_grad=True, device=DEVICE)
b = 2.0 * a
c = a * b
d = a * c
e = a * d
aggregator = UPGrad()

backward([d, e], aggregator=aggregator, parallel_chunk_size=chunk_size, retain_graph=True)

expected_jacobian = torch.tensor(
[
[2.0 * 3.0 * a**2],
[2.0 * 4.0 * a**3],
],
device=DEVICE,
)

assert_close(a.grad, aggregator(expected_jacobian).squeeze())


def test_rnn():
"""
Tests that backward works for a very simple RNN, adapted from
[PyTorch's documentation](https://pytorch.org/docs/stable/generated/torch.nn.RNN.html).
"""

rnn = nn.RNN(input_size=10, hidden_size=20, num_layers=2).to(device=DEVICE)
input = torch.randn(5, 3, 10, device=DEVICE) # Batch of 3 sequences of length 5 and of dim 10.
h0 = torch.randn(2, 3, 20, device=DEVICE) # Batch of 3 hidden states of 2 layers of dim 20.
output, _ = rnn(input, h0) # Output is of shape [5, 3, 20].
target = torch.randn(5, 3, 20, device=DEVICE) # Batch of 3 sequences of len 5 and of dim 20.
losses = ((output - target) ** 2).sum(dim=[1, 2]) # 1 loss per sequence element.
aggregator = UPGrad()

# It's necessary to avoid using vmap by setting the parallel_chunk_size to 1 because the cuda
# implementation of RNN is not supported by vmap.
backward(tensors=losses, aggregator=aggregator, parallel_chunk_size=1, retain_graph=True)

for param in rnn.parameters():
assert param.grad is not None and param.grad.shape == param.shape
56 changes: 54 additions & 2 deletions tests/unit/autojac/test_mtl_backward.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from contextlib import nullcontext as does_not_raise
from itertools import chain

import torch
from pytest import mark, raises
from torch import nn
from torch.nn import BCELoss, MSELoss
from torch.testing import assert_close
from unit._utils import ExceptionContext
from unit.conftest import DEVICE
Expand Down Expand Up @@ -29,15 +32,17 @@ def test_various_aggregators(aggregator: Aggregator):
assert (p.grad is not None) and (p.shape == p.grad.shape)


@mark.parametrize("aggregator", [Mean(), UPGrad(), MGDA()])
@mark.parametrize("shape", [(2, 3), (2, 6), (5, 8), (60, 55), (120, 143)])
@mark.parametrize("aggregator", [Mean(), UPGrad()])
@mark.parametrize("shape", [(2, 3), (2, 6), (5, 8), (20, 55)])
@mark.parametrize("manually_specify_shared_params", [True, False])
@mark.parametrize("manually_specify_tasks_params", [True, False])
@mark.parametrize("chunk_size", [1, 3, None])
def test_value_is_correct(
aggregator: Aggregator,
shape: tuple[int, int],
manually_specify_shared_params: bool,
manually_specify_tasks_params: bool,
chunk_size: int | None,
):
"""
Tests that the .grad value filled by mtl_backward is correct in a simple example of
Expand Down Expand Up @@ -74,6 +79,8 @@ def test_value_is_correct(
aggregator=aggregator,
tasks_params=tasks_params,
shared_params=shared_params,
retain_graph=True,
parallel_chunk_size=chunk_size,
)

assert_close(p1.grad, f)
Expand Down Expand Up @@ -592,3 +599,48 @@ def test_default_shared_params_overlapping_with_default_tasks_params_fails():
aggregator=UPGrad(),
retain_graph=True,
)


def test_rnn():
"""
Tests that mtl_backward works for simple multitask model whose feature extractor is an RNN
adapted from
[PyTorch's documentation](https://pytorch.org/docs/stable/generated/torch.nn.RNN.html).

Here, we have a binary classification task and a 4-regressions task using the last hidden state
of the RNN as shared input features.
"""

rnn = nn.RNN(input_size=10, hidden_size=20, num_layers=2).to(device=DEVICE)
cls_head = nn.Linear(40, 1).to(device=DEVICE)
reg_head = nn.Linear(40, 4).to(device=DEVICE)

input = torch.randn(5, 3, 10, device=DEVICE) # Batch of 3 sequences of length 5 and of dim 10.
h0 = torch.randn(2, 3, 20, device=DEVICE) # Batch of 3 hidden states of 2 layers of dim 20.
_, hn = rnn(input, h0) # hn is of shape [2, 3, 20].
features = hn.permute(1, 0, 2).reshape(3, -1)
cls_output = torch.sigmoid(cls_head(features)).squeeze()
reg_output = reg_head(features)

cls_loss_fn = BCELoss()
reg_loss_fn = MSELoss()

cls_target = torch.tensor([1.0, 0.0, 1.0], device=DEVICE)
reg_target = torch.randn(3, 4, device=DEVICE)

cls_loss = cls_loss_fn(cls_output, cls_target)
reg_loss = reg_loss_fn(reg_output, reg_target)
losses = [cls_loss, reg_loss]

# It's necessary to avoid using vmap by setting the parallel_chunk_size to 1 because the cuda
# implementation of RNN is not supported by vmap.
mtl_backward(
losses=losses,
features=features,
aggregator=UPGrad(),
parallel_chunk_size=1,
retain_graph=True,
)

for param in chain(rnn.parameters(), cls_head.parameters(), reg_head.parameters()):
assert param.grad is not None and param.grad.shape == param.shape
Loading