From b01423344403203a76cb272df16d2ec6055fd84a Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Tue, 10 Dec 2024 15:44:11 +0100 Subject: [PATCH] adapt to new epsilon scheduler class --- tests/problems/cross_modality/test_translation_problem.py | 3 +-- tests/problems/generic/test_fgw_problem.py | 3 +-- tests/problems/generic/test_gw_problem.py | 3 +-- tests/problems/generic/test_sinkhorn_problem.py | 3 +-- tests/problems/space/test_alignment_problem.py | 3 +-- tests/problems/space/test_mapping_problem.py | 3 +-- tests/problems/spatio_temporal/test_spatio_temporal_problem.py | 3 +-- tests/problems/time/test_temporal_problem.py | 3 +-- 8 files changed, 8 insertions(+), 16 deletions(-) diff --git a/tests/problems/cross_modality/test_translation_problem.py b/tests/problems/cross_modality/test_translation_problem.py index 5d444db30..81f2232d3 100644 --- a/tests/problems/cross_modality/test_translation_problem.py +++ b/tests/problems/cross_modality/test_translation_problem.py @@ -173,8 +173,7 @@ def test_pass_arguments(self, adata_translation_split: Tuple[AnnData, AnnData], el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) if arg == "epsilon": eps_processed = getattr(geom, val) - assert isinstance(eps_processed, epsilon_scheduler.Epsilon) - assert eps_processed.target == args_to_check[arg], arg + assert eps_processed == args_to_check[arg], arg else: assert getattr(geom, val) == args_to_check[arg], arg assert el == args_to_check[arg] diff --git a/tests/problems/generic/test_fgw_problem.py b/tests/problems/generic/test_fgw_problem.py index 4595cbbab..df1bb6347 100644 --- a/tests/problems/generic/test_fgw_problem.py +++ b/tests/problems/generic/test_fgw_problem.py @@ -139,8 +139,7 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) if arg == "epsilon": eps_processed = getattr(geom, val) - assert isinstance(eps_processed, epsilon_scheduler.Epsilon) - assert eps_processed.target == args_to_check[arg], arg + assert eps_processed == args_to_check[arg], arg else: assert getattr(geom, val) == args_to_check[arg], arg assert el == args_to_check[arg] diff --git a/tests/problems/generic/test_gw_problem.py b/tests/problems/generic/test_gw_problem.py index 064923cf3..f92b0a67d 100644 --- a/tests/problems/generic/test_gw_problem.py +++ b/tests/problems/generic/test_gw_problem.py @@ -144,8 +144,7 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) if arg == "epsilon": eps_processed = getattr(geom, val) - assert isinstance(eps_processed, epsilon_scheduler.Epsilon) - assert eps_processed.target == args_to_check[arg], arg + assert eps_processed == args_to_check[arg], arg else: assert getattr(geom, val) == args_to_check[arg], arg assert el == args_to_check[arg] diff --git a/tests/problems/generic/test_sinkhorn_problem.py b/tests/problems/generic/test_sinkhorn_problem.py index 1badbf49b..9892ee781 100644 --- a/tests/problems/generic/test_sinkhorn_problem.py +++ b/tests/problems/generic/test_sinkhorn_problem.py @@ -177,8 +177,7 @@ def test_pass_arguments(self, adata_time: AnnData, args_to_check: Mapping[str, A el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) if arg == "epsilon": eps_processed = getattr(geom, val) - assert isinstance(eps_processed, epsilon_scheduler.Epsilon) - assert eps_processed.target == args_to_check[arg], arg + assert eps_processed == args_to_check[arg], arg else: assert getattr(geom, val) == args_to_check[arg], arg assert el == args_to_check[arg] diff --git a/tests/problems/space/test_alignment_problem.py b/tests/problems/space/test_alignment_problem.py index 47202ffe0..901f46d7e 100644 --- a/tests/problems/space/test_alignment_problem.py +++ b/tests/problems/space/test_alignment_problem.py @@ -224,8 +224,7 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) if arg == "epsilon": eps_processed = getattr(geom, val) - assert isinstance(eps_processed, epsilon_scheduler.Epsilon) - assert eps_processed.target == args_to_check[arg], arg + assert eps_processed == args_to_check[arg], arg else: assert getattr(geom, val) == args_to_check[arg], arg assert el == args_to_check[arg] diff --git a/tests/problems/space/test_mapping_problem.py b/tests/problems/space/test_mapping_problem.py index 4ac6266ed..7d6520293 100644 --- a/tests/problems/space/test_mapping_problem.py +++ b/tests/problems/space/test_mapping_problem.py @@ -262,8 +262,7 @@ def test_pass_arguments(self, adata_mapping: AnnData, args_to_check: Mapping[str el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) if arg == "epsilon": eps_processed = getattr(geom, val) - assert isinstance(eps_processed, epsilon_scheduler.Epsilon) - assert eps_processed.target == args_to_check[arg], arg + assert eps_processed == args_to_check[arg], arg else: assert getattr(geom, val) == args_to_check[arg], arg assert el == args_to_check[arg] diff --git a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py index ac661391d..3632d76b7 100644 --- a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py +++ b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py @@ -224,8 +224,7 @@ def test_pass_arguments(self, adata_spatio_temporal: AnnData, args_to_check: Map el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) if arg == "epsilon": eps_processed = getattr(geom, val) - assert isinstance(eps_processed, epsilon_scheduler.Epsilon) - assert eps_processed.target == args_to_check[arg], arg + assert eps_processed == args_to_check[arg], arg else: assert getattr(geom, val) == args_to_check[arg], arg assert el == args_to_check[arg] diff --git a/tests/problems/time/test_temporal_problem.py b/tests/problems/time/test_temporal_problem.py index f53a745c9..d3780a8ec 100644 --- a/tests/problems/time/test_temporal_problem.py +++ b/tests/problems/time/test_temporal_problem.py @@ -456,8 +456,7 @@ def test_pass_arguments(self, adata_time: AnnData, args_to_check: Mapping[str, A el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) if arg == "epsilon": eps_processed = getattr(geom, val) - assert isinstance(eps_processed, epsilon_scheduler.Epsilon) - assert eps_processed.target == args_to_check[arg], arg + assert eps_processed == args_to_check[arg], arg else: assert getattr(geom, val) == args_to_check[arg], arg assert el == args_to_check[arg]