From 1d87f56a589bbc971794558fec91a619f387b8f8 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 5 Feb 2025 00:17:32 +0100 Subject: [PATCH] refactor(aggregation): Refactor dual cone projections (#237) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Simplify dual cone projection by solving directly the problem of Proposition 1 of the paper * Add _project_weights, _project_weight_vector and _to_array in _dual_cone_utils.py to factorize dual cone projections * Rename norm_eps to eps in _compute_normalized_gramian * Add _compute_regularized_normalized_gramian and _regularize in _gramian_utils.py to factorize regularization of the gramian * Improve testing of dual cone projections * Add changelog entry --------- Co-authored-by: Valérian Rey --- CHANGELOG.md | 5 ++ src/torchjd/aggregation/_dual_cone_utils.py | 55 +++++++++++++++ src/torchjd/aggregation/_gramian_utils.py | 25 ++++++- src/torchjd/aggregation/dualproj.py | 33 ++------- src/torchjd/aggregation/upgrad.py | 43 ++---------- .../unit/aggregation/test_dual_cone_utils.py | 69 +++++++++++++++++++ tests/unit/aggregation/test_upgrad.py | 28 -------- 7 files changed, 164 insertions(+), 94 deletions(-) create mode 100644 src/torchjd/aggregation/_dual_cone_utils.py create mode 100644 tests/unit/aggregation/test_dual_cone_utils.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e54a47a..7a917a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,11 @@ changes that do not affect the user. ## [Unreleased] +### Changed + +- Refactored the underlying optimization problem that `UPGrad` and `DualProj` have to solve to + project onto the dual cone. This may minimally affect the output of these aggregators. + ## [0.5.0] - 2025-02-01 ### Added diff --git a/src/torchjd/aggregation/_dual_cone_utils.py b/src/torchjd/aggregation/_dual_cone_utils.py new file mode 100644 index 0000000..c844ade --- /dev/null +++ b/src/torchjd/aggregation/_dual_cone_utils.py @@ -0,0 +1,55 @@ +from typing import Literal + +import numpy as np +import torch +from qpsolvers import solve_qp +from torch import Tensor + + +def _project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor: + """ + Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the + rows of a matrix whose Gramian is provided. + + :param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`. + :param G: The Gramian matrix of shape `[m, m]`. + :param solver: The quadratic programming solver to use. + :return: A tensor of projection weights with the same shape as `U`. + """ + + G_ = _to_array(G) + U_ = _to_array(U) + + W = np.apply_along_axis(lambda u: _project_weight_vector(u, G_, solver), axis=-1, arr=U_) + + return torch.as_tensor(W, device=G.device, dtype=G.dtype) + + +def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: Literal["quadprog"]) -> np.ndarray: + r""" + Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`, + given `G = J J^T` and `u`. In other words, this computes the `w` that satisfies + `\pi_J(J^T u) = J^T w`. + + By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic program: + minimize v^T G v + subject to u \preceq v + + Reference: + [1] `Jacobian Descent For Multi-Objective Optimization `_. + + :param u: The vector of weights `u` of shape `[m]` corresponding to the vector `J^T u` to + project. + :param G: The Gramian matrix of `J`, equal to `J J^T`, and of shape `[m, m]`. + :param solver: The quadratic programming solver to use. + """ + + m = G.shape[0] + w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=solver) + return w + + +def _to_array(tensor: Tensor) -> np.ndarray: + """Transforms a tensor into a numpy array with float64 dtype.""" + + return tensor.cpu().detach().numpy().astype(np.float64) diff --git a/src/torchjd/aggregation/_gramian_utils.py b/src/torchjd/aggregation/_gramian_utils.py index 34efefe..bb12fb4 100644 --- a/src/torchjd/aggregation/_gramian_utils.py +++ b/src/torchjd/aggregation/_gramian_utils.py @@ -7,10 +7,16 @@ def _compute_gramian(matrix: Tensor) -> Tensor: """ Computes the `Gramian matrix `_ of a given matrix. """ + return matrix @ matrix.T -def _compute_normalized_gramian(matrix: Tensor, norm_eps: float) -> Tensor: +def _compute_regularized_normalized_gramian(matrix: Tensor, norm_eps: float, reg_eps: float): + normalized_gramian = _compute_normalized_gramian(matrix, norm_eps) + return _regularize(normalized_gramian, reg_eps) + + +def _compute_normalized_gramian(matrix: Tensor, eps: float) -> Tensor: r""" Computes :math:`\frac{1}{\sigma_\max^2} J J^T` for an input matrix :math:`J`, where :math:`{\sigma_\max^2}` is :math:`J`'s largest singular value. @@ -35,7 +41,7 @@ def _compute_normalized_gramian(matrix: Tensor, norm_eps: float) -> Tensor: "issue on https://github.com/TorchJD/torchjd/issues and paste this error message in it." ) from error max_singular_value = torch.max(singular_values) - if max_singular_value < norm_eps: + if max_singular_value < eps: scaled_singular_values = torch.zeros_like(singular_values) else: scaled_singular_values = singular_values / max_singular_value @@ -43,3 +49,18 @@ def _compute_normalized_gramian(matrix: Tensor, norm_eps: float) -> Tensor: left_unitary_matrix @ torch.diag(scaled_singular_values**2) @ left_unitary_matrix.T ) return normalized_gramian + + +def _regularize(gramian: Tensor, eps: float) -> Tensor: + """ + Adds a regularization term to the gramian to enforce positive definiteness. + + Because of numerical errors, `gramian` might have slightly negative eigenvalue(s). Adding a + regularization term which is a small proportion of the identity matrix ensures that the gramian + is positive definite. + """ + + regularization_matrix = eps * torch.eye( + gramian.shape[0], dtype=gramian.dtype, device=gramian.device + ) + return gramian + regularization_matrix diff --git a/src/torchjd/aggregation/dualproj.py b/src/torchjd/aggregation/dualproj.py index 7d7fbe0..eb55cef 100644 --- a/src/torchjd/aggregation/dualproj.py +++ b/src/torchjd/aggregation/dualproj.py @@ -1,11 +1,9 @@ from typing import Literal -import numpy as np -import torch -from qpsolvers import solve_qp from torch import Tensor -from ._gramian_utils import _compute_normalized_gramian +from ._dual_cone_utils import _project_weights +from ._gramian_utils import _compute_regularized_normalized_gramian from ._pref_vector_utils import ( _check_pref_vector, _pref_vector_to_str_suffix, @@ -106,26 +104,7 @@ def __init__( self.solver = solver def forward(self, matrix: Tensor) -> Tensor: - weights = self.weighting(matrix) - weights_array = weights.cpu().detach().numpy().astype(np.float64) - - gramian = _compute_normalized_gramian(matrix, self.norm_eps) - gramian_array = gramian.cpu().detach().numpy().astype(np.float64) - dimension = gramian.shape[0] - - # Because of numerical errors, `gramian_array` might have slightly negative eigenvalue(s), - # which makes quadprog misbehave. Adding a regularization term which is a small proportion - # of the identity matrix ensures that the gramian is positive definite. - regularization_array = self.reg_eps * np.eye(dimension) - regularized_gramian_array = gramian_array + regularization_array - - P = regularized_gramian_array - q = regularized_gramian_array @ weights_array - G = -np.eye(dimension) - h = np.zeros(dimension) - - projection_weights_array = solve_qp(P, q, G, h, solver=self.solver) - projection_weights = torch.from_numpy(projection_weights_array).to( - device=matrix.device, dtype=matrix.dtype - ) - return projection_weights + weights + u = self.weighting(matrix) + G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps) + w = _project_weights(u, G, self.solver) + return w diff --git a/src/torchjd/aggregation/upgrad.py b/src/torchjd/aggregation/upgrad.py index 2050fd4..fdc8394 100644 --- a/src/torchjd/aggregation/upgrad.py +++ b/src/torchjd/aggregation/upgrad.py @@ -1,11 +1,10 @@ from typing import Literal -import numpy as np import torch -from qpsolvers import solve_qp from torch import Tensor -from ._gramian_utils import _compute_normalized_gramian +from ._dual_cone_utils import _project_weights +from ._gramian_utils import _compute_regularized_normalized_gramian from ._pref_vector_utils import ( _check_pref_vector, _pref_vector_to_str_suffix, @@ -101,37 +100,7 @@ def __init__( self.solver = solver def forward(self, matrix: Tensor) -> Tensor: - weights = self.weighting(matrix) - lagrangian = self._compute_lagrangian(matrix, weights) - lagrangian_weights = torch.sum(lagrangian, dim=0) - result_weights = lagrangian_weights + weights - return result_weights - - def _compute_lagrangian(self, matrix: Tensor, weights: Tensor) -> Tensor: - gramian = _compute_normalized_gramian(matrix, self.norm_eps) - gramian_array = gramian.cpu().detach().numpy().astype(np.float64) - dimension = gramian.shape[0] - - regularization_array = self.reg_eps * np.eye(dimension) - regularized_gramian_array = gramian_array + regularization_array - - P = regularized_gramian_array - G = -np.eye(dimension) - h = np.zeros(dimension) - - lagrangian_rows = [] - for i in range(dimension): - weight = weights[i].item() - if weight <= 0.0: - # In this case, the solution to the quadratic program is always 0, - # so we don't need to run solve_qp. - lagrangian_rows.append(np.zeros([dimension])) - else: - q = weight * regularized_gramian_array[i, :] - lagrangian_rows.append(solve_qp(P, q, G, h, solver=self.solver)) - - lagrangian_array = np.stack(lagrangian_rows) - lagrangian = torch.from_numpy(lagrangian_array).to( - device=gramian.device, dtype=gramian.dtype - ) - return lagrangian + U = torch.diag(self.weighting(matrix)) + G = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps) + W = _project_weights(U, G, self.solver) + return torch.sum(W, dim=0) diff --git a/tests/unit/aggregation/test_dual_cone_utils.py b/tests/unit/aggregation/test_dual_cone_utils.py new file mode 100644 index 0000000..3952840 --- /dev/null +++ b/tests/unit/aggregation/test_dual_cone_utils.py @@ -0,0 +1,69 @@ +import torch +from pytest import mark +from torch.testing import assert_close + +from torchjd.aggregation._dual_cone_utils import _project_weights + + +@mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)]) +def test_solution_weights(shape: tuple[int, int]): + r""" + Tests that `_project_weights` returns valid weights corresponding to the projection onto the + dual cone of a matrix with the specified shape. + + Validation is performed by verifying that the solution satisfies the `KKT conditions + `_ for the + quadratic program that projects vectors onto the dual cone of a matrix. Specifically, the + solution should satisfy the equivalent set of conditions described in Lemma 4 of [1]. + + Let `u` be a vector of weights and `G` a positive semi-definite matrix. Consider the quadratic + problem of minimizing `v^T G v` subject to `u \preceq v`. + + Then `w` is a solution if and only if it satisfies the following three conditions: + 1. **Dual feasibility:** `u \preceq w` + 2. **Primal feasibility:** `0 \preceq G w` + 3. **Complementary slackness:** `u^T G w = w^T G w` + + Reference: + [1] `Jacobian Descent For Multi-Objective Optimization `_. + """ + + J = torch.randn(shape) + G = J @ J.T + u = torch.rand(shape[0]) + + w = _project_weights(u, G, "quadprog") + dual_gap = w - u + + # Dual feasibility + dual_gap_positive_part = dual_gap[dual_gap >= 0.0] + assert_close(dual_gap_positive_part.norm(), dual_gap.norm(), atol=1e-05, rtol=0) + + primal_gap = G @ w + + # Primal feasibility + primal_gap_positive_part = primal_gap[primal_gap >= 0] + assert_close(primal_gap_positive_part.norm(), primal_gap.norm(), atol=1e-04, rtol=0) + + # Complementary slackness + slackness = dual_gap @ primal_gap + assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0) + + +@mark.parametrize("shape", [(5, 2, 3), (1, 3, 6, 9), (2, 1, 1, 5, 8), (3, 1)]) +def test_tensorization_shape(shape: tuple[int, ...]): + """ + Tests that applying `_project_weights` on a tensor is equivalent to applying it on the tensor + reshaped as matrix and to reshape the result back to the original tensor's shape. + """ + + matrix = torch.randn([shape[-1], shape[-1]]) + U_tensor = torch.randn(shape) + U_matrix = U_tensor.reshape([-1, shape[-1]]) + + G = matrix @ matrix.T + + W_tensor = _project_weights(U_tensor, G, "quadprog") + W_matrix = _project_weights(U_matrix, G, "quadprog") + + assert_close(W_matrix.reshape(shape), W_tensor) diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index ba15c32..4f2abb2 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -1,10 +1,7 @@ import torch from pytest import mark -from torch.testing import assert_close from torchjd.aggregation import UPGrad -from torchjd.aggregation.mean import _MeanWeighting -from torchjd.aggregation.upgrad import _UPGradWrapper from ._property_testers import ( ExpectedStructureProperty, @@ -18,31 +15,6 @@ class TestUPGrad(ExpectedStructureProperty, NonConflictingProperty, PermutationI pass -@mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)]) -def test_upgrad_lagrangian_satisfies_kkt_conditions(shape: tuple[int, int]): - matrix = torch.randn(shape) - weights = torch.rand(shape[0]) - - gramian = matrix @ matrix.T - - W = _UPGradWrapper(_MeanWeighting(), norm_eps=0.0001, reg_eps=0.0, solver="quadprog") - - lagrange_multiplier = W._compute_lagrangian(matrix, weights) - - positive_lagrange_multiplier = lagrange_multiplier[lagrange_multiplier >= 0] - assert_close( - positive_lagrange_multiplier.norm(), lagrange_multiplier.norm(), atol=1e-05, rtol=0 - ) - - constraint = gramian @ (torch.diag(weights) + lagrange_multiplier.T) - - positive_constraint = constraint[constraint >= 0] - assert_close(positive_constraint.norm(), constraint.norm(), atol=1e-04, rtol=0) - - slackness = torch.trace(lagrange_multiplier @ constraint) - assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0) - - def test_representations(): A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')"