diff --git a/malariagen_data/anoph/base_params.py b/malariagen_data/anoph/base_params.py index 11c3e6b74..5cefff779 100644 --- a/malariagen_data/anoph/base_params.py +++ b/malariagen_data/anoph/base_params.py @@ -271,19 +271,21 @@ def validate_sample_selection_params( ] min_minor_ac: TypeAlias = Annotated[ - int, + Union[int, float], """ The minimum minor allele count. SNPs with a minor allele count - below this value will be excluded. + below this value will be excluded. Can also be a float, which will + be interpreted as a fraction. """, ] max_missing_an: TypeAlias = Annotated[ - int, + Union[int, float], """ The maximum number of missing allele calls to accept. SNPs with more than this value will be excluded. Set to 0 to require no - missing calls. + missing calls. Can also be a float, which will be interpreted as + a fraction. """, ] diff --git a/malariagen_data/anoph/snp_data.py b/malariagen_data/anoph/snp_data.py index 7db7e91f3..8cb398552 100644 --- a/malariagen_data/anoph/snp_data.py +++ b/malariagen_data/anoph/snp_data.py @@ -1659,18 +1659,26 @@ def biallelic_snp_calls( # Apply conditions. if max_missing_an is not None or min_minor_ac is not None: loc_out = np.ones(ds_out.sizes["variants"], dtype=bool) + an = ac_out.sum(axis=1) # Apply missingness condition. if max_missing_an is not None: - an = ac_out.sum(axis=1) an_missing = (ds_out.sizes["samples"] * ds_out.sizes["ploidy"]) - an - loc_missing = an_missing <= max_missing_an + if isinstance(max_missing_an, float): + an_missing_frac = an_missing / an + loc_missing = an_missing_frac <= max_missing_an + else: + loc_missing = an_missing <= max_missing_an loc_out &= loc_missing # Apply minor allele count condition. if min_minor_ac is not None: ac_minor = ac_out.min(axis=1) - loc_minor = ac_minor >= min_minor_ac + if isinstance(min_minor_ac, float): + ac_minor_frac = ac_minor / an + loc_minor = ac_minor_frac >= min_minor_ac + else: + loc_minor = ac_minor >= min_minor_ac loc_out &= loc_minor ds_out = ds_out.isel(variants=loc_out) diff --git a/tests/anoph/test_snp_data.py b/tests/anoph/test_snp_data.py index 19576a6d3..e31bf35d0 100644 --- a/tests/anoph/test_snp_data.py +++ b/tests/anoph/test_snp_data.py @@ -1387,3 +1387,69 @@ def test_biallelic_snp_calls_and_diplotypes_with_conditions( max_missing_an=max_missing_an, n_snps=n_snps_available + 10, ) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_biallelic_snp_calls_and_diplotypes_with_conditions_fractional( + fixture, api: AnophelesSnpData +): + # Fixed parameters. + contig = random.choice(api.contigs) + all_sample_sets = api.sample_sets()["sample_set"].to_list() + sample_sets = random.choice(all_sample_sets) + site_mask = random.choice((None,) + api.site_mask_ids) + + # Parametrise conditions. + min_minor_ac = random.uniform(0, 0.05) + max_missing_an = random.uniform(0.05, 0.2) + + # Run tests. + ds = check_biallelic_snp_calls_and_diplotypes( + api=api, + sample_sets=sample_sets, + region=contig, + site_mask=site_mask, + min_minor_ac=min_minor_ac, + max_missing_an=max_missing_an, + ) + + # Check conditions are met. + ac = ds["variant_allele_count"].values + an = ac.sum(axis=1) + ac_min = ac.min(axis=1) + assert np.all((ac_min / an) >= min_minor_ac) + an_missing = (ds.sizes["samples"] * ds.sizes["ploidy"]) - an + assert np.all((an_missing / an) <= max_missing_an) + gt = ds["call_genotype"].values + ac_check = allel.GenotypeArray(gt).count_alleles(max_allele=1) + assert np.all(ac == ac_check) + + # Run tests with thinning. + n_snps_available = ds.sizes["variants"] + # This should always be true, although depends on min_minor_ac and max_missing_an, + # so the range of values for those parameters needs to be chosen with some care. + assert n_snps_available > 2 + n_snps_requested = random.randint(1, n_snps_available // 2) + ds_thinned = check_biallelic_snp_calls_and_diplotypes( + api=api, + sample_sets=sample_sets, + region=contig, + site_mask=site_mask, + min_minor_ac=min_minor_ac, + max_missing_an=max_missing_an, + n_snps=n_snps_requested, + ) + n_snps_thinned = ds_thinned.sizes["variants"] + assert n_snps_thinned >= n_snps_requested + assert n_snps_thinned <= 2 * n_snps_requested + + # Ask for more SNPs than available. + with pytest.raises(ValueError): + api.biallelic_snp_calls( + sample_sets=sample_sets, + region=contig, + site_mask=site_mask, + min_minor_ac=min_minor_ac, + max_missing_an=max_missing_an, + n_snps=n_snps_available + 10, + )