Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor dual cone projections #237

Merged
merged 36 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
aa9897e
Factorize projection onto the dual cone.
PierreQuinton Jan 21, 2025
f534006
Move regularization out of projection onto the dual cone.
PierreQuinton Jan 21, 2025
864d7a4
Improve code coverage
PierreQuinton Jan 21, 2025
b0739fe
Add link to paper in weights_of_projection
PierreQuinton Feb 1, 2025
cd43f65
Rename functions in _dual_cone_utils
PierreQuinton Feb 1, 2025
402024c
Rename lagrangian to lagrangian_multipliers to match the maths
PierreQuinton Feb 1, 2025
a7ae47f
Separate _get_projection_weights into a tensorization of _get_lagrang…
PierreQuinton Feb 1, 2025
895b123
call the cast to numpy (and thus CPU) only once on the weights.
PierreQuinton Feb 1, 2025
25c5666
Improve the tensorization of _get_lagrange_multipliers_array
PierreQuinton Feb 1, 2025
8c52aa5
Improve readability of tensorization by using a for comprehension.
PierreQuinton Feb 1, 2025
3591577
Exchange _compute_normalized_regularized_gramian and _regularize_gramian
PierreQuinton Feb 1, 2025
c631f30
Add comment explaining the goal of regularization
PierreQuinton Feb 1, 2025
4aa0cbd
Rename:
PierreQuinton Feb 1, 2025
193ff9c
Exchange _normalize and _normalize_and_regularize
PierreQuinton Feb 1, 2025
a62c708
rename epsilons in _normalize and _regularize
PierreQuinton Feb 1, 2025
fb07035
Set explicitly the dtype and device of regularization matrix
PierreQuinton Feb 2, 2025
756e8c5
Rename functions & variables, add _to_array
ValerianRey Feb 2, 2025
601bc6e
Fix bug of transposed output
ValerianRey Feb 2, 2025
b3615cd
Delete outdated tests
ValerianRey Feb 2, 2025
c9afb41
Rename functions _normalize etc
ValerianRey Feb 2, 2025
e141600
Improve docstring and naming of test of kkt conditions.
PierreQuinton Feb 2, 2025
9c7eaac
Improve docstring of _get_projection_weights
PierreQuinton Feb 2, 2025
f04581e
Add test of tensorization of _get_projection_weights
PierreQuinton Feb 2, 2025
a5f562a
Clarify the QP problem that is solved, it now matches exactly that of…
PierreQuinton Feb 3, 2025
627adf0
Improve docstring of _get_projection_weight_vector
PierreQuinton Feb 3, 2025
bb411c9
Add changelog entry
PierreQuinton Feb 4, 2025
782aad6
Add no_grad context to the computation of UPGrad and DualProj as we c…
PierreQuinton Feb 4, 2025
ffca7d0
Improve dual cone utility functions
ValerianRey Feb 4, 2025
60a3851
Improve variable names:
ValerianRey Feb 4, 2025
c7a2fe8
Improve docstring of test_solution_weights
ValerianRey Feb 4, 2025
633b286
Remove torch.no_grad in DualProj and UPGrad
ValerianRey Feb 4, 2025
4f3e48a
Add docstring to test_tensorization_shape
ValerianRey Feb 4, 2025
0c224fc
Improve docstring formatting
ValerianRey Feb 4, 2025
c759201
Simplify UPGrad
ValerianRey Feb 4, 2025
c604c80
Add docstring to _regularize
ValerianRey Feb 4, 2025
20aa85c
Improve formatting
ValerianRey Feb 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ changes that do not affect the user.

## [Unreleased]

### Changed

- Refactored the underlying optimization problem that `UPGrad` and `DualProj` have to solve to
project onto the dual cone. This may minimally affect the output of these aggregators.

## [0.5.0] - 2025-02-01

### Added
Expand Down
55 changes: 55 additions & 0 deletions src/torchjd/aggregation/_dual_cone_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Literal

import numpy as np
import torch
from qpsolvers import solve_qp
from torch import Tensor


def _project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor:
"""
Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the
rows of a matrix whose Gramian is provided.

:param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`.
:param G: The Gramian matrix of shape `[m, m]`.
:param solver: The quadratic programming solver to use.
:return: A tensor of projection weights with the same shape as `U`.
"""

G_ = _to_array(G)
U_ = _to_array(U)

W = np.apply_along_axis(lambda u: _project_weight_vector(u, G_, solver), axis=-1, arr=U_)

return torch.as_tensor(W, device=G.device, dtype=G.dtype)


def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: Literal["quadprog"]) -> np.ndarray:
r"""
Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`,
given `G = J J^T` and `u`. In other words, this computes the `w` that satisfies
`\pi_J(J^T u) = J^T w`.

By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic program:
minimize v^T G v
subject to u \preceq v

Reference:
[1] `Jacobian Descent For Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_.

:param u: The vector of weights `u` of shape `[m]` corresponding to the vector `J^T u` to
project.
:param G: The Gramian matrix of `J`, equal to `J J^T`, and of shape `[m, m]`.
:param solver: The quadratic programming solver to use.
"""

m = G.shape[0]
w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=solver)
return w


def _to_array(tensor: Tensor) -> np.ndarray:
"""Transforms a tensor into a numpy array with float64 dtype."""

return tensor.cpu().detach().numpy().astype(np.float64)
25 changes: 23 additions & 2 deletions src/torchjd/aggregation/_gramian_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,16 @@ def _compute_gramian(matrix: Tensor) -> Tensor:
"""
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of a given matrix.
"""

return matrix @ matrix.T


def _compute_normalized_gramian(matrix: Tensor, norm_eps: float) -> Tensor:
def _compute_regularized_normalized_gramian(matrix: Tensor, norm_eps: float, reg_eps: float):
normalized_gramian = _compute_normalized_gramian(matrix, norm_eps)
return _regularize(normalized_gramian, reg_eps)


def _compute_normalized_gramian(matrix: Tensor, eps: float) -> Tensor:
r"""
Computes :math:`\frac{1}{\sigma_\max^2} J J^T` for an input matrix :math:`J`, where
:math:`{\sigma_\max^2}` is :math:`J`'s largest singular value.
Expand All @@ -35,11 +41,26 @@ def _compute_normalized_gramian(matrix: Tensor, norm_eps: float) -> Tensor:
"issue on https://github.com/TorchJD/torchjd/issues and paste this error message in it."
) from error
max_singular_value = torch.max(singular_values)
if max_singular_value < norm_eps:
if max_singular_value < eps:
scaled_singular_values = torch.zeros_like(singular_values)
else:
scaled_singular_values = singular_values / max_singular_value
normalized_gramian = (
left_unitary_matrix @ torch.diag(scaled_singular_values**2) @ left_unitary_matrix.T
)
return normalized_gramian


def _regularize(gramian: Tensor, eps: float) -> Tensor:
"""
Adds a regularization term to the gramian to enforce positive definiteness.

