Skip to content

Commit

Permalink
Rename populations to cohorts (#47)
Browse files Browse the repository at this point in the history
* rename populations to cohorts

* update poetry to take advantage of scikit-allel wheels
  • Loading branch information
alimanfoo authored May 14, 2021
1 parent e450dff commit 8c75c9e
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 93 deletions.
58 changes: 29 additions & 29 deletions malariagen_data/ag3.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,21 +711,21 @@ def snp_effects(self, transcript, site_mask=None, site_filters="dt_20200416"):
def snp_allele_frequencies(
self,
transcript,
populations,
cohorts,
site_mask=None,
site_filters="dt_20200416",
species_calls=("20200422", "aim"),
sample_sets="v3_wild",
drop_invariant=True,
):
"""Compute per variant population allele frequencies for a gene transcript.
"""Compute per variant allele frequencies for a gene transcript.
Parameters
----------
transcript : str
Gene transcript ID (AgamP4.12), e.g., "AGAP004707-RA".
populations : dict
Dictionary to map population IDs to sample queries, e.g.,
cohorts : dict
Dictionary to map cohort IDs to sample queries, e.g.,
{"bf_2012_col": "country == 'Burkina Faso' and year == 2012 and species == 'coluzzii'"}
site_mask : {"gamb_colu_arab", "gamb_colu", "arab"}
Site filters mask to apply.
Expand All @@ -738,7 +738,7 @@ def snp_allele_frequencies(
identifiers (e.g., ["AG1000G-BF-A", "AG1000G-BF-B"]) or a release identifier (e.g.,
"v3") or a list of release identifiers.
drop_invariant : bool, optional
If True, variants with no alternate allele calls in any populations are dropped from
If True, variants with no alternate allele calls in any cohorts are dropped from
the result.
Returns
Expand Down Expand Up @@ -792,13 +792,13 @@ def snp_allele_frequencies(

# count alleles
afs = dict()
for pop, query in populations.items():
loc_pop = df_meta.eval(query).values
gt_pop = da.compress(loc_pop, gt, axis=1)
ac_pop = (
allel.GenotypeDaskArray(gt_pop).count_alleles(max_allele=3).compute()
for coh, query in cohorts.items():
loc_coh = df_meta.eval(query).values
gt_coh = da.compress(loc_coh, gt, axis=1)
ac_coh = (
allel.GenotypeDaskArray(gt_coh).count_alleles(max_allele=3).compute()
)
afs[pop] = ac_pop.to_frequencies()
afs[coh] = ac_coh.to_frequencies()

# set up columns
cols = {
Expand All @@ -807,14 +807,14 @@ def snp_allele_frequencies(
"alt_allele": alt.astype("U1").flatten(),
}

for pop in populations:
cols[pop] = afs[pop][:, 1:].flatten()
for coh in cohorts:
cols[coh] = afs[coh][:, 1:].flatten()

# build df
df = pandas.DataFrame(cols)

# add max allele freq column
df["max_af"] = df[populations].max(axis=1)
df["max_af"] = df[cohorts].max(axis=1)

# drop invariants
if drop_invariant:
Expand Down Expand Up @@ -1602,16 +1602,16 @@ def gene_cnv(self, contig, sample_sets="v3_wild"):

return ds_out

def gene_cnv_frequencies(self, contig, populations, sample_sets="v3_wild"):
def gene_cnv_frequencies(self, contig, cohorts=None, sample_sets="v3_wild"):
"""Compute modal copy number by gene, then compute the frequency of
amplifications and deletions by population, from HMM data.
amplifications and deletions in one or more cohorts, from HMM data.
Parameters
----------
contig : str
Chromosome arm, e.g., "3R".
populations : dict
Dictionary to map population IDs to sample queries, e.g.,
cohorts : dict
Dictionary to map cohort IDs to sample queries, e.g.,
{"bf_2012_col": "country == 'Burkina Faso' and year == 2012 and species == 'coluzzii'"}
sample_sets : str or list of str
Can be a sample set identifier (e.g., "AG1000G-AO") or a list of sample set
Expand Down Expand Up @@ -1650,20 +1650,20 @@ def gene_cnv_frequencies(self, contig, populations, sample_sets="v3_wild"):
is_amp = cn > expected_cn
is_del = (0 <= cn) & (cn < expected_cn)

# compute population frequencies
for pop, query in populations.items():
# compute cohort frequencies
for coh, query in cohorts.items():
loc_samples = df_samples.eval(query).values
n_samples = np.count_nonzero(loc_samples)
if n_samples == 0:
raise ValueError(f"no samples for population {pop!r}")
is_amp_pop = np.compress(loc_samples, is_amp, axis=1)
is_del_pop = np.compress(loc_samples, is_del, axis=1)
amp_count_pop = np.sum(is_amp_pop, axis=1)
del_count_pop = np.sum(is_del_pop, axis=1)
amp_freq_pop = amp_count_pop / n_samples
del_freq_pop = del_count_pop / n_samples
df[f"{pop}_amp"] = amp_freq_pop
df[f"{pop}_del"] = del_freq_pop
raise ValueError(f"no samples for cohort {coh!r}")
is_amp_coh = np.compress(loc_samples, is_amp, axis=1)
is_del_coh = np.compress(loc_samples, is_del, axis=1)
amp_count_coh = np.sum(is_amp_coh, axis=1)
del_count_coh = np.sum(is_del_coh, axis=1)
amp_freq_coh = amp_count_coh / n_samples
del_freq_coh = del_count_coh / n_samples
df[f"{coh}_amp"] = amp_freq_coh
df[f"{coh}_del"] = del_freq_coh

# set gene ID as index for convenience
df.set_index("ID", inplace=True)
Expand Down
82 changes: 26 additions & 56 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 6 additions & 8 deletions tests/test_ag3.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def test_snp_effects():

def test_snp_allele_frequencies():
ag3 = setup_ag3()
populations = {
cohorts = {
"ke": "country == 'Kenya'",
"bf_2012_col": "country == 'Burkina Faso' and year == 2012 and species == 'coluzzii'",
}
Expand All @@ -600,7 +600,7 @@ def test_snp_allele_frequencies():
# drop invariants
df = ag3.snp_allele_frequencies(
transcript="AGAP009194-RA",
populations=populations,
cohorts=cohorts,
site_mask="gamb_colu",
sample_sets="v3_wild",
drop_invariant=True,
Expand All @@ -618,7 +618,7 @@ def test_snp_allele_frequencies():
# check invariant have been dropped
assert df.max_af.min() > 0

populations = {
cohorts = {
"gm": "country == 'Gambia, The'",
"mz": "country == 'Mozambique' and year == 2004",
}
Expand All @@ -633,7 +633,7 @@ def test_snp_allele_frequencies():
# keep invariants
df = ag3.snp_allele_frequencies(
transcript="AGAP004707-RD",
populations=populations,
cohorts=cohorts,
site_mask="gamb_colu",
sample_sets="v3_wild",
drop_invariant=False,
Expand Down Expand Up @@ -982,7 +982,7 @@ def test_gene_cnv(contig, sample_sets):
@pytest.mark.parametrize("contig", ["2R", "X"])
def test_gene_cnv_frequencies(contig):
ag3 = setup_ag3()
populations = {
cohorts = {
"ke": "country == 'Kenya'",
"bf_2012_col": "country == 'Burkina Faso' and year == 2012 and species == 'coluzzii'",
}
Expand All @@ -1000,9 +1000,7 @@ def test_gene_cnv_frequencies(contig):
]
df_genes = ag3.geneset().query(f"type == 'gene' and contig == '{contig}'")

df = ag3.gene_cnv_frequencies(
contig=contig, sample_sets="v3_wild", populations=populations
)
df = ag3.gene_cnv_frequencies(contig=contig, sample_sets="v3_wild", cohorts=cohorts)

assert isinstance(df, pd.DataFrame)
assert expected_cols == df.columns.tolist()
Expand Down

0 comments on commit 8c75c9e

Please sign in to comment.