Skip to content

Commit

Permalink
Merge pull request #36 from OxfordRSE/re-enable-l2g-tests
Browse files Browse the repository at this point in the history
Re-enable local2global tests
  • Loading branch information
mihaeladuta authored Dec 11, 2024
2 parents 9aaa273 + c446ab1 commit 355842b
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 42 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ jobs:
python3 -m pip install '.[dev]' --find-links https://data.pyg.org/whl/torch-2.4.1%2Bcpu.html
- name: Test with pytest
run: |
python3 -m pytest --ignore tests/test_local2global.py
python3 -m pytest -n auto
11 changes: 11 additions & 0 deletions l2gv2/patch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@
from .patch import Patch


def seed(new_seed):
"""
Change seed of random number generator.
Args:
new_seed: New seed value
"""
np.random.default_rng(new_seed)


def random_gen(new_seed=None) -> np.random.Generator:
"""Change seed of random number generator.
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dependencies = [
dev = [
"pytest >= 8",
"pytest-cov",
"pytest-xdist",
"pylint",
"pre_commit == 4.0.1"
]
Expand Down Expand Up @@ -90,3 +91,6 @@ reportOptionalMemberAccess = false
reportOptionalSubscript = false
reportGeneralTypeIssues = false
reportMissingTypeStubs = false

[tool.coverage.run]
omit = ["tests/*"]
46 changes: 5 additions & 41 deletions tests/test_local2global.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# pylint: disable=too-many-positional-arguments
"""Test local2global reconstruction"""

import os
import sys
from copy import copy
from statistics import mean
Expand Down Expand Up @@ -80,16 +79,13 @@ def iter_seed(it: int) -> None:
it (int): Iteration number used to generate a unique seed.
Returns:
None
"""
it_seed = np.random.SeedSequence(entropy=_seed.entropy, n_children_spawned=it)
ut.seed(it_seed.spawn(1)[0])


@pytest.mark.skipif(
os.getenv("GITHUB_ACTIONS") == "true",
reason="local2global tests disabled in GitHub Actions",
)
@pytest.mark.xfail(reason="Noisy tests may fail, though many failures are a bad sign")
@pytest.mark.parametrize("it", range(100))
@pytest.mark.parametrize("problem_cls", test_classes)
@pytest.mark.parametrize("patches,min_overlap", zip(patches_list, MIN_OVERLAP))
Expand All @@ -114,15 +110,11 @@ def test_stability(it, problem_cls, patches, min_overlap):
# ex.add_noise(problem, 1e-8)
rotations = ex.rand_rotate_patches(problem)
recovered_rots = problem.calc_synchronised_rotations()
error = ut.orthogonal_MSE_error(rotations, recovered_rots)
error = ut.orthogonal_mse_error(rotations, recovered_rots)
print(f"Mean error is {error}")
assert error < TOL


@pytest.mark.skipif(
os.getenv("GITHUB_ACTIONS") == "true",
reason="local2global tests disabled in GitHub Actions",
)
@pytest.mark.parametrize("problem_cls", test_classes)
@pytest.mark.parametrize("patches,min_overlap", zip(patches_list, MIN_OVERLAP))
def test_calc_synchronised_rotations(problem_cls, patches, min_overlap):
Expand All @@ -142,15 +134,11 @@ def test_calc_synchronised_rotations(problem_cls, patches, min_overlap):
rotations = ex.rand_rotate_patches(problem)
ex.rand_shift_patches(problem)
recovered_rots = problem.calc_synchronised_rotations()
error = ut.orthogonal_MSE_error(rotations, recovered_rots)
error = ut.orthogonal_mse_error(rotations, recovered_rots)
print(f"Mean error is {error}")
assert error < TOL


