-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(aggregation): Refactor dual cone projections (#237)
* Simplify dual cone projection by solving directly the problem of Proposition 1 of the paper * Add _project_weights, _project_weight_vector and _to_array in _dual_cone_utils.py to factorize dual cone projections * Rename norm_eps to eps in _compute_normalized_gramian * Add _compute_regularized_normalized_gramian and _regularize in _gramian_utils.py to factorize regularization of the gramian * Improve testing of dual cone projections * Add changelog entry --------- Co-authored-by: Valérian Rey <[email protected]>
- Loading branch information
1 parent
6e6f54c
commit 1d87f56
Showing
7 changed files
with
164 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from typing import Literal | ||
|
||
import numpy as np | ||
import torch | ||
from qpsolvers import solve_qp | ||
from torch import Tensor | ||
|
||
|
||
def _project_weights(U: Tensor, G: Tensor, solver: Literal["quadprog"]) -> Tensor: | ||
""" | ||
Computes the tensor of weights corresponding to the projection of the vectors in `U` onto the | ||
rows of a matrix whose Gramian is provided. | ||
:param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`. | ||
:param G: The Gramian matrix of shape `[m, m]`. | ||
:param solver: The quadratic programming solver to use. | ||
:return: A tensor of projection weights with the same shape as `U`. | ||
""" | ||
|
||
G_ = _to_array(G) | ||
U_ = _to_array(U) | ||
|
||
W = np.apply_along_axis(lambda u: _project_weight_vector(u, G_, solver), axis=-1, arr=U_) | ||
|
||
return torch.as_tensor(W, device=G.device, dtype=G.dtype) | ||
|
||
|
||
def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: Literal["quadprog"]) -> np.ndarray: | ||
r""" | ||
Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`, | ||
given `G = J J^T` and `u`. In other words, this computes the `w` that satisfies | ||
`\pi_J(J^T u) = J^T w`. | ||
By Proposition 1 of [1], this is equivalent to solving for `v` the following quadratic program: | ||
minimize v^T G v | ||
subject to u \preceq v | ||
Reference: | ||
[1] `Jacobian Descent For Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_. | ||
:param u: The vector of weights `u` of shape `[m]` corresponding to the vector `J^T u` to | ||
project. | ||
:param G: The Gramian matrix of `J`, equal to `J J^T`, and of shape `[m, m]`. | ||
:param solver: The quadratic programming solver to use. | ||
""" | ||
|
||
m = G.shape[0] | ||
w = solve_qp(G, np.zeros(m), -np.eye(m), -u, solver=solver) | ||
return w | ||
|
||
|
||
def _to_array(tensor: Tensor) -> np.ndarray: | ||
"""Transforms a tensor into a numpy array with float64 dtype.""" | ||
|
||
return tensor.cpu().detach().numpy().astype(np.float64) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import torch | ||
from pytest import mark | ||
from torch.testing import assert_close | ||
|
||
from torchjd.aggregation._dual_cone_utils import _project_weights | ||
|
||
|
||
@mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)]) | ||
def test_solution_weights(shape: tuple[int, int]): | ||
r""" | ||
Tests that `_project_weights` returns valid weights corresponding to the projection onto the | ||
dual cone of a matrix with the specified shape. | ||
Validation is performed by verifying that the solution satisfies the `KKT conditions | ||
<https://en.wikipedia.org/wiki/Karush%E2%80%93Kuhn%E2%80%93Tucker_conditions>`_ for the | ||
quadratic program that projects vectors onto the dual cone of a matrix. Specifically, the | ||
solution should satisfy the equivalent set of conditions described in Lemma 4 of [1]. | ||
Let `u` be a vector of weights and `G` a positive semi-definite matrix. Consider the quadratic | ||
problem of minimizing `v^T G v` subject to `u \preceq v`. | ||
Then `w` is a solution if and only if it satisfies the following three conditions: | ||
1. **Dual feasibility:** `u \preceq w` | ||
2. **Primal feasibility:** `0 \preceq G w` | ||
3. **Complementary slackness:** `u^T G w = w^T G w` | ||
Reference: | ||
[1] `Jacobian Descent For Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_. | ||
""" | ||
|
||
J = torch.randn(shape) | ||
G = J @ J.T | ||
u = torch.rand(shape[0]) | ||
|
||
w = _project_weights(u, G, "quadprog") | ||
dual_gap = w - u | ||
|
||
# Dual feasibility | ||
dual_gap_positive_part = dual_gap[dual_gap >= 0.0] | ||
assert_close(dual_gap_positive_part.norm(), dual_gap.norm(), atol=1e-05, rtol=0) | ||
|
||
primal_gap = G @ w | ||
|
||
# Primal feasibility | ||
primal_gap_positive_part = primal_gap[primal_gap >= 0] | ||
assert_close(primal_gap_positive_part.norm(), primal_gap.norm(), atol=1e-04, rtol=0) | ||
|
||
# Complementary slackness | ||
slackness = dual_gap @ primal_gap | ||
assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0) | ||
|
||
|
||
@mark.parametrize("shape", [(5, 2, 3), (1, 3, 6, 9), (2, 1, 1, 5, 8), (3, 1)]) | ||
def test_tensorization_shape(shape: tuple[int, ...]): | ||
""" | ||
Tests that applying `_project_weights` on a tensor is equivalent to applying it on the tensor | ||
reshaped as matrix and to reshape the result back to the original tensor's shape. | ||
""" | ||
|
||
matrix = torch.randn([shape[-1], shape[-1]]) | ||
U_tensor = torch.randn(shape) | ||
U_matrix = U_tensor.reshape([-1, shape[-1]]) | ||
|
||
G = matrix @ matrix.T | ||
|
||
W_tensor = _project_weights(U_tensor, G, "quadprog") | ||
W_matrix = _project_weights(U_matrix, G, "quadprog") | ||
|
||
assert_close(W_matrix.reshape(shape), W_tensor) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters