Skip to content

Commit

Permalink
Merge pull request #569 from malariagen/GH568_add_defaults_to_pca
Browse files Browse the repository at this point in the history
Add defaults for min_minor_ac, max_missing_an to pca(), plot_njt()
  • Loading branch information
leehart authored Aug 1, 2024
2 parents fb787df + a85afea commit 8c9389d
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 9 deletions.
10 changes: 7 additions & 3 deletions malariagen_data/anoph/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,12 @@ def pca(
sample_indices: Optional[base_params.sample_indices] = None,
site_mask: Optional[base_params.site_mask] = base_params.DEFAULT,
site_class: Optional[base_params.site_class] = None,
min_minor_ac: Optional[base_params.min_minor_ac] = None,
max_missing_an: Optional[base_params.max_missing_an] = None,
min_minor_ac: Optional[
base_params.min_minor_ac
] = pca_params.min_minor_ac_default,
max_missing_an: Optional[
base_params.max_missing_an
] = pca_params.max_missing_an_default,
cohort_size: Optional[base_params.cohort_size] = None,
min_cohort_size: Optional[base_params.min_cohort_size] = None,
max_cohort_size: Optional[base_params.max_cohort_size] = None,
Expand All @@ -73,7 +77,7 @@ def pca(
) -> Tuple[pca_params.df_pca, pca_params.evr]:
# Change this name if you ever change the behaviour of this function, to
# invalidate any previously cached data.
name = "pca_v2"
name = "pca_v3"

# Normalize params for consistent hash value.
(
Expand Down
5 changes: 5 additions & 0 deletions malariagen_data/anoph/pca_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
from typing_extensions import Annotated, TypeAlias
from . import base_params

n_components: TypeAlias = Annotated[
int,
Expand All @@ -23,3 +24,7 @@
np.ndarray,
"An array of explained variance ratios, one per component.",
]

min_minor_ac_default: base_params.min_minor_ac = 2

max_missing_an_default: base_params.max_missing_an = 0
2 changes: 1 addition & 1 deletion malariagen_data/anoph/snp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,7 +1655,7 @@ def biallelic_snp_calls(
ds_out = xr.Dataset(coords=coords, data_vars=data_vars, attrs=ds.attrs)

# Apply conditions.
if max_missing_an or min_minor_ac:
if max_missing_an is not None or min_minor_ac is not None:
loc_out = np.ones(ds_out.sizes["variants"], dtype=bool)

# Apply missingness condition.
Expand Down
9 changes: 7 additions & 2 deletions malariagen_data/anopheles.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
hapnet_params,
het_params,
ihs_params,
pca_params,
plotly_params,
xpehh_params,
)
Expand Down Expand Up @@ -3158,8 +3159,12 @@ def plot_njt(
sample_indices: Optional[base_params.sample_indices] = None,
site_mask: Optional[base_params.site_mask] = base_params.DEFAULT,
site_class: Optional[base_params.site_class] = None,
min_minor_ac: Optional[base_params.min_minor_ac] = None,
max_missing_an: Optional[base_params.max_missing_an] = None,
min_minor_ac: Optional[
base_params.min_minor_ac
] = pca_params.min_minor_ac_default,
max_missing_an: Optional[
base_params.max_missing_an
] = pca_params.max_missing_an_default,
cohort_size: Optional[base_params.cohort_size] = None,
min_cohort_size: Optional[base_params.min_cohort_size] = None,
max_cohort_size: Optional[base_params.max_cohort_size] = None,
Expand Down
12 changes: 9 additions & 3 deletions tests/anoph/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from malariagen_data import af1 as _af1
from malariagen_data import ag3 as _ag3
from malariagen_data.anoph.pca import AnophelesPca
from malariagen_data.anoph import pca_params


@pytest.fixture
Expand Down Expand Up @@ -83,13 +84,18 @@ def test_pca_plotting(fixture, api: AnophelesPca):
sample_sets=random.sample(all_sample_sets, 2),
site_mask=random.choice((None,) + api.site_mask_ids),
)
ds = api.biallelic_snp_calls(**data_params)
ds = api.biallelic_snp_calls(
min_minor_ac=pca_params.min_minor_ac_default,
max_missing_an=pca_params.max_missing_an_default,
**data_params,
)

# PCA parameters.
n_samples = ds.sizes["samples"]
n_snps_available = ds.sizes["variants"]
n_snps = random.randint(n_samples, n_snps_available)
n_components = random.randint(3, n_samples)
n_snps = random.randint(1, n_snps_available)
# PC3 required for plot_pca_coords_3d()
n_components = random.randint(3, min(n_samples, n_snps))

# Run the PCA.
pca_df, pca_evr = api.pca(
Expand Down

0 comments on commit 8c9389d

Please sign in to comment.