Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor dual cone projections #237

Merged
merged 36 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
aa9897e
Factorize projection onto the dual cone.
PierreQuinton Jan 21, 2025
f534006
Move regularization out of projection onto the dual cone.
PierreQuinton Jan 21, 2025
864d7a4
Improve code coverage
PierreQuinton Jan 21, 2025
b0739fe
Add link to paper in weights_of_projection
PierreQuinton Feb 1, 2025
cd43f65
Rename functions in _dual_cone_utils
PierreQuinton Feb 1, 2025
402024c
Rename lagrangian to lagrangian_multipliers to match the maths
PierreQuinton Feb 1, 2025
a7ae47f
Separate _get_projection_weights into a tensorization of _get_lagrang…
PierreQuinton Feb 1, 2025
895b123
call the cast to numpy (and thus CPU) only once on the weights.
PierreQuinton Feb 1, 2025
25c5666
Improve the tensorization of _get_lagrange_multipliers_array
PierreQuinton Feb 1, 2025
8c52aa5
Improve readability of tensorization by using a for comprehension.
PierreQuinton Feb 1, 2025
3591577
Exchange _compute_normalized_regularized_gramian and _regularize_gramian
PierreQuinton Feb 1, 2025
c631f30
Add comment explaining the goal of regularization
PierreQuinton Feb 1, 2025
4aa0cbd
Rename:
PierreQuinton Feb 1, 2025
193ff9c
Exchange _normalize and _normalize_and_regularize
PierreQuinton Feb 1, 2025
a62c708
rename epsilons in _normalize and _regularize
PierreQuinton Feb 1, 2025
fb07035
Set explicitly the dtype and device of regularization matrix
PierreQuinton Feb 2, 2025
756e8c5
Rename functions & variables, add _to_array
ValerianRey Feb 2, 2025
601bc6e
Fix bug of transposed output
ValerianRey Feb 2, 2025
b3615cd
Delete outdated tests
ValerianRey Feb 2, 2025
c9afb41
Rename functions _normalize etc
ValerianRey Feb 2, 2025
e141600
Improve docstring and naming of test of kkt conditions.
PierreQuinton Feb 2, 2025
9c7eaac
Improve docstring of _get_projection_weights
PierreQuinton Feb 2, 2025
f04581e
Add test of tensorization of _get_projection_weights
PierreQuinton Feb 2, 2025
a5f562a
Clarify the QP problem that is solved, it now matches exactly that of…
PierreQuinton Feb 3, 2025
627adf0
Improve docstring of _get_projection_weight_vector
PierreQuinton Feb 3, 2025
bb411c9
Add changelog entry
PierreQuinton Feb 4, 2025
782aad6
Add no_grad context to the computation of UPGrad and DualProj as we c…
PierreQuinton Feb 4, 2025
ffca7d0
Improve dual cone utility functions
ValerianRey Feb 4, 2025
60a3851
Improve variable names:
ValerianRey Feb 4, 2025
c7a2fe8
Improve docstring of test_solution_weights
ValerianRey Feb 4, 2025
633b286
Remove torch.no_grad in DualProj and UPGrad
ValerianRey Feb 4, 2025
4f3e48a
Add docstring to test_tensorization_shape
ValerianRey Feb 4, 2025
0c224fc
Improve docstring formatting
ValerianRey Feb 4, 2025
c759201
Simplify UPGrad
ValerianRey Feb 4, 2025
c604c80
Add docstring to _regularize
ValerianRey Feb 4, 2025
20aa85c
Improve formatting
ValerianRey Feb 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions src/torchjd/aggregation/_dual_cone_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import Literal

import numpy as np
import torch
from qpsolvers import solve_qp
from torch import Tensor


def _get_projection_weights(
gramian: Tensor, weights: Tensor, solver: Literal["quadprog"]
) -> Tensor:
"""
Computes the weights of the projection of some weights onto the dual cone of a matrix whose
gramian is provided. Specifically, this solves for $w$ in the problem defined by (5) in
Proposition 1 of [1] when the gramian is $JJ^\top$ and $v$ is given by weights.
This is a vectorized version, therefore weights can be a matrix made of columns of weights.

[1] `Jacobian Descent For Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_.
"""
PierreQuinton marked this conversation as resolved.
Show resolved Hide resolved
lagrange_multipliers = _get_lagrange_multipliers(gramian, weights, solver)
return lagrange_multipliers + weights
ValerianRey marked this conversation as resolved.
Show resolved Hide resolved


def _get_lagrange_multipliers(
gramian: Tensor, weights: Tensor, solver: Literal["quadprog"]
) -> Tensor:
weight_matrix = _to_array(weights.reshape([-1, weights.shape[-1]]))
gramian_array = _to_array(gramian)

lagrange_multiplier_vectors = [
_get_lagrange_multiplier_vector(gramian_array, weight_vector, solver)
for weight_vector in weight_matrix
]

lagrange_multiplier_matrix = np.stack(lagrange_multiplier_vectors)
lagrange_multipliers = (
torch.from_numpy(lagrange_multiplier_matrix)
.to(device=gramian.device, dtype=gramian.dtype)
.reshape(weights.shape)
)
return lagrange_multipliers


