Skip to content

Commit

Permalink
refactor(aggregation): Refactor dual cone projections (#237)
Browse files Browse the repository at this point in the history
* Simplify dual cone projection by solving directly the problem of Proposition 1 of the paper
* Add _project_weights, _project_weight_vector and _to_array in _dual_cone_utils.py to factorize dual cone projections
* Rename norm_eps to eps in _compute_normalized_gramian
* Add _compute_regularized_normalized_gramian and _regularize in _gramian_utils.py to factorize regularization of the gramian
* Improve testing of dual cone projections
* Add changelog entry
---------
Co-authored-by: Valérian Rey <[email protected]>
  • Loading branch information
PierreQuinton authored Feb 4, 2025
1 parent 6e6f54c commit 1d87f56
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 94 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ changes that do not affect the user.

## [Unreleased]

### Changed

- Refactored the underlying optimization problem that `UPGrad` and `DualProj` have to solve to
project onto the dual cone. This may minimally affect the output of these aggregators.

## [0.5.0] - 2025-02-01

### Added
Expand Down
55 changes: 55 additions & 0 deletions src/torchjd/aggregation/_dual_cone_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Literal

import numpy as np
import torch
from qpsolvers import solve_qp
from torch import Tensor


def _project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor:
"""
Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the
rows of a matrix whose Gramian is provided.
:param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`.
:param G: The Gramian matrix of shape `[m, m]`.
:param solver: The quadratic programming solver to use.
:return: A tensor of projection weights with the same shape as `U`.
"""

G_ = _to_array(G)
U_ = _to_array(U)

W = np.apply_along_axis(lambda u: _project_weight_vector(u, G_, solver), axis=-1, arr=U_)

return torch.as_tensor(W, device=G.device, dtype=G.dtype)


def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: Literal["quadprog"]) -> np.ndarray:
r"""
Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`,
given `G = J J^T` and `u`. In other words, this computes the `w` that satisfies
`\pi_J(J^T u) = J^T w`.
By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic program:
minimize v^T G v
subject to u \preceq v
Reference:
[1] `Jacobian Descent For Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_.
:param u: The vector of weights `u` of shape `[m]` corresponding to the vector `J^T u` to
project.
:param G: The Gramian matrix of `J`, equal to `J J^T`, and of shape `[m, m]`.
:param solver: The quadratic programming solver to use.
"""

m = G.shape[0]
w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=solver)
return w


def _to_array(tensor: Tensor) -> np.ndarray:
"""Transforms a tensor into a numpy array with float64 dtype."""

return tensor.cpu().detach().numpy().astype(np.float64)
25 changes: 23 additions & 2 deletions src/torchjd/aggregation/_gramian_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,16 @@ def _compute_gramian(matrix: Tensor) -> Tensor:
"""
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of a given matrix.
"""

return matrix @ matrix.T


def _compute_normalized_gramian(matrix: Tensor, norm_eps: float) -> Tensor:
def _compute_regularized_normalized_gramian(matrix: Tensor, norm_eps: float, reg_eps: float):
normalized_gramian = _compute_normalized_gramian(matrix, norm_eps)
return _regularize(normalized_gramian, reg_eps)


