Skip to content

Commit

Permalink
support fractions for max_missing_an and min_minor_ac parameters (#617)
Browse files Browse the repository at this point in the history
  • Loading branch information
alimanfoo authored Sep 20, 2024
1 parent ce614ae commit 5ec5de3
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 7 deletions.
10 changes: 6 additions & 4 deletions malariagen_data/anoph/base_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,19 +271,21 @@ def validate_sample_selection_params(
]

min_minor_ac: TypeAlias = Annotated[
int,
Union[int, float],
"""
The minimum minor allele count. SNPs with a minor allele count
below this value will be excluded.
below this value will be excluded. Can also be a float, which will
be interpreted as a fraction.
""",
]

max_missing_an: TypeAlias = Annotated[
int,
Union[int, float],
"""
The maximum number of missing allele calls to accept. SNPs with
more than this value will be excluded. Set to 0 to require no
missing calls.
missing calls. Can also be a float, which will be interpreted as
a fraction.
""",
]

Expand Down
14 changes: 11 additions & 3 deletions malariagen_data/anoph/snp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1659,18 +1659,26 @@ def biallelic_snp_calls(
# Apply conditions.
if max_missing_an is not None or min_minor_ac is not None:
loc_out = np.ones(ds_out.sizes["variants"], dtype=bool)
an = ac_out.sum(axis=1)

# Apply missingness condition.
if max_missing_an is not None:
an = ac_out.sum(axis=1)
an_missing = (ds_out.sizes["samples"] * ds_out.sizes["ploidy"]) - an
loc_missing = an_missing <= max_missing_an
if isinstance(max_missing_an, float):
an_missing_frac = an_missing / an
loc_missing = an_missing_frac <= max_missing_an
else:
loc_missing = an_missing <= max_missing_an
loc_out &= loc_missing

# Apply minor allele count condition.
if min_minor_ac is not None:
ac_minor = ac_out.min(axis=1)
loc_minor = ac_minor >= min_minor_ac
if isinstance(min_minor_ac, float):
ac_minor_frac = ac_minor / an
loc_minor = ac_minor_frac >= min_minor_ac
else:
loc_minor = ac_minor >= min_minor_ac
loc_out &= loc_minor

ds_out = ds_out.isel(variants=loc_out)
Expand Down
66 changes: 66 additions & 0 deletions tests/anoph/test_snp_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,3 +1387,69 @@ def test_biallelic_snp_calls_and_diplotypes_with_conditions(
max_missing_an=max_missing_an,
n_snps=n_snps_available + 10,
)


@parametrize_with_cases("fixture,api", cases=".")
def test_biallelic_snp_calls_and_diplotypes_with_conditions_fractional(
fixture, api: AnophelesSnpData
):
# Fixed parameters.
contig = random.choice(api.contigs)
all_sample_sets = api.sample_sets()["sample_set"].to_list()
sample_sets = random.choice(all_sample_sets)
site_mask = random.choice((None,) + api.site_mask_ids)

# Parametrise conditions.
min_minor_ac = random.uniform(0, 0.05)
max_missing_an = random.uniform(0.05, 0.2)

# Run tests.
ds = check_biallelic_snp_calls_and_diplotypes(
api=api,
sample_sets=sample_sets,
region=contig,
site_mask=site_mask,
min_minor_ac=min_minor_ac,
max_missing_an=max_missing_an,
)

# Check conditions are met.
ac = ds["variant_allele_count"].values
an = ac.sum(axis=1)
ac_min = ac.min(axis=1)
assert np.all((ac_min / an) >= min_minor_ac)
an_missing = (ds.sizes["samples"] * ds.sizes["ploidy"]) - an
assert np.all((an_missing / an) <= max_missing_an)
gt = ds["call_genotype"].values
ac_check = allel.GenotypeArray(gt).count_alleles(max_allele=1)
assert np.all(ac == ac_check)

# Run tests with thinning.
n_snps_available = ds.sizes["variants"]
# This should always be true, although depends on min_minor_ac and max_missing_an,
# so the range of values for those parameters needs to be chosen with some care.
assert n_snps_available > 2
n_snps_requested = random.randint(1, n_snps_available // 2)
ds_thinned = check_biallelic_snp_calls_and_diplotypes(
api=api,
sample_sets=sample_sets,
region=contig,
site_mask=site_mask,
min_minor_ac=min_minor_ac,
max_missing_an=max_missing_an,
n_snps=n_snps_requested,
)
n_snps_thinned = ds_thinned.sizes["variants"]
assert n_snps_thinned >= n_snps_requested
assert n_snps_thinned <= 2 * n_snps_requested

# Ask for more SNPs than available.
with pytest.raises(ValueError):
api.biallelic_snp_calls(
sample_sets=sample_sets,
region=contig,
site_mask=site_mask,
min_minor_ac=min_minor_ac,
max_missing_an=max_missing_an,
n_snps=n_snps_available + 10,
)

0 comments on commit 5ec5de3

Please sign in to comment.