def _get_lagrange_multiplier_vector(
gramian: np.array, weight_vector: np.array, solver: Literal["quadprog"]
) -> np.array:
"""
Solves the dual of the projection of a vector of weights onto the dual cone of the matrix J
whose gramian is given.
"""
dimension = gramian.shape[0]
P = gramian
q = gramian @ weight_vector
G = -np.eye(dimension)
h = np.zeros(dimension)
return solve_qp(P, q, G, h, solver=solver)
PierreQuinton marked this conversation as resolved.
Show resolved Hide resolved


def _to_array(tensor: Tensor) -> np.ndarray:
return tensor.cpu().detach().numpy().astype(np.float64)
18 changes: 16 additions & 2 deletions src/torchjd/aggregation/_gramian_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ def _compute_gramian(matrix: Tensor) -> Tensor:
return matrix @ matrix.T


def _compute_normalized_gramian(matrix: Tensor, norm_eps: float) -> Tensor:
def _compute_regularized_normalized_gramian(matrix: Tensor, norm_eps: float, reg_eps: float):
normalized_gramian = _compute_normalized_gramian(matrix, norm_eps)
return _regularize(normalized_gramian, reg_eps)


def _compute_normalized_gramian(matrix: Tensor, eps: float) -> Tensor:
r"""
Computes :math:`\frac{1}{\sigma_\max^2} J J^T` for an input matrix :math:`J`, where
:math:`{\sigma_\max^2}` is :math:`J`'s largest singular value.
Expand All @@ -35,11 +40,20 @@ def _compute_normalized_gramian(matrix: Tensor, norm_eps: float) -> Tensor:
"issue on https://github.com/TorchJD/torchjd/issues and paste this error message in it."
) from error
max_singular_value = torch.max(singular_values)
if max_singular_value < norm_eps:
if max_singular_value < eps:
scaled_singular_values = torch.zeros_like(singular_values)
else:
scaled_singular_values = singular_values / max_singular_value
normalized_gramian = (
left_unitary_matrix @ torch.diag(scaled_singular_values**2) @ left_unitary_matrix.T
)
return normalized_gramian


def _regularize(gramian: Tensor, eps: float) -> Tensor:
# Because of numerical errors, `gramian` might have slightly negative eigenvalue(s).
# Adding a regularization term which is a small proportion of the identity matrix ensures that the gramian is positive definite.
regularization_matrix = eps * torch.eye(
gramian.shape[0], dtype=gramian.dtype, device=gramian.device
)
return gramian + regularization_matrix
30 changes: 4 additions & 26 deletions src/torchjd/aggregation/dualproj.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Literal

import numpy as np
import torch
from qpsolvers import solve_qp
from torch import Tensor

from ._gramian_utils import _compute_normalized_gramian
from ._dual_cone_utils import _get_projection_weights
from ._gramian_utils import _compute_regularized_normalized_gramian
from ._pref_vector_utils import (
_check_pref_vector,
_pref_vector_to_str_suffix,
Expand Down Expand Up @@ -107,25 +105,5 @@ def __init__(

def forward(self, matrix: Tensor) -> Tensor:
weights = self.weighting(matrix)
weights_array = weights.cpu().detach().numpy().astype(np.float64)

gramian = _compute_normalized_gramian(matrix, self.norm_eps)
gramian_array = gramian.cpu().detach().numpy().astype(np.float64)
dimension = gramian.shape[0]

# Because of numerical errors, `gramian_array` might have slightly negative eigenvalue(s),
# which makes quadprog misbehave. Adding a regularization term which is a small proportion
# of the identity matrix ensures that the gramian is positive definite.
regularization_array = self.reg_eps * np.eye(dimension)
regularized_gramian_array = gramian_array + regularization_array

P = regularized_gramian_array
q = regularized_gramian_array @ weights_array
G = -np.eye(dimension)
h = np.zeros(dimension)

projection_weights_array = solve_qp(P, q, G, h, solver=self.solver)
projection_weights = torch.from_numpy(projection_weights_array).to(
device=matrix.device, dtype=matrix.dtype
)
return projection_weights + weights
gramian = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
return _get_projection_weights(gramian, weights, self.solver)
41 changes: 6 additions & 35 deletions src/torchjd/aggregation/upgrad.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from typing import Literal

import numpy as np
import torch
from qpsolvers import solve_qp
from torch import Tensor

from ._gramian_utils import _compute_normalized_gramian
from ._dual_cone_utils import _get_projection_weights
from ._gramian_utils import _compute_regularized_normalized_gramian
from ._pref_vector_utils import (
_check_pref_vector,
_pref_vector_to_str_suffix,
Expand Down Expand Up @@ -102,36 +101,8 @@ def __init__(

def forward(self, matrix: Tensor) -> Tensor:
weights = self.weighting(matrix)
lagrangian = self._compute_lagrangian(matrix, weights)
lagrangian_weights = torch.sum(lagrangian, dim=0)
result_weights = lagrangian_weights + weights
return result_weights

def _compute_lagrangian(self, matrix: Tensor, weights: Tensor) -> Tensor:
gramian = _compute_normalized_gramian(matrix, self.norm_eps)
gramian_array = gramian.cpu().detach().numpy().astype(np.float64)
dimension = gramian.shape[0]

regularization_array = self.reg_eps * np.eye(dimension)
regularized_gramian_array = gramian_array + regularization_array

P = regularized_gramian_array
G = -np.eye(dimension)
h = np.zeros(dimension)

lagrangian_rows = []
for i in range(dimension):
weight = weights[i].item()
if weight <= 0.0:
# In this case, the solution to the quadratic program is always 0,
# so we don't need to run solve_qp.
lagrangian_rows.append(np.zeros([dimension]))
else:
q = weight * regularized_gramian_array[i, :]
lagrangian_rows.append(solve_qp(P, q, G, h, solver=self.solver))

lagrangian_array = np.stack(lagrangian_rows)
lagrangian = torch.from_numpy(lagrangian_array).to(
device=gramian.device, dtype=gramian.dtype
gramian = _compute_regularized_normalized_gramian(matrix, self.norm_eps, self.reg_eps)
projection_weights_matrix = _get_projection_weights(
gramian, torch.diag(weights), self.solver
)
return lagrangian
return torch.sum(projection_weights_matrix, dim=0)
53 changes: 53 additions & 0 deletions tests/unit/aggregation/test_dual_cone_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
from pytest import mark
from torch.testing import assert_close

from torchjd.aggregation._dual_cone_utils import _get_projection_weights


@mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)])
def test_solution_weights(shape: tuple[int, int]):
r"""
Tests that `_get_projection_weights` returns valid weights corresponding to the projection onto
the dual cone of a matrix with the specified shape.

