From 86e7e023f6c3adb0551ed6e4c2aa443e6472d89d Mon Sep 17 00:00:00 2001 From: jonbrenas <51911846+jonbrenas@users.noreply.github.com> Date: Thu, 12 Dec 2024 09:40:09 +0000 Subject: [PATCH] Extended the tests to CNVs and HAPs --- tests/anoph/test_cnv_frq.py | 42 ++++++++++++++++++++++++++++++------- tests/anoph/test_hap_frq.py | 10 +++++++++ 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/tests/anoph/test_cnv_frq.py b/tests/anoph/test_cnv_frq.py index ef96a400..a96cc0bf 100644 --- a/tests/anoph/test_cnv_frq.py +++ b/tests/anoph/test_cnv_frq.py @@ -11,6 +11,13 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.cnv_frq import AnophelesCnvFrequencyAnalysis from malariagen_data.util import compare_series_like +from .test_frq import ( + test_plot_frequencies_heatmap, + test_plot_frequencies_time_series, + test_plot_frequencies_time_series_with_taxa, + test_plot_frequencies_time_series_with_areas, + test_plot_frequencies_interactive_map, +) @pytest.fixture @@ -109,6 +116,8 @@ def test_gene_cnv_frequencies_with_str_cohorts( # Run the function under test. df_cnv = api.gene_cnv_frequencies(**params) + test_plot_frequencies_heatmap(api, df_cnv) + # Figure out expected cohort labels. df_samples = api.sample_metadata(sample_sets=sample_sets) if "cohort_" + cohorts in df_samples: @@ -166,12 +175,14 @@ def test_gene_cnv_frequencies_with_min_cohort_size( return # Run the function under test. - df_snp = api.gene_cnv_frequencies(**params) + df_cnv = api.gene_cnv_frequencies(**params) + + test_plot_frequencies_heatmap(api, df_cnv) # Standard checks. check_gene_cnv_frequencies( api=api, - df=df_snp, + df=df_cnv, cohort_labels=cohort_labels, region=region, ) @@ -212,12 +223,14 @@ def test_gene_cnv_frequencies_with_str_cohorts_and_sample_query( ) # Run the function under test. - df_snp = api.gene_cnv_frequencies(**params) + df_cnv = api.gene_cnv_frequencies(**params) + + test_plot_frequencies_heatmap(api, df_cnv) # Standard checks. check_gene_cnv_frequencies( api=api, - df=df_snp, + df=df_cnv, cohort_labels=cohort_labels, region=region, ) @@ -268,12 +281,14 @@ def test_gene_cnv_frequencies_with_str_cohorts_and_sample_query_options( ) # Run the function under test. - df_snp = api.gene_cnv_frequencies(**params) + df_cnv = api.gene_cnv_frequencies(**params) + + test_plot_frequencies_heatmap(api, df_cnv) # Standard checks. check_gene_cnv_frequencies( api=api, - df=df_snp, + df=df_cnv, cohort_labels=cohort_labels, region=region, ) @@ -305,12 +320,14 @@ def test_gene_cnv_frequencies_with_dict_cohorts( ) # Run the function under test. - df_snp = api.gene_cnv_frequencies(**params) + df_cnv = api.gene_cnv_frequencies(**params) + + test_plot_frequencies_heatmap(api, df_cnv) # Standard checks. check_gene_cnv_frequencies( api=api, - df=df_snp, + df=df_cnv, cohort_labels=cohort_labels, region=region, ) @@ -350,6 +367,9 @@ def test_gene_cnv_frequencies_without_drop_invariant( df_cnv_a = api.gene_cnv_frequencies(drop_invariant=True, **params) df_cnv_b = api.gene_cnv_frequencies(drop_invariant=False, **params) + test_plot_frequencies_heatmap(api, df_cnv_a) + test_plot_frequencies_heatmap(api, df_cnv_b) + # Standard checks. check_gene_cnv_frequencies( api=api, @@ -418,6 +438,8 @@ def test_gene_cnv_frequencies_with_max_coverage_variance( # checks. df_cnv = api.gene_cnv_frequencies(**params) + test_plot_frequencies_heatmap(api, df_cnv) + # Figure out expected cohort labels. df_samples = api.sample_metadata(sample_sets=sample_sets) if "cohort_" + cohorts in df_samples: @@ -711,6 +733,10 @@ def check_gene_cnv_frequencies_advanced( # Check the result. assert isinstance(ds, xr.Dataset) + test_plot_frequencies_time_series(api, ds) + test_plot_frequencies_time_series_with_taxa(api, ds) + test_plot_frequencies_time_series_with_areas(api, ds) + test_plot_frequencies_interactive_map(api, ds) assert set(ds.dims) == {"cohorts", "variants"} # Check variant variables. diff --git a/tests/anoph/test_hap_frq.py b/tests/anoph/test_hap_frq.py index 7dc54346..b8278aa9 100644 --- a/tests/anoph/test_hap_frq.py +++ b/tests/anoph/test_hap_frq.py @@ -8,6 +8,12 @@ from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.hap_frq import AnophelesHapFrequencyAnalysis +from .test_frq import ( + test_plot_frequencies_time_series, + test_plot_frequencies_time_series_with_taxa, + test_plot_frequencies_time_series_with_areas, + test_plot_frequencies_interactive_map, +) @pytest.fixture @@ -82,6 +88,10 @@ def check_hap_frequencies_advanced( ds, ): assert isinstance(ds, xr.Dataset) + test_plot_frequencies_time_series(api, ds) + test_plot_frequencies_time_series_with_taxa(api, ds) + test_plot_frequencies_time_series_with_areas(api, ds) + test_plot_frequencies_interactive_map(api, ds) assert set(ds.dims) == {"cohorts", "variants"} expected_cohort_vars = [