@pytest.mark.skipif(
os.getenv("GITHUB_ACTIONS") == "true",
reason="local2global tests disabled in GitHub Actions",
)
@pytest.mark.xfail(reason="Noisy tests may fail, though many failures are a bad sign")
@pytest.mark.parametrize("test_class", test_classes)
@pytest.mark.parametrize("noise", NOISE_SCALES)
Expand Down Expand Up @@ -188,18 +176,14 @@ def test_noisy_calc_synchronised_rotations(noise, test_class, patches, min_overl

recovered_rots = problem.calc_synchronised_rotations()
problem.rotate_patches(rotations=[r.T for r in recovered_rots])
error = ut.orthogonal_MSE_error(rotations, recovered_rots)
error = ut.orthogonal_mse_error(rotations, recovered_rots)
print(f"Mean rotation error is {error}")
print(
f"Error of relative rotations is min: {min_err}, mean: {mean_err}, max: {max_err}"
)
assert error < max(max_err, TOL)


@pytest.mark.skipif(
os.getenv("GITHUB_ACTIONS") == "true",
reason="local2global tests disabled in GitHub Actions",
)
@pytest.mark.parametrize("problem_cls", test_classes)
@pytest.mark.parametrize("patches,min_overlap", zip(patches_list, MIN_OVERLAP))
def test_calc_synchronised_scales(problem_cls, patches, min_overlap):
Expand All @@ -225,10 +209,6 @@ def test_calc_synchronised_scales(problem_cls, patches, min_overlap):
assert error < TOL


@pytest.mark.skipif(
os.getenv("GITHUB_ACTIONS") == "true",
reason="local2global tests disabled in GitHub Actions",
)
@pytest.mark.xfail(reason="Noisy tests may fail, though many failures are a bad sign")
@pytest.mark.parametrize("problem_cls", test_classes)
@pytest.mark.parametrize("noise", NOISE_SCALES)
Expand Down Expand Up @@ -277,10 +257,6 @@ def test_noisy_calc_synchronised_scales(problem_cls, noise, patches, min_overlap
assert error < max_err + TOL


@pytest.mark.skipif(
os.getenv("GITHUB_ACTIONS") == "true",
reason="local2global tests disabled in GitHub Actions",
)
@pytest.mark.parametrize("problem_cls", test_classes)
@pytest.mark.parametrize("patches,min_overlap", zip(patches_list, MIN_OVERLAP))
def test_calc_synchronised_translations(problem_cls, patches, min_overlap):
Expand All @@ -304,10 +280,6 @@ def test_calc_synchronised_translations(problem_cls, patches, min_overlap):
assert error < TOL


@pytest.mark.skipif(
os.getenv("GITHUB_ACTIONS") == "true",
reason="local2global tests disabled in GitHub Actions",
)
@pytest.mark.xfail(reason="Noisy tests may fail, though many failures are a bad sign")
@pytest.mark.parametrize("noise", NOISE_SCALES)
@pytest.mark.parametrize("patches,min_overlap", zip(patches_list, MIN_OVERLAP))
Expand All @@ -333,10 +305,6 @@ def test_noisy_calc_synchronised_translations(noise, patches, min_overlap):
assert error < noise + TOL


@pytest.mark.skipif(
os.getenv("GITHUB_ACTIONS") == "true",
reason="local2global tests disabled in GitHub Actions",
)
@pytest.mark.parametrize("problem_cls", test_classes)
@pytest.mark.parametrize(
"patches,min_overlap,points", zip(patches_list, MIN_OVERLAP, points_list)
Expand Down Expand Up @@ -366,10 +334,6 @@ def test_get_aligned_embedding(problem_cls, patches, min_overlap, points):
assert error < TOL


@pytest.mark.skipif(
os.getenv("GITHUB_ACTIONS") == "true",
reason="local2global tests disabled in GitHub Actions",
)
@pytest.mark.xfail(reason="Noisy tests may fail, though many failures are a bad sign")
@pytest.mark.parametrize("problem_cls", test_classes)
@pytest.mark.parametrize("it", range(3))
Expand Down

0 comments on commit 355842b

Please sign in to comment.