Skip to content

Commit

Permalink
Improve docstring and naming of test of kkt conditions.
Browse files Browse the repository at this point in the history
  • Loading branch information
PierreQuinton committed Feb 2, 2025
1 parent c9afb41 commit e141600
Showing 1 changed file with 34 additions and 10 deletions.
44 changes: 34 additions & 10 deletions tests/unit/aggregation/test_dual_cone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,48 @@


@mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)])
def test_lagrangian_satisfies_kkt_conditions(shape: tuple[int, int]):
def test_solution_weights(shape: tuple[int, int]):
r"""
Tests that `_get_projection_weights` returns valid weights corresponding to the projection onto
the dual cone of a matrix with the specified shape.
Validation is performed by verifying that the solution satisfies the `KKT conditions
<https://en.wikipedia.org/wiki/Karush%E2%80%93Kuhn%E2%80%93Tucker_conditions>`_ for the
quadratic program that projects vectors onto the dual cone of a matrix.
Specifically, the solution should satisfy the equivalent set of conditions described in Lemma 4
of [1].
Let:
- `u` be a vector of weights,
- `G` a positive semi-definite matrix,
- Consider the quadratic problem of minimizing `v^\top G v` subject to `u \preceq v`.
Then `w` is a solution if and only if it satisfies the following three conditions:
1. **Dual feasibility:** `u \preceq w`
2. **Primal feasibility:** `0 \preceq G w`
3. **Complementary slackness:** `u^\top G w = w^\top G w`
Reference:
[1] `Jacobian Descent For Multi-Objective Optimization <https://arxiv.org/pdf/2406.16232>`_.
"""
matrix = torch.randn(shape)
weights = torch.rand(shape[0])

gramian = matrix @ matrix.T

projection_weights = _get_projection_weights(gramian, weights, "quadprog")
lagrange_multiplier = projection_weights - weights
dual_gap = projection_weights - weights

positive_lagrange_multiplier = lagrange_multiplier[lagrange_multiplier >= 0.0]
assert_close(
positive_lagrange_multiplier.norm(), lagrange_multiplier.norm(), atol=1e-05, rtol=0
)
# Dual feasibility
dual_gap_positive_part = dual_gap[dual_gap >= 0.0]
assert_close(dual_gap_positive_part.norm(), dual_gap.norm(), atol=1e-05, rtol=0)

constraint = gramian @ projection_weights
primal_gap = gramian @ projection_weights

positive_constraint = constraint[constraint >= 0]
assert_close(positive_constraint.norm(), constraint.norm(), atol=1e-04, rtol=0)
# Primal feasibility
primal_gap_positive_part = primal_gap[primal_gap >= 0]
assert_close(primal_gap_positive_part.norm(), primal_gap.norm(), atol=1e-04, rtol=0)

slackness = lagrange_multiplier @ constraint
# Complementary slackness
slackness = dual_gap @ primal_gap
assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0)

0 comments on commit e141600

Please sign in to comment.