def _compute_normalized_gramian(matrix: Tensor, eps: float) -> Tensor:
r"""
Computes :math:`\frac{1}{\sigma_\max^2} J J^T` for an input matrix :math:`J`, where
:math:`{\sigma_\max^2}` is :math:`J`'s largest singular value.
Expand All @@ -35,11 +41,26 @@ def _compute_normalized_gramian(matrix: Tensor, norm_eps: float) -> Tensor:
"issue on https://github.com/TorchJD/torchjd/issues and paste this error message in it."
) from error
max_singular_value = torch.max(singular_values)
if max_singular_value < norm_eps:
if max_singular_value < eps:
scaled_singular_values = torch.zeros_like(singular_values)
else:
scaled_singular_values = singular_values / max_singular_value
normalized_gramian = (
left_unitary_matrix @ torch.diag(scaled_singular_values**2) @ left_unitary_matrix.T
)
return normalized_gramian


def _regularize(gramian: Tensor, eps: float) -> Tensor:
"""
Adds a regularization term to the gramian to enforce positive definiteness.
Because of numerical errors, `gramian` might have slightly negative eigenvalue(s). Adding a
regularization term which is a small proportion of the identity matrix ensures that the gramian
is positive definite.
"""

regularization_matrix = eps * torch.eye(
gramian.shape[0], dtype=gramian.dtype, device=gramian.device
)
return gramian + regularization_matrix
33 changes: 6 additions & 27 deletions src/torchjd/aggregation/dualproj.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Literal

import numpy as np
import torch
from qpsolvers import solve_qp
from torch import Tensor

from ._gramian_utils import _compute_normalized_gramian
from ._dual_cone_utils import _project_weights
from ._gramian_utils import _compute_regularized_normalized_gramian
from ._pref_vector_utils import (
_check_pref_vector,
_pref_vector_to_str_suffix,
Expand Down Expand Up @@ -106,26 +104,7 @@ def __init__(
self.solver = solver

def forward(self, matrix: Tensor) -> Tensor:
weights = self.weighting(matrix)
weights_array = weights.cpu().detach().numpy().astype(np.float64)

gramian = _compute_normalized_gramian(matrix, self.norm_eps)
gramian_array = gramian.cpu().detach().numpy().astype(np.float64)
dimension = gramian.shape[0]

# Because of numerical errors, `gramian_array` might have slightly negative eigenvalue(s),
# which makes quadprog misbehave. Adding a regularization term which is a small proportion
# of the identity matrix ensures that the gramian is positive definite.
regularization_array = self.reg_eps * np.eye(dimension)
regularized_gramian_array = gramian_array + regularization_array

P = regularized_gramian_array
q = regularized_gramian_array @ weights_array
G = -np.eye(dimension)
h = np.zeros(dimension)

projection_weights_array = solve_qp(P, q, G, h, solver=self.solver)
projection_weights = torch.from_numpy(projection_weights_array).to(
device=matrix.device, dtype=matrix.dtype
)
return projection_weights + weights
u = self.weighting(matrix)
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
w = _project_weights(u, G, self.solver)
return w
43 changes: 6 additions & 37 deletions src/torchjd/aggregation/upgrad.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Literal

import numpy as np
import torch
from qpsolvers import solve_qp
from torch import Tensor

from ._gramian_utils import _compute_normalized_gramian
from ._dual_cone_utils import _project_weights
from ._gramian_utils import _compute_regularized_normalized_gramian
from ._pref_vector_utils import (
_check_pref_vector,
_pref_vector_to_str_suffix,
Expand Down Expand Up @@ -101,37 +100,7 @@ def __init__(
self.solver = solver

def forward(self, matrix: Tensor) -> Tensor:
weights = self.weighting(matrix)
lagrangian = self._compute_lagrangian(matrix, weights)
lagrangian_weights = torch.sum(lagrangian, dim=0)
result_weights = lagrangian_weights + weights
return result_weights

def _compute_lagrangian(self, matrix: Tensor, weights: Tensor) -> Tensor:
gramian = _compute_normalized_gramian(matrix, self.norm_eps)
gramian_array = gramian.cpu().detach().numpy().astype(np.float64)
dimension = gramian.shape[0]

regularization_array = self.reg_eps * np.eye(dimension)
regularized_gramian_array = gramian_array + regularization_array

P = regularized_gramian_array
G = -np.eye(dimension)
h = np.zeros(dimension)

lagrangian_rows = []
for i in range(dimension):
weight = weights[i].item()
if weight <= 0.0:
# In this case, the solution to the quadratic program is always 0,
# so we don't need to run solve_qp.
lagrangian_rows.append(np.zeros([dimension]))
else:
q = weight * regularized_gramian_array[i, :]
lagrangian_rows.append(solve_qp(P, q, G, h, solver=self.solver))

lagrangian_array = np.stack(lagrangian_rows)
lagrangian = torch.from_numpy(lagrangian_array).to(
device=gramian.device, dtype=gramian.dtype
)
return lagrangian
U = torch.diag(self.weighting(matrix))
G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
W = _project_weights(U, G, self.solver)
return torch.sum(W, dim=0)
69 changes: 69 additions & 0 deletions tests/unit/aggregation/test_dual_cone_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
from pytest import mark
from torch.testing import assert_close

from torchjd.aggregation._dual_cone_utils import _project_weights


@mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)])
def test_solution_weights(shape: tuple[int, int]):
r"""
Tests that `_project_weights` returns valid weights corresponding to the projection onto the
dual cone of a matrix with the specified shape.
Validation is performed by verifying that the solution satisfies the `KKT conditions
<https://en.wikipedia.org/wiki/Karush%E2%80%93Kuhn%E2%80%93Tucker_conditions>`_ for the
quadratic program that projects vectors onto the dual cone of a matrix. Specifically, the
solution should satisfy the equivalent set of conditions described in Lemma 4 of [1].
Let `u` be a vector of weights and `G` a positive semi-definite matrix. Consider the quadratic
problem of minimizing `v^T G v` subject to `u \preceq v`.
Then `w` is a solution if and only if it satisfies the following three conditions:
1. **Dual feasibility:** `u \preceq w`
2. **Primal feasibility:** `0 \preceq G w`
3. **Complementary slackness:** `u^T G w = w^T G w`
Reference:
[1] `Jacobian Descent For Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_.
"""

J = torch.randn(shape)
G = J @ J.T
u = torch.rand(shape[0])

w = _project_weights(u, G, "quadprog")
dual_gap = w - u

# Dual feasibility
dual_gap_positive_part = dual_gap[dual_gap >= 0.0]
assert_close(dual_gap_positive_part.norm(), dual_gap.norm(), atol=1e-05, rtol=0)

primal_gap = G @ w

# Primal feasibility
primal_gap_positive_part = primal_gap[primal_gap >= 0]
assert_close(primal_gap_positive_part.norm(), primal_gap.norm(), atol=1e-04, rtol=0)

# Complementary slackness
slackness = dual_gap @ primal_gap
assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0)


@mark.parametrize("shape", [(5, 2, 3), (1, 3, 6, 9), (2, 1, 1, 5, 8), (3, 1)])
def test_tensorization_shape(shape: tuple[int, ...]):
"""
Tests that applying `_project_weights` on a tensor is equivalent to applying it on the tensor
reshaped as matrix and to reshape the result back to the original tensor's shape.
"""

matrix = torch.randn([shape[-1], shape[-1]])
U_tensor = torch.randn(shape)
U_matrix = U_tensor.reshape([-1, shape[-1]])

G = matrix @ matrix.T

W_tensor = _project_weights(U_tensor, G, "quadprog")
W_matrix = _project_weights(U_matrix, G, "quadprog")

assert_close(W_matrix.reshape(shape), W_tensor)
28 changes: 0 additions & 28 deletions tests/unit/aggregation/test_upgrad.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import torch
from pytest import mark
from torch.testing import assert_close

from torchjd.aggregation import UPGrad
from torchjd.aggregation.mean import _MeanWeighting
from torchjd.aggregation.upgrad import _UPGradWrapper

from ._property_testers import (
ExpectedStructureProperty,
Expand All @@ -18,31 +15,6 @@ class TestUPGrad(ExpectedStructureProperty, NonConflictingProperty, PermutationI
pass


@mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)])
def test_upgrad_lagrangian_satisfies_kkt_conditions(shape: tuple[int, int]):
matrix = torch.randn(shape)
weights = torch.rand(shape[0])

gramian = matrix @ matrix.T

W = _UPGradWrapper(_MeanWeighting(), norm_eps=0.0001, reg_eps=0.0, solver="quadprog")

lagrange_multiplier = W._compute_lagrangian(matrix, weights)

positive_lagrange_multiplier = lagrange_multiplier[lagrange_multiplier >= 0]
assert_close(
positive_lagrange_multiplier.norm(), lagrange_multiplier.norm(), atol=1e-05, rtol=0
)

constraint = gramian @ (torch.diag(weights) + lagrange_multiplier.T)

positive_constraint = constraint[constraint >= 0]
assert_close(positive_constraint.norm(), constraint.norm(), atol=1e-04, rtol=0)

slackness = torch.trace(lagrange_multiplier @ constraint)
assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0)


def test_representations():
A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog")
assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')"
Expand Down

0 comments on commit 1d87f56

Please sign in to comment.