Skip to content

Commit

Permalink
Extended the tests to CNVs and HAPs
Browse files Browse the repository at this point in the history
  • Loading branch information
jonbrenas committed Dec 12, 2024
1 parent 48d0274 commit 86e7e02
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 8 deletions.
42 changes: 34 additions & 8 deletions tests/anoph/test_cnv_frq.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
from malariagen_data import ag3 as _ag3
from malariagen_data.anoph.cnv_frq import AnophelesCnvFrequencyAnalysis
from malariagen_data.util import compare_series_like
from .test_frq import (
test_plot_frequencies_heatmap,
test_plot_frequencies_time_series,
test_plot_frequencies_time_series_with_taxa,
test_plot_frequencies_time_series_with_areas,
test_plot_frequencies_interactive_map,
)


@pytest.fixture
Expand Down Expand Up @@ -109,6 +116,8 @@ def test_gene_cnv_frequencies_with_str_cohorts(
# Run the function under test.
df_cnv = api.gene_cnv_frequencies(**params)

test_plot_frequencies_heatmap(api, df_cnv)

# Figure out expected cohort labels.
df_samples = api.sample_metadata(sample_sets=sample_sets)
if "cohort_" + cohorts in df_samples:
Expand Down Expand Up @@ -166,12 +175,14 @@ def test_gene_cnv_frequencies_with_min_cohort_size(
return

# Run the function under test.
df_snp = api.gene_cnv_frequencies(**params)
df_cnv = api.gene_cnv_frequencies(**params)

test_plot_frequencies_heatmap(api, df_cnv)

# Standard checks.
check_gene_cnv_frequencies(
api=api,
df=df_snp,
df=df_cnv,
cohort_labels=cohort_labels,
region=region,
)
Expand Down Expand Up @@ -212,12 +223,14 @@ def test_gene_cnv_frequencies_with_str_cohorts_and_sample_query(
)

# Run the function under test.
df_snp = api.gene_cnv_frequencies(**params)
df_cnv = api.gene_cnv_frequencies(**params)

test_plot_frequencies_heatmap(api, df_cnv)

# Standard checks.
check_gene_cnv_frequencies(
api=api,
df=df_snp,
df=df_cnv,
cohort_labels=cohort_labels,
region=region,
)
Expand Down Expand Up @@ -268,12 +281,14 @@ def test_gene_cnv_frequencies_with_str_cohorts_and_sample_query_options(
)

# Run the function under test.
df_snp = api.gene_cnv_frequencies(**params)
df_cnv = api.gene_cnv_frequencies(**params)

test_plot_frequencies_heatmap(api, df_cnv)

# Standard checks.
check_gene_cnv_frequencies(
api=api,
df=df_snp,
df=df_cnv,
cohort_labels=cohort_labels,
region=region,
)
Expand Down Expand Up @@ -305,12 +320,14 @@ def test_gene_cnv_frequencies_with_dict_cohorts(
)

# Run the function under test.
df_snp = api.gene_cnv_frequencies(**params)
df_cnv = api.gene_cnv_frequencies(**params)

test_plot_frequencies_heatmap(api, df_cnv)

# Standard checks.
check_gene_cnv_frequencies(
api=api,
df=df_snp,
df=df_cnv,
cohort_labels=cohort_labels,
region=region,
)
Expand Down Expand Up @@ -350,6 +367,9 @@ def test_gene_cnv_frequencies_without_drop_invariant(
df_cnv_a = api.gene_cnv_frequencies(drop_invariant=True, **params)
df_cnv_b = api.gene_cnv_frequencies(drop_invariant=False, **params)

test_plot_frequencies_heatmap(api, df_cnv_a)
test_plot_frequencies_heatmap(api, df_cnv_b)

# Standard checks.
check_gene_cnv_frequencies(
api=api,
Expand Down Expand Up @@ -418,6 +438,8 @@ def test_gene_cnv_frequencies_with_max_coverage_variance(
# checks.
df_cnv = api.gene_cnv_frequencies(**params)

test_plot_frequencies_heatmap(api, df_cnv)

# Figure out expected cohort labels.
df_samples = api.sample_metadata(sample_sets=sample_sets)
if "cohort_" + cohorts in df_samples:
Expand Down Expand Up @@ -711,6 +733,10 @@ def check_gene_cnv_frequencies_advanced(

# Check the result.
assert isinstance(ds, xr.Dataset)
test_plot_frequencies_time_series(api, ds)
test_plot_frequencies_time_series_with_taxa(api, ds)
test_plot_frequencies_time_series_with_areas(api, ds)
test_plot_frequencies_interactive_map(api, ds)
assert set(ds.dims) == {"cohorts", "variants"}

# Check variant variables.
Expand Down
10 changes: 10 additions & 0 deletions tests/anoph/test_hap_frq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@

from malariagen_data import ag3 as _ag3
from malariagen_data.anoph.hap_frq import AnophelesHapFrequencyAnalysis
from .test_frq import (
test_plot_frequencies_time_series,
test_plot_frequencies_time_series_with_taxa,
test_plot_frequencies_time_series_with_areas,
test_plot_frequencies_interactive_map,
)


@pytest.fixture
Expand Down Expand Up @@ -82,6 +88,10 @@ def check_hap_frequencies_advanced(
ds,
):
assert isinstance(ds, xr.Dataset)
test_plot_frequencies_time_series(api, ds)
test_plot_frequencies_time_series_with_taxa(api, ds)
test_plot_frequencies_time_series_with_areas(api, ds)
test_plot_frequencies_interactive_map(api, ds)
assert set(ds.dims) == {"cohorts", "variants"}

expected_cohort_vars = [
Expand Down

0 comments on commit 86e7e02

Please sign in to comment.