Validation is performed by verifying that the solution satisfies the `KKT conditions
<https://en.wikipedia.org/wiki/Karush%E2%80%93Kuhn%E2%80%93Tucker_conditions>`_ for the
quadratic program that projects vectors onto the dual cone of a matrix.
Specifically, the solution should satisfy the equivalent set of conditions described in Lemma 4
of [1].

Let:
- `u` be a vector of weights,
- `G` a positive semi-definite matrix,
- Consider the quadratic problem of minimizing `v^\top G v` subject to `u \preceq v`.

Then `w` is a solution if and only if it satisfies the following three conditions:
1. **Dual feasibility:** `u \preceq w`
2. **Primal feasibility:** `0 \preceq G w`
3. **Complementary slackness:** `u^\top G w = w^\top G w`

Reference:
[1] `Jacobian Descent For Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_.
"""
matrix = torch.randn(shape)
weights = torch.rand(shape[0])

gramian = matrix @ matrix.T

projection_weights = _get_projection_weights(gramian, weights, "quadprog")
dual_gap = projection_weights - weights

# Dual feasibility
dual_gap_positive_part = dual_gap[dual_gap >= 0.0]
PierreQuinton marked this conversation as resolved.
Show resolved Hide resolved
assert_close(dual_gap_positive_part.norm(), dual_gap.norm(), atol=1e-05, rtol=0)

primal_gap = gramian @ projection_weights

# Primal feasibility
primal_gap_positive_part = primal_gap[primal_gap >= 0]
assert_close(primal_gap_positive_part.norm(), primal_gap.norm(), atol=1e-04, rtol=0)

# Complementary slackness
slackness = dual_gap @ primal_gap
assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0)
28 changes: 0 additions & 28 deletions tests/unit/aggregation/test_upgrad.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import torch
from pytest import mark
from torch.testing import assert_close

from torchjd.aggregation import UPGrad
from torchjd.aggregation.mean import _MeanWeighting
from torchjd.aggregation.upgrad import _UPGradWrapper

from ._property_testers import (
ExpectedStructureProperty,
Expand All @@ -18,31 +15,6 @@ class TestUPGrad(ExpectedStructureProperty, NonConflictingProperty, PermutationI
pass


@mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)])
def test_upgrad_lagrangian_satisfies_kkt_conditions(shape: tuple[int, int]):
matrix = torch.randn(shape)
weights = torch.rand(shape[0])

gramian = matrix @ matrix.T

W = _UPGradWrapper(_MeanWeighting(), norm_eps=0.0001, reg_eps=0.0, solver="quadprog")

lagrange_multiplier = W._compute_lagrangian(matrix, weights)

positive_lagrange_multiplier = lagrange_multiplier[lagrange_multiplier >= 0]
assert_close(
positive_lagrange_multiplier.norm(), lagrange_multiplier.norm(), atol=1e-05, rtol=0
)

constraint = gramian @ (torch.diag(weights) + lagrange_multiplier.T)

positive_constraint = constraint[constraint >= 0]
assert_close(positive_constraint.norm(), constraint.norm(), atol=1e-04, rtol=0)

slackness = torch.trace(lagrange_multiplier @ constraint)
assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0)


def test_representations():
A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog")
assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')"
Expand Down