diff --git a/malariagen_data/anoph/pca.py b/malariagen_data/anoph/pca.py index 6d819ba4e..a5a28adeb 100644 --- a/malariagen_data/anoph/pca.py +++ b/malariagen_data/anoph/pca.py @@ -62,8 +62,12 @@ def pca( sample_indices: Optional[base_params.sample_indices] = None, site_mask: Optional[base_params.site_mask] = base_params.DEFAULT, site_class: Optional[base_params.site_class] = None, - min_minor_ac: Optional[base_params.min_minor_ac] = None, - max_missing_an: Optional[base_params.max_missing_an] = None, + min_minor_ac: Optional[ + base_params.min_minor_ac + ] = pca_params.min_minor_ac_default, + max_missing_an: Optional[ + base_params.max_missing_an + ] = pca_params.max_missing_an_default, cohort_size: Optional[base_params.cohort_size] = None, min_cohort_size: Optional[base_params.min_cohort_size] = None, max_cohort_size: Optional[base_params.max_cohort_size] = None, @@ -73,7 +77,7 @@ def pca( ) -> Tuple[pca_params.df_pca, pca_params.evr]: # Change this name if you ever change the behaviour of this function, to # invalidate any previously cached data. - name = "pca_v2" + name = "pca_v3" # Normalize params for consistent hash value. ( diff --git a/malariagen_data/anoph/pca_params.py b/malariagen_data/anoph/pca_params.py index e74a243af..ef959509d 100644 --- a/malariagen_data/anoph/pca_params.py +++ b/malariagen_data/anoph/pca_params.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd from typing_extensions import Annotated, TypeAlias +from . import base_params n_components: TypeAlias = Annotated[ int, @@ -23,3 +24,7 @@ np.ndarray, "An array of explained variance ratios, one per component.", ] + +min_minor_ac_default: base_params.min_minor_ac = 2 + +max_missing_an_default: base_params.max_missing_an = 0 diff --git a/malariagen_data/anoph/snp_data.py b/malariagen_data/anoph/snp_data.py index 8e16c3552..7b404f3a8 100644 --- a/malariagen_data/anoph/snp_data.py +++ b/malariagen_data/anoph/snp_data.py @@ -1655,7 +1655,7 @@ def biallelic_snp_calls( ds_out = xr.Dataset(coords=coords, data_vars=data_vars, attrs=ds.attrs) # Apply conditions. - if max_missing_an or min_minor_ac: + if max_missing_an is not None or min_minor_ac is not None: loc_out = np.ones(ds_out.sizes["variants"], dtype=bool) # Apply missingness condition. diff --git a/malariagen_data/anopheles.py b/malariagen_data/anopheles.py index 615215858..324cc395e 100644 --- a/malariagen_data/anopheles.py +++ b/malariagen_data/anopheles.py @@ -35,6 +35,7 @@ hapnet_params, het_params, ihs_params, + pca_params, plotly_params, xpehh_params, ) @@ -3158,8 +3159,12 @@ def plot_njt( sample_indices: Optional[base_params.sample_indices] = None, site_mask: Optional[base_params.site_mask] = base_params.DEFAULT, site_class: Optional[base_params.site_class] = None, - min_minor_ac: Optional[base_params.min_minor_ac] = None, - max_missing_an: Optional[base_params.max_missing_an] = None, + min_minor_ac: Optional[ + base_params.min_minor_ac + ] = pca_params.min_minor_ac_default, + max_missing_an: Optional[ + base_params.max_missing_an + ] = pca_params.max_missing_an_default, cohort_size: Optional[base_params.cohort_size] = None, min_cohort_size: Optional[base_params.min_cohort_size] = None, max_cohort_size: Optional[base_params.max_cohort_size] = None, diff --git a/tests/anoph/test_pca.py b/tests/anoph/test_pca.py index 554644f98..8d452fcce 100644 --- a/tests/anoph/test_pca.py +++ b/tests/anoph/test_pca.py @@ -9,6 +9,7 @@ from malariagen_data import af1 as _af1 from malariagen_data import ag3 as _ag3 from malariagen_data.anoph.pca import AnophelesPca +from malariagen_data.anoph import pca_params @pytest.fixture @@ -83,13 +84,18 @@ def test_pca_plotting(fixture, api: AnophelesPca): sample_sets=random.sample(all_sample_sets, 2), site_mask=random.choice((None,) + api.site_mask_ids), ) - ds = api.biallelic_snp_calls(**data_params) + ds = api.biallelic_snp_calls( + min_minor_ac=pca_params.min_minor_ac_default, + max_missing_an=pca_params.max_missing_an_default, + **data_params, + ) # PCA parameters. n_samples = ds.sizes["samples"] n_snps_available = ds.sizes["variants"] - n_snps = random.randint(n_samples, n_snps_available) - n_components = random.randint(3, n_samples) + n_snps = random.randint(1, n_snps_available) + # PC3 required for plot_pca_coords_3d() + n_components = random.randint(3, min(n_samples, n_snps)) # Run the PCA. pca_df, pca_evr = api.pca(