Skip to content

Commit

Permalink
Improve code coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
PierreQuinton committed Jan 21, 2025
1 parent b7ac60e commit 5c15741
Showing 1 changed file with 32 additions and 1 deletion.
33 changes: 32 additions & 1 deletion tests/unit/aggregation/test_dual_cone_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from pytest import mark
from pytest import mark, raises
from torch.testing import assert_close

from torchjd.aggregation._dual_cone_utils import _weights_of_projection_onto_dual_cone
Expand Down Expand Up @@ -27,3 +27,34 @@ def test_lagrangian_satisfies_kkt_conditions(shape: tuple[int, int]):

slackness = lagrange_multiplier @ constraint
assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0)


@mark.parametrize("shape", [(5, 7), (9, 37), (2, 14), (32, 114), (50, 100)])
def test_lagrangian_satisfies_kkt_conditions_matrix_weights(shape: tuple[int, int]):
matrix = torch.randn(shape)
weights_matrix = torch.diag(torch.rand(shape[0]))

gramian = matrix @ matrix.T

projection_weights = _weights_of_projection_onto_dual_cone(gramian, weights_matrix, "quadprog")
lagrange_multiplier = projection_weights - weights_matrix

positive_lagrange_multiplier = lagrange_multiplier[lagrange_multiplier >= 0.0]
assert_close(
positive_lagrange_multiplier.norm(), lagrange_multiplier.norm(), atol=1e-05, rtol=0
)

constraint = gramian @ projection_weights

positive_constraint = constraint[constraint >= 0]
assert_close(positive_constraint.norm(), constraint.norm(), atol=1e-04, rtol=0)

slackness = torch.trace(constraint @ lagrange_multiplier.T)
assert_close(slackness, torch.zeros_like(slackness), atol=3e-03, rtol=0)


def test_weights_of_projection_onto_dual_cone_invalid_shape():
with raises(ValueError):
_weights_of_projection_onto_dual_cone(
torch.zeros([5, 5]), torch.zeros([5, 2, 3]), "quadprog"
)

0 comments on commit 5c15741

Please sign in to comment.