Because of numerical errors, `gramian` might have slightly negative eigenvalue(s). Adding a
regularization term which is a small proportion of the identity matrix ensures that the gramian
is positive definite.
"""

regularization_matrix = eps * torch.eye(
gramian.shape[0], dtype=gramian.dtype, device=gramian.device
)
return gramian + regularization_matrix
33 changes: 6 additions & 27 deletions src/torchjd/aggregation/dualproj.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Literal

import numpy as np
import torch
from qpsolvers import solve_qp
from torch import Tensor

from ._gramian_utils import _compute_normalized_gramian
from ._dual_cone_utils import _project_weights
from ._gramian_utils import _compute_regularized_normalized_gramian
from ._pref_vector_utils import (
_check_pref_vector,
_pref_vector_to_str_suffix,
Expand Down Expand Up @@ -106,26 +104,7 @@ def __init__(
self.solver = solver

def forward(self, matrix: Tensor) -> Tensor:
weights = self.weighting(matrix)
weights_array = weights.cpu().detach().numpy().astype(np.float64)

gramian = _compute_normalized_gramian(matrix, self.norm_eps)
gramian_array = gramian.cpu().detach().numpy().astype(np.float64)
dimension = gramian.shape[0]

# Because of numerical errors, `gramian_array` might have slightly negative eigenvalue(s),
# which makes quadprog misbehave. Adding a regularization term which is a small proportion
# of the identity matrix ensures that the gramian is positive definite.
regularization_array = self.reg_eps * np.eye(dimension)
regularized_gramian_array = gramian_array + regularization_array

P = regularized_gramian_array
q = regularized_gramian_array @ weights_array
G = -np.eye(dimension)
h = np.zeros(dimension)

projection_weights_array = solve_qp(P, q, G, h, solver=self.solver)
projection_weights = torch.from_numpy(projection_weights_array).to(
device=matrix.device, dtype=matrix.dtype
)
return projection_weights + weights
u = self.weighting(matrix)
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
w = _project_weights(u, G, self.solver)
return w
43 changes: 6 additions & 37 deletions src/torchjd/aggregation/upgrad.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Literal

import numpy as np
import torch
from qpsolvers import solve_qp
from torch import Tensor

from ._gramian_utils import _compute_normalized_gramian
from ._dual_cone_utils import _project_weights
from ._gramian_utils import _compute_regularized_normalized_gramian
from ._pref_vector_utils import (
_check_pref_vector,
_pref_vector_to_str_suffix,
Expand Down Expand Up @@ -101,37 +100,7 @@ def __init__(
self.solver = solver

def forward(self, matrix: Tensor) -> Tensor:
weights = self.weighting(matrix)
lagrangian = self._compute_lagrangian(matrix, weights)
lagrangian_weights = torch.sum(lagrangian, dim=0)
result_weights = lagrangian_weights + weights
return result_weights

def _compute_lagrangian(self, matrix: Tensor, weights: Tensor) -> Tensor:
gramian = _compute_normalized_gramian(matrix, self.norm_eps)
gramian_array = gramian.cpu().detach().numpy().astype(np.float64)
dimension = gramian.shape[0]

regularization_array = self.reg_eps * np.eye(dimension)
regularized_gramian_array = gramian_array + regularization_array

P = regularized_gramian_array
G = -np.eye(dimension)
h = np.zeros(dimension)

lagrangian_rows = []
for i in range(dimension):
weight = weights[i].item()
if weight <= 0.0:
# In this case, the solution to the quadratic program is always 0,
# so we don't need to run solve_qp.
lagrangian_rows.append(np.zeros([dimension]))
else:
q = weight * regularized_gramian_array[i, :]
lagrangian_rows.append(solve_qp(P, q, G, h, solver=self.solver))

lagrangian_array = np.stack(lagrangian_rows)
lagrangian = torch.from_numpy(lagrangian_array).to(
device=gramian.device, dtype=gramian.dtype
)
return lagrangian
U = torch.diag(self.weighting(matrix))
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
W = _project_weights(U, G, self.solver)
return torch.sum(W, dim=0)
69 changes: 69 additions & 0 deletions tests/unit/aggregation/test_dual_cone_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
from pytest import mark
from torch.testing import assert_close

from torchjd.aggregation._dual_cone_utils import _project_weights


@mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)])
def test_solution_weights(shape: tuple[int, int]):
r"""
Tests that `_project_weights` returns valid weights corresponding to the projection onto the
dual cone of a matrix with the specified shape.
Validation is performed by verifying that the solution satisfies the `KKT conditions
<https://en.wikipedia.org/wiki/Karush%E2%80%93Kuhn%E2%80%93Tucker_conditions>`_ for the
quadratic program that projects vectors onto the dual cone of a matrix. Specifically, the
solution should satisfy the equivalent set of conditions described in Lemma 4 of [1].
Let `u` be a vector of weights and `G` a positive semi-definite matrix. Consider the quadratic
problem of minimizing `v^T G v` subject to `u \preceq v`.
Then `w` is a solution if and only if it satisfies the following three conditions:
1. **Dual feasibility:** `u \preceq w`
2. **Primal feasibility:** `0 \preceq G w`
3. **Complementary slackness:** `u^T G w = w^T G w`
Reference:
[1] `Jacobian Descent For Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_.
"""

J = torch.randn(shape)
G = J @ J.T
u = torch.rand(shape[0])

w = _project_weights(u, G, "quadprog")
dual_gap = w - u

# Dual feasibility
dual_gap_positive_part = dual_gap[dual_gap >= 0.0]
PierreQuinton marked this conversation as resolved.
Show resolved Hide resolved
assert_close(dual_gap_positive_part.norm(), dual_gap.norm(), atol=1e-05, rtol=0)

primal_gap = G @ w

# Primal feasibility
primal_gap_positive_part = primal_gap[primal_gap >= 0]
assert_close(primal_gap_positive_part.norm(), primal_gap.norm(), atol=1e-04, rtol=0)

# Complementary slackness
slackness = dual_gap @ primal_gap
assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0)


@mark.parametrize("shape", [(5, 2, 3), (1, 3, 6, 9), (2, 1, 1, 5, 8), (3, 1)])
def test_tensorization_shape(shape: tuple[int, ...]):
"""
Tests that applying `_project_weights` on a tensor is equivalent to applying it on the tensor
reshaped as matrix and to reshape the result back to the original tensor's shape.
"""

matrix = torch.randn([shape[-1], shape[-1]])
U_tensor = torch.randn(shape)
U_matrix = U_tensor.reshape([-1, shape[-1]])

G = matrix @ matrix.T

W_tensor = _project_weights(U_tensor, G, "quadprog")
W_matrix = _project_weights(U_matrix, G, "quadprog")

assert_close(W_matrix.reshape(shape), W_tensor)
28 changes: 0 additions & 28 deletions tests/unit/aggregation/test_upgrad.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import torch
from pytest import mark
from torch.testing import assert_close

from torchjd.aggregation import UPGrad
from torchjd.aggregation.mean import _MeanWeighting
from torchjd.aggregation.upgrad import _UPGradWrapper

from ._property_testers import (
ExpectedStructureProperty,
Expand All @@ -18,31 +15,6 @@ class TestUPGrad(ExpectedStructureProperty, NonConflictingProperty, PermutationI
pass


@mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)])
def test_upgrad_lagrangian_satisfies_kkt_conditions(shape: tuple[int, int]):
matrix = torch.randn(shape)
weights = torch.rand(shape[0])

gramian = matrix @ matrix.T

W = _UPGradWrapper(_MeanWeighting(), norm_eps=0.0001, reg_eps=0.0, solver="quadprog")

lagrange_multiplier = W._compute_lagrangian(matrix, weights)

positive_lagrange_multiplier = lagrange_multiplier[lagrange_multiplier >= 0]
assert_close(
positive_lagrange_multiplier.norm(), lagrange_multiplier.norm(), atol=1e-05, rtol=0
)

constraint = gramian @ (torch.diag(weights) + lagrange_multiplier.T)

positive_constraint = constraint[constraint >= 0]
assert_close(positive_constraint.norm(), constraint.norm(), atol=1e-04, rtol=0)

slackness = torch.trace(lagrange_multiplier @ constraint)
assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0)


def test_representations():
A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog")
assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')"
Expand Down