diff --git a/malariagen_data/ag3.py b/malariagen_data/ag3.py index 196b4d3ed..3d97077ed 100644 --- a/malariagen_data/ag3.py +++ b/malariagen_data/ag3.py @@ -796,6 +796,9 @@ def snp_allele_frequencies( for coh, query in cohorts.items(): # locate samples loc_coh = df_meta.eval(query).values + n_samples = np.count_nonzero(loc_coh) + if n_samples == 0: + raise ValueError(f"no samples for cohort {coh!r}") gt_coh = np.compress(loc_coh, gt, axis=1) # count alleles ac_coh = allel.GenotypeArray(gt_coh).count_alleles(max_allele=3) diff --git a/tests/test_ag3.py b/tests/test_ag3.py index 89e0516cf..101d331d5 100644 --- a/tests/test_ag3.py +++ b/tests/test_ag3.py @@ -664,6 +664,22 @@ def test_snp_allele_frequencies(): assert np.any(df.max_af == 0) +def test_snp_allele_frequencies_0_cohort(): + ag3 = setup_ag3() + cohorts = { + "bf_2050_col": "country == 'Burkina Faso' and year == 2050 and species == 'coluzzii'", + } + + with pytest.raises(ValueError): + _ = ag3.snp_allele_frequencies( + transcript="AGAP009194-RA", + cohorts=cohorts, + site_mask="gamb_colu", + sample_sets="v3_wild", + drop_invariant=True, + ) + + @pytest.mark.parametrize( "sample_sets", ["AG1000G-AO", ("AG1000G-AO", "AG1000G-UG"), "v3_wild"] ) @@ -1030,3 +1046,20 @@ def test_gene_cnv_frequencies(contig): x = a + d assert np.all(x >= 0) assert np.all(x <= 1) + + +@pytest.mark.parametrize( + "contig", + [ + "X", + ], +) +def test_gene_cnv_frequencies_0_cohort(contig): + ag3 = setup_ag3() + cohorts = { + "bf_2050_col": "country == 'Burkina Faso' and year == 2050 and species == 'coluzzii'", + } + with pytest.raises(ValueError): + _ = ag3.gene_cnv_frequencies( + contig=contig, sample_sets="v3_wild", cohorts=cohorts + )