diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml
index 66feaa829..5b4a4fd08 100644
--- a/.github/workflows/coverage.yml
+++ b/.github/workflows/coverage.yml
@@ -28,7 +28,7 @@ jobs:
run: poetry install
- name: Run tests with coverage
- run: poetry run pytest -v --cov malariagen_data/anoph --cov-report=xml tests/anoph
+ run: poetry run pytest --durations=20 --durations-min=1.0 -v --cov malariagen_data/anoph --cov-report=xml tests/anoph
- name: Upload coverage report
uses: codecov/codecov-action@v3
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 4c5e31ee1..28d8f13a9 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -41,7 +41,7 @@ jobs:
key: gcs_cache_tests_20231119
- name: Run full test suite
- run: poetry run pytest -v tests
+ run: poetry run pytest --durations=20 --durations-min=10.0 -v tests
- name: Save GCS cache
uses: actions/cache/save@v3
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 9bdb021ea..34d8c53c1 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -12,5 +12,7 @@ repos:
hooks:
# Run the linter.
- id: ruff
+ args:
+ - "--fix"
# Run the formatter.
- id: ruff-format
diff --git a/malariagen_data/af1.py b/malariagen_data/af1.py
index eb1f3e342..fda0d39f2 100644
--- a/malariagen_data/af1.py
+++ b/malariagen_data/af1.py
@@ -128,6 +128,7 @@ def __init__(
tqdm_class=tqdm_class,
taxon_colors=TAXON_COLORS,
virtual_contigs=None,
+ gene_names=None,
)
def __repr__(self):
@@ -209,21 +210,3 @@ def _repr_html_(self):
"""
return html
-
- def _transcript_to_gene_name(self, transcript):
- df_genome_features = self.genome_features().set_index("ID")
- rec_transcript = df_genome_features.loc[transcript]
- parent = rec_transcript["Parent"]
-
- # E.g. manual overrides (used in Ag3)
- # if parent == "AGAP004707":
- # parent_name = "Vgsc/para"
- # else:
- # parent_name = rec_parent["Name"]
-
- # Note: Af1 doesn't have the "Name" attribute
- # rec_parent = df_genome_features.loc[parent]
- # parent_name = rec_parent["Name"]
- parent_name = parent
-
- return parent_name
diff --git a/malariagen_data/ag3.py b/malariagen_data/ag3.py
index 8431b0d0d..512431077 100644
--- a/malariagen_data/ag3.py
+++ b/malariagen_data/ag3.py
@@ -27,6 +27,9 @@
"3RL": ("3R", "3L"),
"23X": ("2R", "2L", "3R", "3L", "X"),
}
+GENE_NAMES = {
+ "AGAP004707": "Vgsc/para",
+}
def _setup_aim_palettes():
@@ -191,6 +194,7 @@ def __init__(
tqdm_class=tqdm_class,
taxon_colors=TAXON_COLORS,
virtual_contigs=VIRTUAL_CONTIGS,
+ gene_names=GENE_NAMES,
)
# set up caches
@@ -293,20 +297,6 @@ def _repr_html_(self):
"""
return html
- def _transcript_to_gene_name(self, transcript):
- df_genome_features = self.genome_features().set_index("ID")
- rec_transcript = df_genome_features.loc[transcript]
- parent = rec_transcript["Parent"]
- rec_parent = df_genome_features.loc[parent]
-
- # manual overrides
- if parent == "AGAP004707":
- parent_name = "Vgsc/para"
- else:
- parent_name = rec_parent["Name"]
-
- return parent_name
-
def cross_metadata(self):
"""Load a dataframe containing metadata about samples in colony crosses,
including which samples are parents or progeny in which crosses.
diff --git a/malariagen_data/anoph/cnv_data.py b/malariagen_data/anoph/cnv_data.py
index fe50a639a..3a9cebf3c 100644
--- a/malariagen_data/anoph/cnv_data.py
+++ b/malariagen_data/anoph/cnv_data.py
@@ -163,60 +163,61 @@ def cnv_hmm(
regions: List[Region] = parse_multi_region(self, region)
del region
- debug("access CNV HMM data and concatenate as needed")
- lx = []
- for r in regions:
- ly = []
- for s in sample_sets:
- y = self._cnv_hmm_dataset(
- contig=r.contig,
- sample_set=s,
- inline_array=inline_array,
- chunks=chunks,
+ with self._spinner("Access CNV HMM data"):
+ debug("access CNV HMM data and concatenate as needed")
+ lx = []
+ for r in regions:
+ ly = []
+ for s in sample_sets:
+ y = self._cnv_hmm_dataset(
+ contig=r.contig,
+ sample_set=s,
+ inline_array=inline_array,
+ chunks=chunks,
+ )
+ ly.append(y)
+
+ debug("concatenate data from multiple sample sets")
+ x = simple_xarray_concat(ly, dim=DIM_SAMPLE)
+
+ debug("handle region, do this only once - optimisation")
+ if r.start is not None or r.end is not None:
+ start = x["variant_position"].values
+ end = x["variant_end"].values
+ index = pd.IntervalIndex.from_arrays(start, end, closed="both")
+ # noinspection PyArgumentList
+ other = pd.Interval(r.start, r.end, closed="both")
+ loc_region = index.overlaps(other) # type: ignore
+ x = x.isel(variants=loc_region)
+
+ lx.append(x)
+
+ debug("concatenate data from multiple regions")
+ ds = simple_xarray_concat(lx, dim=DIM_VARIANT)
+
+ debug("handle sample query")
+ if sample_query is not None:
+ debug("load sample metadata")
+ df_samples = self.sample_metadata(sample_sets=sample_sets)
+
+ debug("align sample metadata with CNV data")
+ cnv_samples = ds["sample_id"].values.tolist()
+ df_samples_cnv = (
+ df_samples.set_index("sample_id").loc[cnv_samples].reset_index()
)
- ly.append(y)
-
- debug("concatenate data from multiple sample sets")
- x = simple_xarray_concat(ly, dim=DIM_SAMPLE)
-
- debug("handle region, do this only once - optimisation")
- if r.start is not None or r.end is not None:
- start = x["variant_position"].values
- end = x["variant_end"].values
- index = pd.IntervalIndex.from_arrays(start, end, closed="both")
- # noinspection PyArgumentList
- other = pd.Interval(r.start, r.end, closed="both")
- loc_region = index.overlaps(other) # type: ignore
- x = x.isel(variants=loc_region)
-
- lx.append(x)
-
- debug("concatenate data from multiple regions")
- ds = simple_xarray_concat(lx, dim=DIM_VARIANT)
-
- debug("handle sample query")
- if sample_query is not None:
- debug("load sample metadata")
- df_samples = self.sample_metadata(sample_sets=sample_sets)
-
- debug("align sample metadata with CNV data")
- cnv_samples = ds["sample_id"].values.tolist()
- df_samples_cnv = (
- df_samples.set_index("sample_id").loc[cnv_samples].reset_index()
- )
- debug("apply the query")
- loc_query_samples = df_samples_cnv.eval(sample_query).values
- if np.count_nonzero(loc_query_samples) == 0:
- raise ValueError(f"No samples found for query {sample_query!r}")
+ debug("apply the query")
+ loc_query_samples = df_samples_cnv.eval(sample_query).values
+ if np.count_nonzero(loc_query_samples) == 0:
+ raise ValueError(f"No samples found for query {sample_query!r}")
- ds = ds.isel(samples=loc_query_samples)
+ ds = ds.isel(samples=loc_query_samples)
- debug("handle coverage variance filter")
- if max_coverage_variance is not None:
- cov_var = ds["sample_coverage_variance"].values
- loc_pass_samples = cov_var <= max_coverage_variance
- ds = ds.isel(samples=loc_pass_samples)
+ debug("handle coverage variance filter")
+ if max_coverage_variance is not None:
+ cov_var = ds["sample_coverage_variance"].values
+ loc_pass_samples = cov_var <= max_coverage_variance
+ ds = ds.isel(samples=loc_pass_samples)
return ds
diff --git a/malariagen_data/anoph/frq_params.py b/malariagen_data/anoph/frq_params.py
index 417e507b0..579125b43 100644
--- a/malariagen_data/anoph/frq_params.py
+++ b/malariagen_data/anoph/frq_params.py
@@ -65,3 +65,8 @@
`gene_cnv_frequencies_advanced()`.
""",
]
+
+include_counts: TypeAlias = Annotated[
+ bool,
+ "Include columns with allele counts and number of non-missing allele calls (nobs).",
+]
diff --git a/malariagen_data/anoph/genome_features.py b/malariagen_data/anoph/genome_features.py
index 3f2bfaaaf..d6c6aa5bb 100644
--- a/malariagen_data/anoph/genome_features.py
+++ b/malariagen_data/anoph/genome_features.py
@@ -1,4 +1,4 @@
-from typing import Dict, Optional, Tuple
+from typing import Dict, Optional, Tuple, Mapping
import bokeh.models
import bokeh.plotting
@@ -26,6 +26,7 @@ def __init__(
*,
gff_gene_type: str,
gff_default_attributes: Tuple[str, ...],
+ gene_names: Optional[Mapping[str, str]] = None,
**kwargs,
):
# N.B., this class is designed to work cooperatively, and
@@ -38,6 +39,11 @@ def __init__(
self._gff_gene_type = gff_gene_type
self._gff_default_attributes = gff_default_attributes
+ # Allow manual override of gene names.
+ if gene_names is None:
+ gene_names = dict()
+ self._gene_name_overrides = gene_names
+
# Setup caches.
self._cache_genome_features: Dict[Tuple[str, ...], pd.DataFrame] = dict()
@@ -45,7 +51,7 @@ def __init__(
def _geneset_gff3_path(self):
return self.config["GENESET_GFF3_PATH"]
- def geneset(self, *args, **kwargs):
+ def geneset(self, *args, **kwargs): # pragma: no cover
"""Deprecated, this method has been renamed to genome_features()."""
return self.genome_features(*args, **kwargs)
@@ -429,3 +435,21 @@ def _bokeh_style_genome_xaxis(fig, contig):
fig.xaxis.ticker = bokeh.models.AdaptiveTicker(min_interval=1)
fig.xaxis.minor_tick_line_color = None
fig.xaxis[0].formatter = bokeh.models.NumeralTickFormatter(format="0,0")
+
+ def _transcript_to_parent_name(self, transcript):
+ df_genome_features = self.genome_features().set_index("ID")
+
+ try:
+ rec_transcript = df_genome_features.loc[transcript]
+ except KeyError:
+ return None
+
+ parent_id = rec_transcript["Parent"]
+
+ try:
+ # Manual override.
+ return self._gene_name_overrides[parent_id]
+ except KeyError:
+ rec_parent = df_genome_features.loc[parent_id]
+ # Try to access "Name" attribute, fall back to "ID" if not present.
+ return rec_parent.get("Name", parent_id)
diff --git a/malariagen_data/anoph/sample_metadata.py b/malariagen_data/anoph/sample_metadata.py
index fa158bd9c..eee4b1b95 100644
--- a/malariagen_data/anoph/sample_metadata.py
+++ b/malariagen_data/anoph/sample_metadata.py
@@ -969,3 +969,37 @@ def _setup_sample_hover_data_plotly(
if symbol and symbol not in hover_data:
hover_data.append(symbol)
return hover_data
+
+
+def locate_cohorts(*, cohorts, data):
+ # Build cohort dictionary where key=cohort_id, value=loc_coh.
+ coh_dict = {}
+
+ if isinstance(cohorts, Mapping):
+ # User has supplied a custom dictionary mapping cohort identifiers
+ # to pandas queries.
+
+ for coh, query in cohorts.items():
+ loc_coh = data.eval(query).values
+ coh_dict[coh] = loc_coh
+
+ else:
+ assert isinstance(cohorts, str)
+ # User has supplied the name of a sample metadata column.
+
+ # Convenience to allow things like "admin1_year" instead of "cohort_admin1_year".
+ if "cohort_" + cohorts in data.columns:
+ cohorts = "cohort_" + cohorts
+
+ # Check the given cohort set exists.
+ if cohorts not in data.columns:
+ raise ValueError(f"{cohorts!r} is not a known column in the data.")
+ cohort_labels = data[cohorts].unique()
+
+ # Remove the nans and sort.
+ cohort_labels = sorted([c for c in cohort_labels if isinstance(c, str)])
+ for coh in cohort_labels:
+ loc_coh = data[cohorts] == coh
+ coh_dict[coh] = loc_coh.values
+
+ return coh_dict
diff --git a/malariagen_data/anoph/snp_frq.py b/malariagen_data/anoph/snp_frq.py
new file mode 100644
index 000000000..07466ddb4
--- /dev/null
+++ b/malariagen_data/anoph/snp_frq.py
@@ -0,0 +1,1352 @@
+from typing import Optional, Dict, Union, Callable, List
+import warnings
+from textwrap import dedent
+
+import allel # type: ignore
+import numpy as np
+import pandas as pd
+from numpydoc_decorator import doc # type: ignore
+import xarray as xr
+import numba # type: ignore
+import plotly.express as px # type: ignore
+
+from .. import veff
+from ..util import check_types, pandas_apply
+from .snp_data import AnophelesSnpData
+from .sample_metadata import locate_cohorts
+from . import base_params, frq_params, map_params, plotly_params
+
+
+AA_CHANGE_QUERY = (
+ "effect in ['NON_SYNONYMOUS_CODING', 'START_LOST', 'STOP_LOST', 'STOP_GAINED']"
+)
+
+
+class AnophelesSnpFrequencyAnalysis(
+ AnophelesSnpData,
+):
+ def __init__(
+ self,
+ **kwargs,
+ ):
+ # N.B., this class is designed to work cooperatively, and
+ # so it's important that any remaining parameters are passed
+ # to the superclass constructor.
+ super().__init__(**kwargs)
+
+ # Set up cache variables.
+ self._cache_annotator = None
+
+ def _snp_df_melt(self, *, ds_snp: xr.Dataset) -> pd.DataFrame:
+ """Set up a dataframe with SNP site and filter data,
+ melting each alternate allele into a separate row."""
+
+ with self._spinner(desc="Prepare SNP dataframe"):
+ # Grab contig, pos, ref and alt.
+ contig_index = ds_snp["variant_contig"].values[0]
+ contig = ds_snp.attrs["contigs"][contig_index]
+ pos = ds_snp["variant_position"].values
+ alleles = ds_snp["variant_allele"].values
+ ref = alleles[:, 0]
+ alt = alleles[:, 1:]
+
+ # Access site filters.
+ filter_pass = dict()
+ for m in self.site_mask_ids:
+ x = ds_snp[f"variant_filter_pass_{m}"].values
+ filter_pass[m] = x
+
+ # Set up columns with contig, pos, ref, alt columns, melting
+ # the data out to one row per alternate allele.
+ cols = {
+ "contig": contig,
+ "position": np.repeat(pos, 3),
+ "ref_allele": np.repeat(ref.astype("U1"), 3),
+ "alt_allele": alt.astype("U1").flatten(),
+ }
+
+ # Add mask columns.
+ for m in self.site_mask_ids:
+ x = filter_pass[m]
+ cols[f"pass_{m}"] = np.repeat(x, 3)
+
+ # Construct dataframe.
+ df_snps = pd.DataFrame(cols)
+
+ return df_snps
+
+ def _snp_effect_annotator(self):
+ """Set up variant effect annotator."""
+ if self._cache_annotator is None:
+ self._cache_annotator = veff.Annotator(
+ genome=self.open_genome(), genome_features=self.genome_features()
+ )
+ return self._cache_annotator
+
+ @check_types
+ @doc(
+ summary="Compute variant effects for a gene transcript.",
+ returns="""
+ A dataframe of all possible SNP variants and their effects, one row
+ per variant.
+ """,
+ )
+ def snp_effects(
+ self,
+ transcript: base_params.transcript,
+ site_mask: Optional[base_params.site_mask] = None,
+ ) -> pd.DataFrame:
+ # Access SNP data.
+ ds_snp = self.snp_variants(
+ region=transcript,
+ site_mask=site_mask,
+ )
+
+ # Setup initial dataframe of SNPs.
+ df_snps = self._snp_df_melt(ds_snp=ds_snp)
+
+ # Setup variant effect annotator.
+ ann = self._snp_effect_annotator()
+
+ # Add effects to the dataframe.
+ ann.get_effects(transcript=transcript, variants=df_snps)
+
+ return df_snps
+
+ @check_types
+ @doc(
+ summary="""
+ Compute SNP allele frequencies for a gene transcript.
+ """,
+ returns="""
+ A dataframe of SNP allele frequencies, one row per variant allele.
+ """,
+ notes="""
+ Cohorts with fewer samples than `min_cohort_size` will be excluded from
+ output data frame.
+ """,
+ )
+ def snp_allele_frequencies(
+ self,
+ transcript: base_params.transcript,
+ cohorts: base_params.cohorts,
+ sample_query: Optional[base_params.sample_query] = None,
+ min_cohort_size: base_params.min_cohort_size = 10,
+ site_mask: Optional[base_params.site_mask] = None,
+ sample_sets: Optional[base_params.sample_sets] = None,
+ drop_invariant: frq_params.drop_invariant = True,
+ effects: frq_params.effects = True,
+ include_counts: frq_params.include_counts = False,
+ ) -> pd.DataFrame:
+ # Access sample metadata.
+ df_samples = self.sample_metadata(
+ sample_sets=sample_sets, sample_query=sample_query
+ )
+
+ # Build cohort dictionary, maps cohort labels to boolean indexers.
+ coh_dict = locate_cohorts(cohorts=cohorts, data=df_samples)
+
+ # Remove cohorts below minimum cohort size.
+ coh_dict = {
+ coh: loc_coh
+ for coh, loc_coh in coh_dict.items()
+ if np.count_nonzero(loc_coh) >= min_cohort_size
+ }
+
+ # Early check for no cohorts.
+ if len(coh_dict) == 0:
+ raise ValueError(
+ "No cohorts available for the given sample selection parameters and minimum cohort size."
+ )
+
+ # Access SNP data.
+ ds_snp = self.snp_calls(
+ region=transcript,
+ site_mask=site_mask,
+ sample_sets=sample_sets,
+ sample_query=sample_query,
+ )
+
+ # Early check for no SNPs.
+ if ds_snp.sizes["variants"] == 0: # pragma: no cover
+ raise ValueError("No SNPs available for the given region and site mask.")
+
+ # Access genotypes.
+ gt = ds_snp["call_genotype"].data
+ with self._dask_progress(desc="Load SNP genotypes"):
+ gt = gt.compute()
+
+ # Set up initial dataframe of SNPs.
+ df_snps = self._snp_df_melt(ds_snp=ds_snp)
+
+ # Count alleles.
+ count_cols = dict()
+ nobs_cols = dict()
+ freq_cols = dict()
+ cohorts_iterator = self._progress(
+ coh_dict.items(), desc="Compute allele frequencies"
+ )
+ for coh, loc_coh in cohorts_iterator:
+ n_samples = np.count_nonzero(loc_coh)
+ assert n_samples >= min_cohort_size
+ gt_coh = np.compress(loc_coh, gt, axis=1)
+ ac_coh = np.asarray(allel.GenotypeArray(gt_coh).count_alleles(max_allele=3))
+ an_coh = np.sum(ac_coh, axis=1)[:, None]
+ with np.errstate(divide="ignore", invalid="ignore"):
+ af_coh = np.where(an_coh > 0, ac_coh / an_coh, np.nan)
+ # Melt the frequencies so we get one row per alternate allele.
+ frq = af_coh[:, 1:].flatten()
+ freq_cols["frq_" + coh] = frq
+ count = ac_coh[:, 1:].flatten()
+ count_cols["count_" + coh] = count
+ nobs = np.repeat(an_coh[:, 0], 3)
+ nobs_cols["nobs_" + coh] = nobs
+
+ # Build a dataframe with the frequency columns.
+ df_freqs = pd.DataFrame(freq_cols)
+ df_counts = pd.DataFrame(count_cols)
+ df_nobs = pd.DataFrame(nobs_cols)
+
+ # Compute max_af.
+ df_max_af = pd.DataFrame({"max_af": df_freqs.max(axis=1)})
+
+ # Build the final dataframe.
+ df_snps.reset_index(drop=True, inplace=True)
+ if include_counts:
+ df_snps = pd.concat(
+ [df_snps, df_freqs, df_max_af, df_counts, df_nobs], axis=1
+ )
+ else:
+ df_snps = pd.concat([df_snps, df_freqs, df_max_af], axis=1)
+
+ # Drop invariants.
+ if drop_invariant:
+ loc_variant = df_snps["max_af"] > 0
+
+ # Check for no SNPs remaining after dropping invariants.
+ if np.count_nonzero(loc_variant) == 0: # pragma: no cover
+ raise ValueError("No SNPs remaining after dropping invariant SNPs.")
+
+ df_snps = df_snps.loc[loc_variant]
+
+ # Reset index after filtering.
+ df_snps.reset_index(inplace=True, drop=True)
+
+ if effects:
+ # Add effect annotations.
+ ann = self._snp_effect_annotator()
+ ann.get_effects(
+ transcript=transcript, variants=df_snps, progress=self._progress
+ )
+
+ # Add label.
+ df_snps["label"] = pandas_apply(
+ _make_snp_label_effect,
+ df_snps,
+ columns=["contig", "position", "ref_allele", "alt_allele", "aa_change"],
+ )
+
+ # Set index.
+ df_snps.set_index(
+ ["contig", "position", "ref_allele", "alt_allele", "aa_change"],
+ inplace=True,
+ )
+
+ else:
+ # Add label.
+ df_snps["label"] = pandas_apply(
+ _make_snp_label,
+ df_snps,
+ columns=["contig", "position", "ref_allele", "alt_allele"],
+ )
+
+ # Set index.
+ df_snps.set_index(
+ ["contig", "position", "ref_allele", "alt_allele"],
+ inplace=True,
+ )
+
+ # Add dataframe metadata.
+ gene_name = self._transcript_to_parent_name(transcript)
+ title = transcript
+ if gene_name:
+ title += f" ({gene_name})"
+ title += " SNP frequencies"
+ df_snps.attrs["title"] = title
+
+ return df_snps
+
+ @check_types
+ @doc(
+ summary="""
+ Compute amino acid substitution frequencies for a gene transcript.
+ """,
+ returns="""
+ A dataframe of amino acid allele frequencies, one row per
+ substitution.
+ """,
+ notes="""
+ Cohorts with fewer samples than `min_cohort_size` will be excluded from
+ output data frame.
+ """,
+ )
+ def aa_allele_frequencies(
+ self,
+ transcript: base_params.transcript,
+ cohorts: base_params.cohorts,
+ sample_query: Optional[base_params.sample_query] = None,
+ min_cohort_size: Optional[base_params.min_cohort_size] = 10,
+ site_mask: Optional[base_params.site_mask] = None,
+ sample_sets: Optional[base_params.sample_sets] = None,
+ drop_invariant: frq_params.drop_invariant = True,
+ include_counts: frq_params.include_counts = False,
+ ) -> pd.DataFrame:
+ df_snps = self.snp_allele_frequencies(
+ transcript=transcript,
+ cohorts=cohorts,
+ sample_query=sample_query,
+ min_cohort_size=min_cohort_size,
+ site_mask=site_mask,
+ sample_sets=sample_sets,
+ drop_invariant=drop_invariant,
+ effects=True,
+ include_counts=include_counts,
+ )
+ df_snps.reset_index(inplace=True)
+
+ # We just want aa change.
+ df_ns_snps = df_snps.query(AA_CHANGE_QUERY).copy()
+
+ # Early check for no matching SNPs.
+ if len(df_ns_snps) == 0: # pragma: no cover
+ raise ValueError(
+ "No amino acid change SNPs found for the given transcript and site mask."
+ )
+
+ # N.B., we need to worry about the possibility of the
+ # same aa change due to SNPs at different positions. We cannot
+ # sum frequencies of SNPs at different genomic positions. This
+ # is why we group by position and aa_change, not just aa_change.
+
+ # Group and sum to collapse multi variant allele changes.
+ freq_cols = [col for col in df_ns_snps if col.startswith("frq_")]
+
+ # Special handling here to ensure nans don't get summed to zero.
+ # See also https://github.com/pandas-dev/pandas/issues/20824#issuecomment-705376621
+ def np_sum(g):
+ return np.sum(g.values)
+
+ agg: Dict[str, Union[Callable, str]] = {c: np_sum for c in freq_cols}
+
+ # Add in counts and observations data if requested.
+ if include_counts:
+ count_cols = [col for col in df_ns_snps if col.startswith("count_")]
+ for c in count_cols:
+ agg[c] = "sum"
+ nobs_cols = [col for col in df_ns_snps if col.startswith("nobs_")]
+ for c in nobs_cols:
+ agg[c] = "first"
+
+ keep_cols = (
+ "contig",
+ "transcript",
+ "aa_pos",
+ "ref_allele",
+ "ref_aa",
+ "alt_aa",
+ "effect",
+ "impact",
+ )
+ for c in keep_cols:
+ agg[c] = "first"
+ agg["alt_allele"] = lambda v: "{" + ",".join(v) + "}" if len(v) > 1 else v
+ df_aaf = df_ns_snps.groupby(["position", "aa_change"]).agg(agg).reset_index()
+
+ # Compute new max_af.
+ df_aaf["max_af"] = df_aaf[freq_cols].max(axis=1)
+
+ # Add label.
+ df_aaf["label"] = pandas_apply(
+ _make_snp_label_aa,
+ df_aaf,
+ columns=["aa_change", "contig", "position", "ref_allele", "alt_allele"],
+ )
+
+ # Sort by genomic position.
+ df_aaf = df_aaf.sort_values(["position", "aa_change"])
+
+ # Set index.
+ df_aaf.set_index(["aa_change", "contig", "position"], inplace=True)
+
+ # Add metadata.
+ gene_name = self._transcript_to_parent_name(transcript)
+ title = transcript
+ if gene_name:
+ title += f" ({gene_name})"
+ title += " SNP frequencies"
+ df_aaf.attrs["title"] = title
+
+ return df_aaf
+
+ @check_types
+ @doc(
+ summary="""
+ Group samples by taxon, area (space) and period (time), then compute
+ SNP allele frequencies.
+ """,
+ returns="""
+ The resulting dataset contains data has dimensions "cohorts" and
+ "variants". Variables prefixed with "cohort" are 1-dimensional
+ arrays with data about the cohorts, such as the area, period, taxon
+ and cohort size. Variables prefixed with "variant" are
+ 1-dimensional arrays with data about the variants, such as the
+ contig, position, reference and alternate alleles. Variables
+ prefixed with "event" are 2-dimensional arrays with the allele
+ counts and frequency calculations.
+ """,
+ )
+ def snp_allele_frequencies_advanced(
+ self,
+ transcript: base_params.transcript,
+ area_by: frq_params.area_by,
+ period_by: frq_params.period_by,
+ sample_sets: Optional[base_params.sample_sets] = None,
+ sample_query: Optional[base_params.sample_query] = None,
+ min_cohort_size: base_params.min_cohort_size = 10,
+ drop_invariant: frq_params.drop_invariant = True,
+ variant_query: Optional[frq_params.variant_query] = None,
+ site_mask: Optional[base_params.site_mask] = None,
+ nobs_mode: frq_params.nobs_mode = frq_params.nobs_mode_default,
+ ci_method: Optional[frq_params.ci_method] = frq_params.ci_method_default,
+ ) -> xr.Dataset:
+ # Load sample metadata.
+ df_samples = self.sample_metadata(
+ sample_sets=sample_sets, sample_query=sample_query
+ )
+
+ # Prepare sample metadata for cohort grouping.
+ df_samples = _prep_samples_for_cohort_grouping(
+ df_samples=df_samples,
+ area_by=area_by,
+ period_by=period_by,
+ )
+
+ # Group samples to make cohorts.
+ group_samples_by_cohort = df_samples.groupby(["taxon", "area", "period"])
+
+ # Build cohorts dataframe.
+ df_cohorts = _build_cohorts_from_sample_grouping(
+ group_samples_by_cohort=group_samples_by_cohort,
+ min_cohort_size=min_cohort_size,
+ )
+
+ # Early check for no cohorts.
+ if len(df_cohorts) == 0:
+ raise ValueError(
+ "No cohorts available for the given sample selection parameters and minimum cohort size."
+ )
+
+ # Access SNP calls.
+ ds_snps = self.snp_calls(
+ region=transcript,
+ sample_sets=sample_sets,
+ sample_query=sample_query,
+ site_mask=site_mask,
+ )
+
+ # Early check for no SNPs.
+ if ds_snps.sizes["variants"] == 0: # pragma: no cover
+ raise ValueError("No SNPs available for the given region and site mask.")
+
+ # Access genotypes.
+ gt = ds_snps["call_genotype"].data
+ with self._dask_progress(desc="Load SNP genotypes"):
+ gt = gt.compute()
+
+ # Set up variant variables.
+ contigs = ds_snps.attrs["contigs"]
+ variant_contig = np.repeat(
+ [contigs[i] for i in ds_snps["variant_contig"].values], 3
+ )
+ variant_position = np.repeat(ds_snps["variant_position"].values, 3)
+ alleles = ds_snps["variant_allele"].values
+ variant_ref_allele = np.repeat(alleles[:, 0], 3)
+ variant_alt_allele = alleles[:, 1:].flatten()
+ variant_pass = dict()
+ for site_mask in self.site_mask_ids:
+ variant_pass[site_mask] = np.repeat(
+ ds_snps[f"variant_filter_pass_{site_mask}"].values, 3
+ )
+
+ # Set up main event variables.
+ n_variants, n_cohorts = len(variant_position), len(df_cohorts)
+ count = np.zeros((n_variants, n_cohorts), dtype=int)
+ nobs = np.zeros((n_variants, n_cohorts), dtype=int)
+
+ # Build event count and nobs for each cohort.
+ cohorts_iterator = self._progress(
+ enumerate(df_cohorts.itertuples()),
+ total=len(df_cohorts),
+ desc="Compute SNP allele frequencies",
+ )
+ for cohort_index, cohort in cohorts_iterator:
+ cohort_key = cohort.taxon, cohort.area, cohort.period
+ sample_indices = group_samples_by_cohort.indices[cohort_key]
+
+ cohort_ac, cohort_an = _cohort_alt_allele_counts_melt(
+ gt=gt,
+ indices=sample_indices,
+ max_allele=3,
+ )
+ count[:, cohort_index] = cohort_ac
+
+ if nobs_mode == "called":
+ nobs[:, cohort_index] = cohort_an
+ else:
+ assert nobs_mode == "fixed"
+ nobs[:, cohort_index] = cohort.size * 2
+
+ # Compute frequency.
+ with np.errstate(divide="ignore", invalid="ignore"):
+ # Ignore division warnings.
+ frequency = count / nobs
+
+ # Compute maximum frequency over cohorts.
+ with warnings.catch_warnings():
+ # Ignore "All-NaN slice encountered" warnings.
+ warnings.simplefilter("ignore", category=RuntimeWarning)
+ max_af = np.nanmax(frequency, axis=1)
+
+ # Make dataframe of SNPs.
+ df_variants_cols = {
+ "contig": variant_contig,
+ "position": variant_position,
+ "ref_allele": variant_ref_allele.astype("U1"),
+ "alt_allele": variant_alt_allele.astype("U1"),
+ "max_af": max_af,
+ }
+ for site_mask in self.site_mask_ids:
+ df_variants_cols[f"pass_{site_mask}"] = variant_pass[site_mask]
+ df_variants = pd.DataFrame(df_variants_cols)
+
+ # Deal with SNP alleles not observed.
+ if drop_invariant:
+ loc_variant = max_af > 0
+
+ # Check for no SNPs remaining after dropping invariants.
+ if np.count_nonzero(loc_variant) == 0: # pragma: no cover
+ raise ValueError("No SNPs remaining after dropping invariant SNPs.")
+
+ df_variants = df_variants.loc[loc_variant].reset_index(drop=True)
+ count = np.compress(loc_variant, count, axis=0)
+ nobs = np.compress(loc_variant, nobs, axis=0)
+ frequency = np.compress(loc_variant, frequency, axis=0)
+
+ # Set up variant effect annotator.
+ ann = self._snp_effect_annotator()
+
+ # Add effects to the dataframe.
+ ann.get_effects(
+ transcript=transcript, variants=df_variants, progress=self._progress
+ )
+
+ # Add variant labels.
+ df_variants["label"] = pandas_apply(
+ _make_snp_label_effect,
+ df_variants,
+ columns=["contig", "position", "ref_allele", "alt_allele", "aa_change"],
+ )
+
+ # Build the output dataset.
+ ds_out = xr.Dataset()
+
+ # Cohort variables.
+ for coh_col in df_cohorts.columns:
+ ds_out[f"cohort_{coh_col}"] = "cohorts", df_cohorts[coh_col]
+
+ # Variant variables.
+ for snp_col in df_variants.columns:
+ ds_out[f"variant_{snp_col}"] = "variants", df_variants[snp_col]
+
+ # Event variables.
+ ds_out["event_count"] = ("variants", "cohorts"), count
+ ds_out["event_nobs"] = ("variants", "cohorts"), nobs
+ ds_out["event_frequency"] = ("variants", "cohorts"), frequency
+
+ # Apply variant query.
+ if variant_query is not None:
+ loc_variants = df_variants.eval(variant_query).values
+
+ # Check for no SNPs remaining after applying variant query.
+ if np.count_nonzero(loc_variants) == 0:
+ raise ValueError(
+ f"No SNPs remaining after applying variant query {variant_query!r}."
+ )
+
+ ds_out = ds_out.isel(variants=loc_variants)
+
+ # Add confidence intervals.
+ _add_frequency_ci(ds=ds_out, ci_method=ci_method)
+
+ # Tidy up display by sorting variables.
+ sorted_vars: List[str] = sorted([str(k) for k in ds_out.keys()])
+ ds_out = ds_out[sorted_vars]
+
+ # Add metadata.
+ gene_name = self._transcript_to_parent_name(transcript)
+ title = transcript
+ if gene_name:
+ title += f" ({gene_name})"
+ title += " SNP frequencies"
+ ds_out.attrs["title"] = title
+
+ return ds_out
+
+ @check_types
+ @doc(
+ summary="""
+ Group samples by taxon, area (space) and period (time), then compute
+ amino acid change allele frequencies.
+ """,
+ returns="""
+ The resulting dataset contains data has dimensions "cohorts" and
+ "variants". Variables prefixed with "cohort" are 1-dimensional
+ arrays with data about the cohorts, such as the area, period, taxon
+ and cohort size. Variables prefixed with "variant" are
+ 1-dimensional arrays with data about the variants, such as the
+ contig, position, reference and alternate alleles. Variables
+ prefixed with "event" are 2-dimensional arrays with the allele
+ counts and frequency calculations.
+ """,
+ )
+ def aa_allele_frequencies_advanced(
+ self,
+ transcript: base_params.transcript,
+ area_by: frq_params.area_by,
+ period_by: frq_params.period_by,
+ sample_sets: Optional[base_params.sample_sets] = None,
+ sample_query: Optional[base_params.sample_query] = None,
+ min_cohort_size: base_params.min_cohort_size = 10,
+ variant_query: Optional[frq_params.variant_query] = None,
+ site_mask: Optional[base_params.site_mask] = None,
+ nobs_mode: frq_params.nobs_mode = "called",
+ ci_method: Optional[frq_params.ci_method] = "wilson",
+ ) -> xr.Dataset:
+ # Begin by computing SNP allele frequencies.
+ ds_snp_frq = self.snp_allele_frequencies_advanced(
+ transcript=transcript,
+ area_by=area_by,
+ period_by=period_by,
+ sample_sets=sample_sets,
+ sample_query=sample_query,
+ min_cohort_size=min_cohort_size,
+ drop_invariant=True, # always drop invariant for aa frequencies
+ variant_query=AA_CHANGE_QUERY, # we'll also apply a variant query later
+ site_mask=site_mask,
+ nobs_mode=nobs_mode,
+ ci_method=None, # we will recompute confidence intervals later
+ )
+
+ # N.B., we need to worry about the possibility of the
+ # same aa change due to SNPs at different positions. We cannot
+ # sum frequencies of SNPs at different genomic positions. This
+ # is why we group by position and aa_change, not just aa_change.
+
+ # Add in a special grouping column to work around the fact that xarray currently
+ # doesn't support grouping by multiple variables in the same dimension.
+ df_grouper = ds_snp_frq[
+ ["variant_position", "variant_aa_change"]
+ ].to_dataframe()
+ grouper_var = df_grouper.apply(
+ lambda row: "_".join([str(v) for v in row]), axis="columns"
+ )
+ ds_snp_frq["variant_position_aa_change"] = "variants", grouper_var
+
+ # Group by position and amino acid change.
+ group_by_aa_change = ds_snp_frq.groupby("variant_position_aa_change")
+
+ # Apply aggregation.
+ ds_aa_frq = group_by_aa_change.map(_map_snp_to_aa_change_frq_ds)
+
+ # Add back in cohort variables, unaffected by aggregation.
+ cohort_vars = [v for v in ds_snp_frq if v.startswith("cohort_")]
+ for v in cohort_vars:
+ ds_aa_frq[v] = ds_snp_frq[v]
+
+ # Sort by genomic position.
+ ds_aa_frq = ds_aa_frq.sortby(["variant_position", "variant_aa_change"])
+
+ # Recompute frequency.
+ count = ds_aa_frq["event_count"].values
+ nobs = ds_aa_frq["event_nobs"].values
+ with np.errstate(divide="ignore", invalid="ignore"):
+ frequency = count / nobs # ignore division warnings
+ ds_aa_frq["event_frequency"] = ("variants", "cohorts"), frequency
+
+ # Recompute max frequency over cohorts.
+ with warnings.catch_warnings():
+ # Ignore "All-NaN slice encountered" warnings.
+ warnings.simplefilter("ignore", category=RuntimeWarning)
+ max_af = np.nanmax(ds_aa_frq["event_frequency"].values, axis=1)
+ ds_aa_frq["variant_max_af"] = "variants", max_af
+
+ # Set up variant dataframe, useful intermediate.
+ variant_cols = [v for v in ds_aa_frq if v.startswith("variant_")]
+ df_variants = ds_aa_frq[variant_cols].to_dataframe()
+ df_variants.columns = [c.split("variant_")[1] for c in df_variants.columns]
+
+ # Assign new variant label.
+ label = pandas_apply(
+ _make_snp_label_aa,
+ df_variants,
+ columns=["aa_change", "contig", "position", "ref_allele", "alt_allele"],
+ )
+ ds_aa_frq["variant_label"] = "variants", label
+
+ # Apply variant query if given.
+ if variant_query is not None:
+ loc_variants = df_variants.eval(variant_query).values
+
+ # Check for no SNPs remaining after applying variant query.
+ if np.count_nonzero(loc_variants) == 0:
+ raise ValueError(
+ f"No SNPs remaining after applying variant query {variant_query!r}."
+ )
+
+ ds_aa_frq = ds_aa_frq.isel(variants=loc_variants)
+
+ # Compute new confidence intervals.
+ _add_frequency_ci(ds=ds_aa_frq, ci_method=ci_method)
+
+ # Tidy up display by sorting variables.
+ ds_aa_frq = ds_aa_frq[sorted(ds_aa_frq)]
+
+ gene_name = self._transcript_to_parent_name(transcript)
+ title = transcript
+ if gene_name:
+ title += f" ({gene_name})"
+ title += " SNP frequencies"
+ ds_aa_frq.attrs["title"] = title
+
+ return ds_aa_frq
+
+ @check_types
+ @doc(
+ summary="""
+ Plot a heatmap from a pandas DataFrame of frequencies, e.g., output
+ from `snp_allele_frequencies()` or `gene_cnv_frequencies()`.
+ """,
+ parameters=dict(
+ df="""
+ A DataFrame of frequencies, e.g., output from
+ `snp_allele_frequencies()` or `gene_cnv_frequencies()`.
+ """,
+ index="""
+ One or more column headers that are present in the input dataframe.
+ This becomes the heatmap y-axis row labels. The column/s must
+ produce a unique index.
+ """,
+ max_len="""
+ Displaying large styled dataframes may cause ipython notebooks to
+ crash. If the input dataframe is larger than this value, an error
+ will be raised.
+ """,
+ col_width="""
+ Plot width per column in pixels (px).
+ """,
+ row_height="""
+ Plot height per row in pixels (px).
+ """,
+ kwargs="""
+ Passed through to `px.imshow()`.
+ """,
+ ),
+ notes="""
+ It's recommended to filter the input DataFrame to just rows of interest,
+ i.e., fewer rows than `max_len`.
+ """,
+ )
+ def plot_frequencies_heatmap(
+ self,
+ df: pd.DataFrame,
+ index: Optional[Union[str, List[str]]] = "label",
+ max_len: Optional[int] = 100,
+ col_width: int = 40,
+ row_height: int = 20,
+ x_label: plotly_params.x_label = "Cohorts",
+ y_label: plotly_params.y_label = "Variants",
+ colorbar: plotly_params.colorbar = True,
+ width: plotly_params.width = None,
+ height: plotly_params.height = None,
+ text_auto: plotly_params.text_auto = ".0%",
+ aspect: plotly_params.aspect = "auto",
+ color_continuous_scale: plotly_params.color_continuous_scale = "Reds",
+ title: plotly_params.title = True,
+ show: plotly_params.show = True,
+ renderer: plotly_params.renderer = None,
+ **kwargs,
+ ) -> plotly_params.figure:
+ # Check len of input.
+ if max_len and len(df) > max_len:
+ raise ValueError(
+ dedent(
+ f"""
+ Input DataFrame is longer than max_len parameter value {max_len}, which means
+ that the plot is likely to be very large. If you really want to go ahead,
+ please rerun the function with max_len=None.
+ """
+ )
+ )
+
+ # Handle title.
+ if title is True:
+ title = df.attrs.get("title", None)
+
+ # Indexing.
+ if index is None:
+ index = list(df.index.names)
+ df = df.reset_index().copy()
+ if isinstance(index, list):
+ index_col = (
+ df[index]
+ .astype(str)
+ .apply(
+ lambda row: ", ".join([o for o in row if o is not None]),
+ axis="columns",
+ )
+ )
+ else:
+ assert isinstance(index, str)
+ index_col = df[index].astype(str)
+
+ # Check that index is unique.
+ if not index_col.is_unique:
+ raise ValueError(f"{index} does not produce a unique index")
+
+ # Drop and re-order columns.
+ frq_cols = [col for col in df.columns if col.startswith("frq_")]
+
+ # Keep only freq cols.
+ heatmap_df = df[frq_cols].copy()
+
+ # Set index.
+ heatmap_df.set_index(index_col, inplace=True)
+
+ # Clean column names.
+ heatmap_df.columns = heatmap_df.columns.str.lstrip("frq_")
+
+ # Deal with width and height.
+ if width is None:
+ width = 400 + col_width * len(heatmap_df.columns)
+ if colorbar:
+ width += 40
+ if height is None:
+ height = 200 + row_height * len(heatmap_df)
+ if title is not None:
+ height += 40
+
+ # Plotly heatmap styling.
+ fig = px.imshow(
+ img=heatmap_df,
+ zmin=0,
+ zmax=1,
+ width=width,
+ height=height,
+ text_auto=text_auto,
+ aspect=aspect,
+ color_continuous_scale=color_continuous_scale,
+ title=title,
+ **kwargs,
+ )
+
+ fig.update_xaxes(side="bottom", tickangle=30)
+ if x_label is not None:
+ fig.update_xaxes(title=x_label)
+ if y_label is not None:
+ fig.update_yaxes(title=y_label)
+ fig.update_layout(
+ coloraxis_colorbar=dict(
+ title="Frequency",
+ tickvals=[0, 0.2, 0.4, 0.6, 0.8, 1.0],
+ ticktext=["0%", "20%", "40%", "60%", "80%", "100%"],
+ )
+ )
+ if not colorbar:
+ fig.update(layout_coloraxis_showscale=False)
+
+ if show: # pragma: no cover
+ fig.show(renderer=renderer)
+ return None
+ else:
+ return fig
+
+ @check_types
+ @doc(
+ summary="Create a time series plot of variant frequencies using plotly.",
+ parameters=dict(
+ ds="""
+ A dataset of variant frequencies, such as returned by
+ `snp_allele_frequencies_advanced()`,
+ `aa_allele_frequencies_advanced()` or
+ `gene_cnv_frequencies_advanced()`.
+ """,
+ kwargs="Passed through to `px.line()`.",
+ ),
+ returns="""
+ A plotly figure containing line graphs. The resulting figure will
+ have one panel per cohort, grouped into columns by taxon, and
+ grouped into rows by area. Markers and lines show frequencies of
+ variants.
+ """,
+ )
+ def plot_frequencies_time_series(
+ self,
+ ds: xr.Dataset,
+ height: plotly_params.height = None,
+ width: plotly_params.width = None,
+ title: plotly_params.title = True,
+ legend_sizing: plotly_params.legend_sizing = "constant",
+ show: plotly_params.show = True,
+ renderer: plotly_params.renderer = None,
+ **kwargs,
+ ) -> plotly_params.figure:
+ # Handle title.
+ if title is True:
+ title = ds.attrs.get("title", None)
+
+ # Extract cohorts into a dataframe.
+ cohort_vars = [v for v in ds if str(v).startswith("cohort_")]
+ df_cohorts = ds[cohort_vars].to_dataframe()
+ df_cohorts.columns = [c.split("cohort_")[1] for c in df_cohorts.columns] # type: ignore
+
+ # Extract variant labels.
+ variant_labels = ds["variant_label"].values
+
+ # Build a long-form dataframe from the dataset.
+ dfs = []
+ for cohort_index, cohort in enumerate(df_cohorts.itertuples()):
+ ds_cohort = ds.isel(cohorts=cohort_index)
+ df = pd.DataFrame(
+ {
+ "taxon": cohort.taxon,
+ "area": cohort.area,
+ "date": cohort.period_start,
+ "period": str(
+ cohort.period
+ ), # use string representation for hover label
+ "sample_size": cohort.size,
+ "variant": variant_labels,
+ "count": ds_cohort["event_count"].values,
+ "nobs": ds_cohort["event_nobs"].values,
+ "frequency": ds_cohort["event_frequency"].values,
+ "frequency_ci_low": ds_cohort["event_frequency_ci_low"].values,
+ "frequency_ci_upp": ds_cohort["event_frequency_ci_upp"].values,
+ }
+ )
+ dfs.append(df)
+ df_events = pd.concat(dfs, axis=0).reset_index(drop=True)
+
+ # Remove events with no observations.
+ df_events = df_events.query("nobs > 0").copy()
+
+ # Calculate error bars.
+ frq = df_events["frequency"]
+ frq_ci_low = df_events["frequency_ci_low"]
+ frq_ci_upp = df_events["frequency_ci_upp"]
+ df_events["frequency_error"] = frq_ci_upp - frq
+ df_events["frequency_error_minus"] = frq - frq_ci_low
+
+ # Make a plot.
+ fig = px.line(
+ df_events,
+ facet_col="taxon",
+ facet_row="area",
+ x="date",
+ y="frequency",
+ error_y="frequency_error",
+ error_y_minus="frequency_error_minus",
+ color="variant",
+ markers=True,
+ hover_name="variant",
+ hover_data={
+ "frequency": ":.0%",
+ "period": True,
+ "area": True,
+ "taxon": True,
+ "sample_size": True,
+ "date": False,
+ "variant": False,
+ },
+ height=height,
+ width=width,
+ title=title,
+ labels={
+ "date": "Date",
+ "frequency": "Frequency",
+ "variant": "Variant",
+ "taxon": "Taxon",
+ "area": "Area",
+ "period": "Period",
+ "sample_size": "Sample size",
+ },
+ **kwargs,
+ )
+
+ # Tidy plot.
+ fig.update_layout(
+ yaxis_range=[-0.05, 1.05],
+ legend=dict(itemsizing=legend_sizing, tracegroupgap=0),
+ )
+
+ if show: # pragma: no cover
+ fig.show(renderer=renderer)
+ return None
+ else:
+ return fig
+
+ @check_types
+ @doc(
+ summary="""
+ Plot markers on a map showing variant frequencies for cohorts grouped
+ by area (space), period (time) and taxon.
+ """,
+ parameters=dict(
+ m="The map on which to add the markers.",
+ variant="Index or label of variant to plot.",
+ taxon="Taxon to show markers for.",
+ period="Time period to show markers for.",
+ clear="""
+ If True, clear all layers (except the base layer) from the map
+ before adding new markers.
+ """,
+ ),
+ )
+ def plot_frequencies_map_markers(
+ self,
+ m,
+ ds: frq_params.ds_frequencies_advanced,
+ variant: Union[int, str],
+ taxon: str,
+ period: pd.Period,
+ clear: bool = True,
+ ):
+ # Only import here because of some problems importing globally.
+ import ipyleaflet # type: ignore
+ import ipywidgets # type: ignore
+
+ # Slice dataset to variant of interest.
+ if isinstance(variant, int):
+ ds_variant = ds.isel(variants=variant)
+ variant_label = ds["variant_label"].values[variant]
+ else:
+ assert isinstance(variant, str)
+ ds_variant = ds.set_index(variants="variant_label").sel(variants=variant)
+ variant_label = variant
+
+ # Convert to a dataframe for convenience.
+ df_markers = ds_variant[
+ [
+ "cohort_taxon",
+ "cohort_area",
+ "cohort_period",
+ "cohort_lat_mean",
+ "cohort_lon_mean",
+ "cohort_size",
+ "event_frequency",
+ "event_frequency_ci_low",
+ "event_frequency_ci_upp",
+ ]
+ ].to_dataframe()
+
+ # Select data matching taxon and period parameters.
+ df_markers = df_markers.loc[
+ (
+ (df_markers["cohort_taxon"] == taxon)
+ & (df_markers["cohort_period"] == period)
+ )
+ ]
+
+ # Clear existing layers in the map.
+ if clear:
+ for layer in m.layers[1:]:
+ m.remove_layer(layer)
+
+ # Add markers.
+ for x in df_markers.itertuples():
+ marker = ipyleaflet.CircleMarker()
+ marker.location = (x.cohort_lat_mean, x.cohort_lon_mean)
+ marker.radius = 20
+ marker.color = "black"
+ marker.weight = 1
+ marker.fill_color = "red"
+ marker.fill_opacity = x.event_frequency
+ popup_html = f"""
+ {variant_label}
+ Taxon: {x.cohort_taxon}
+ Area: {x.cohort_area}
+ Period: {x.cohort_period}
+ Sample size: {x.cohort_size}
+ Frequency: {x.event_frequency:.0%}
+ (95% CI: {x.event_frequency_ci_low:.0%} - {x.event_frequency_ci_upp:.0%})
+ """
+ marker.popup = ipyleaflet.Popup(
+ child=ipywidgets.HTML(popup_html),
+ )
+ m.add(marker)
+
+ @check_types
+ @doc(
+ summary="""
+ Create an interactive map with markers showing variant frequencies or
+ cohorts grouped by area (space), period (time) and taxon.
+ """,
+ parameters=dict(
+ title="""
+ If True, attempt to use metadata from input dataset as a plot
+ title. Otherwise, use supplied value as a title.
+ """,
+ epilogue="Additional text to display below the map.",
+ ),
+ returns="""
+ An interactive map with widgets for selecting which variant, taxon
+ and time period to display.
+ """,
+ )
+ def plot_frequencies_interactive_map(
+ self,
+ ds: frq_params.ds_frequencies_advanced,
+ center: map_params.center = map_params.center_default,
+ zoom: map_params.zoom = map_params.zoom_default,
+ title: Union[bool, str] = True,
+ epilogue: Union[bool, str] = True,
+ ):
+ import ipyleaflet
+ import ipywidgets
+
+ # Handle title.
+ if title is True:
+ title = ds.attrs.get("title", None)
+
+ # Create a map.
+ freq_map = ipyleaflet.Map(center=center, zoom=zoom)
+
+ # Set up interactive controls.
+ variants = ds["variant_label"].values
+ taxa = np.unique(ds["cohort_taxon"].values)
+ periods = np.unique(ds["cohort_period"].values)
+ controls = ipywidgets.interactive(
+ self.plot_frequencies_map_markers,
+ m=ipywidgets.fixed(freq_map),
+ ds=ipywidgets.fixed(ds),
+ variant=ipywidgets.Dropdown(options=variants, description="Variant: "),
+ taxon=ipywidgets.Dropdown(options=taxa, description="Taxon: "),
+ period=ipywidgets.Dropdown(options=periods, description="Period: "),
+ clear=ipywidgets.fixed(True),
+ )
+
+ # Lay out widgets.
+ components = []
+ if title is not None:
+ components.append(ipywidgets.HTML(value=f"
{title}
"))
+ components.append(controls)
+ components.append(freq_map)
+ if epilogue is True:
+ epilogue = """
+ Variant frequencies are shown as coloured markers. Opacity of color
+ denotes frequency. Click on a marker for more information.
+ """
+ if epilogue:
+ components.append(ipywidgets.HTML(value=f"{epilogue}"))
+
+ out = ipywidgets.VBox(components)
+
+ return out
+
+
+def _make_snp_label(contig, position, ref_allele, alt_allele):
+ return f"{contig}:{position:,} {ref_allele}>{alt_allele}"
+
+
+def _make_snp_label_effect(contig, position, ref_allele, alt_allele, aa_change):
+ label = f"{contig}:{position:,} {ref_allele}>{alt_allele}"
+ if isinstance(aa_change, str):
+ label += f" ({aa_change})"
+ return label
+
+
+def _make_snp_label_aa(aa_change, contig, position, ref_allele, alt_allele):
+ label = f"{aa_change} ({contig}:{position:,} {ref_allele}>{alt_allele})"
+ return label
+
+
+def _make_sample_period_month(row):
+ year = row.year
+ month = row.month
+ if year > 0 and month > 0:
+ return pd.Period(freq="M", year=year, month=month)
+ else:
+ return pd.NaT
+
+
+def _make_sample_period_quarter(row):
+ year = row.year
+ month = row.month
+ if year > 0 and month > 0:
+ return pd.Period(freq="Q", year=year, month=month)
+ else:
+ return pd.NaT
+
+
+def _make_sample_period_year(row):
+ year = row.year
+ if year > 0:
+ return pd.Period(freq="A", year=year)
+ else:
+ return pd.NaT
+
+
+def _prep_samples_for_cohort_grouping(*, df_samples, area_by, period_by):
+ # Take a copy, as we will modify the dataframe.
+ df_samples = df_samples.copy()
+
+ # Fix intermediate taxon values - we only want to build cohorts with clean
+ # taxon calls, so we set intermediate values to None.
+ loc_intermediate_taxon = (
+ df_samples["taxon"].str.startswith("intermediate").fillna(False)
+ )
+ df_samples.loc[loc_intermediate_taxon, "taxon"] = None
+
+ # Add period column.
+ if period_by == "year":
+ make_period = _make_sample_period_year
+ elif period_by == "quarter":
+ make_period = _make_sample_period_quarter
+ elif period_by == "month":
+ make_period = _make_sample_period_month
+ else: # pragma: no cover
+ raise ValueError(
+ f"Value for period_by parameter must be one of 'year', 'quarter', 'month'; found {period_by!r}."
+ )
+ sample_period = df_samples.apply(make_period, axis="columns")
+ df_samples["period"] = sample_period
+
+ # Add area column for consistent output.
+ df_samples["area"] = df_samples[area_by]
+
+ return df_samples
+
+
+def _build_cohorts_from_sample_grouping(*, group_samples_by_cohort, min_cohort_size):
+ # Build cohorts dataframe.
+ df_cohorts = group_samples_by_cohort.agg(
+ size=("sample_id", len),
+ lat_mean=("latitude", "mean"),
+ lat_max=("latitude", "mean"),
+ lat_min=("latitude", "mean"),
+ lon_mean=("longitude", "mean"),
+ lon_max=("longitude", "mean"),
+ lon_min=("longitude", "mean"),
+ )
+ # Reset index so that the index fields are included as columns.
+ df_cohorts = df_cohorts.reset_index()
+
+ # Add cohort helper variables.
+ cohort_period_start = df_cohorts["period"].apply(lambda v: v.start_time)
+ cohort_period_end = df_cohorts["period"].apply(lambda v: v.end_time)
+ df_cohorts["period_start"] = cohort_period_start
+ df_cohorts["period_end"] = cohort_period_end
+ # Create a label that is similar to the cohort metadata,
+ # although this won't be perfect.
+ df_cohorts["label"] = df_cohorts.apply(
+ lambda v: f"{v.area}_{v.taxon[:4]}_{v.period}", axis="columns"
+ )
+
+ # Apply minimum cohort size.
+ df_cohorts = df_cohorts.query(f"size >= {min_cohort_size}").reset_index(drop=True)
+
+ return df_cohorts
+
+
+def _cohort_alt_allele_counts_melt(*, gt, indices, max_allele):
+ ac_alt_melt, an = _cohort_alt_allele_counts_melt_kernel(gt, indices, max_allele)
+ an_melt = np.repeat(an, max_allele, axis=0)
+ return ac_alt_melt, an_melt
+
+
+@numba.njit
+def _cohort_alt_allele_counts_melt_kernel(
+ gt, sample_indices, max_allele
+): # pragma: no cover
+ n_variants = gt.shape[0]
+ n_samples = sample_indices.shape[0]
+ ploidy = gt.shape[2]
+
+ ac_alt_melt = np.zeros(n_variants * max_allele, dtype=np.int64)
+ an = np.zeros(n_variants, dtype=np.int64)
+
+ for i in range(n_variants):
+ out_i_offset = (i * max_allele) - 1
+ for j in range(n_samples):
+ sample_index = sample_indices[j]
+ for k in range(ploidy):
+ allele = gt[i, sample_index, k]
+ if allele > 0:
+ out_i = out_i_offset + allele
+ ac_alt_melt[out_i] += 1
+ an[i] += 1
+ elif allele == 0:
+ an[i] += 1
+
+ return ac_alt_melt, an
+
+
+def _add_frequency_ci(*, ds, ci_method):
+ from statsmodels.stats.proportion import proportion_confint # type: ignore
+
+ if ci_method is not None:
+ count = ds["event_count"].values
+ nobs = ds["event_nobs"].values
+ with np.errstate(divide="ignore", invalid="ignore"):
+ frq_ci_low, frq_ci_upp = proportion_confint(
+ count=count, nobs=nobs, method=ci_method
+ )
+ ds["event_frequency_ci_low"] = ("variants", "cohorts"), frq_ci_low
+ ds["event_frequency_ci_upp"] = ("variants", "cohorts"), frq_ci_upp
+
+
+def _map_snp_to_aa_change_frq_ds(ds):
+ # Keep only variables that make sense for amino acid substitutions.
+ keep_vars = [
+ "variant_contig",
+ "variant_position",
+ "variant_transcript",
+ "variant_effect",
+ "variant_impact",
+ "variant_aa_pos",
+ "variant_aa_change",
+ "variant_ref_allele",
+ "variant_ref_aa",
+ "variant_alt_aa",
+ "event_nobs",
+ ]
+
+ if ds.sizes["variants"] == 1:
+ # Keep everything as-is, no need for aggregation.
+ ds_out = ds[keep_vars + ["variant_alt_allele", "event_count"]]
+
+ else:
+ # Take the first value from all variants variables.
+ ds_out = ds[keep_vars].isel(variants=[0])
+
+ # Sum event count over variants.
+ count = ds["event_count"].values.sum(axis=0, keepdims=True)
+ ds_out["event_count"] = ("variants", "cohorts"), count
+
+ # Collapse alt allele.
+ alt_allele = "{" + ",".join(ds["variant_alt_allele"].values) + "}"
+ ds_out["variant_alt_allele"] = (
+ "variants",
+ np.array([alt_allele], dtype=object),
+ )
+
+ return ds_out
diff --git a/malariagen_data/anopheles.py b/malariagen_data/anopheles.py
index 783e8017e..591ed07f3 100644
--- a/malariagen_data/anopheles.py
+++ b/malariagen_data/anopheles.py
@@ -3,7 +3,7 @@
from abc import abstractmethod
from bisect import bisect_left, bisect_right
from collections import Counter
-from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, Sequence
+from typing import Any, Dict, List, Mapping, Optional, Tuple, Sequence
import allel # type: ignore
import bokeh.layouts
@@ -20,14 +20,18 @@
from numpydoc_decorator import doc # type: ignore
from malariagen_data.anoph import tree_params
+from malariagen_data.anoph.snp_frq import (
+ AnophelesSnpFrequencyAnalysis,
+ _add_frequency_ci,
+ _build_cohorts_from_sample_grouping,
+ _prep_samples_for_cohort_grouping,
+)
-from . import veff
from .anoph import (
aim_params,
base_params,
dash_params,
diplotype_distance_params,
- frq_params,
fst_params,
g123_params,
gplt_params,
@@ -36,7 +40,6 @@
hapnet_params,
het_params,
ihs_params,
- map_params,
plotly_params,
xpehh_params,
)
@@ -49,7 +52,7 @@
from .anoph.hap_data import AnophelesHapData, hap_params
from .anoph.igv import AnophelesIgv
from .anoph.pca import AnophelesPca
-from .anoph.sample_metadata import AnophelesSampleMetadata
+from .anoph.sample_metadata import AnophelesSampleMetadata, locate_cohorts
from .anoph.snp_data import AnophelesSnpData
from .mjn import median_joining_network, mjn_graph
from .util import (
@@ -61,18 +64,15 @@
biallelic_diplotype_sqeuclidean,
check_types,
jackknife_ci,
- locate_region,
parse_multi_region,
parse_single_region,
pdist_abs_hamming,
plotly_discrete_legend,
region_str,
simple_xarray_concat,
+ pandas_apply,
)
-AA_CHANGE_QUERY = (
- "effect in ['NON_SYNONYMOUS_CODING', 'START_LOST', 'STOP_LOST', 'STOP_GAINED']"
-)
DEFAULT_MAX_COVERAGE_VARIANCE = 0.2
@@ -98,6 +98,7 @@
# work around pycharm failing to recognise that doc() is callable
# noinspection PyCallingNonCallable
class AnophelesDataResource(
+ AnophelesSnpFrequencyAnalysis,
AnophelesPca,
AnophelesIgv,
AnophelesCnvData,
@@ -140,6 +141,7 @@ def __init__(
storage_options: Mapping, # used by fsspec via init_filesystem(url, **kwargs)
taxon_colors: Optional[Mapping[str, str]],
virtual_contigs: Optional[Mapping[str, Sequence[str]]],
+ gene_names: Optional[Mapping[str, str]],
):
super().__init__(
url=url,
@@ -169,12 +171,9 @@ def __init__(
tqdm_class=tqdm_class,
taxon_colors=taxon_colors,
virtual_contigs=virtual_contigs,
+ gene_names=gene_names,
)
- # set up caches
- # TODO review type annotations here, maybe can tighten
- self._cache_annotator = None
-
@property
@abstractmethod
def _fst_gwss_results_cache_name(self):
@@ -215,311 +214,6 @@ def _h1x_gwss_cache_name(self):
def _ihs_gwss_cache_name(self):
raise NotImplementedError("Must override _ihs_gwss_cache_name")
- @abstractmethod
- def _transcript_to_gene_name(self, transcript):
- # children may have different manual overrides.
- raise NotImplementedError("Must override _transcript_to_gene_name")
-
- @check_types
- @doc(
- summary="""
- Group samples by taxon, area (space) and period (time), then compute
- SNP allele frequencies.
- """,
- returns="""
- The resulting dataset contains data has dimensions "cohorts" and
- "variants". Variables prefixed with "cohort" are 1-dimensional
- arrays with data about the cohorts, such as the area, period, taxon
- and cohort size. Variables prefixed with "variant" are
- 1-dimensional arrays with data about the variants, such as the
- contig, position, reference and alternate alleles. Variables
- prefixed with "event" are 2-dimensional arrays with the allele
- counts and frequency calculations.
- """,
- )
- def snp_allele_frequencies_advanced(
- self,
- transcript: base_params.transcript,
- area_by: frq_params.area_by,
- period_by: frq_params.period_by,
- sample_sets: Optional[base_params.sample_sets] = None,
- sample_query: Optional[base_params.sample_query] = None,
- min_cohort_size: base_params.min_cohort_size = 10,
- drop_invariant: frq_params.drop_invariant = True,
- variant_query: Optional[frq_params.variant_query] = None,
- site_mask: Optional[base_params.site_mask] = None,
- nobs_mode: frq_params.nobs_mode = frq_params.nobs_mode_default,
- ci_method: Optional[frq_params.ci_method] = frq_params.ci_method_default,
- ) -> xr.Dataset:
- debug = self._log.debug
-
- debug("check parameters")
- self._check_param_min_cohort_size(min_cohort_size)
-
- debug("load sample metadata")
- df_samples = self.sample_metadata(
- sample_sets=sample_sets, sample_query=sample_query
- )
-
- debug("access SNP calls")
- ds_snps = self.snp_calls(
- region=transcript,
- sample_sets=sample_sets,
- sample_query=sample_query,
- site_mask=site_mask,
- )
-
- debug("access genotypes")
- gt = ds_snps["call_genotype"].data
-
- debug("prepare sample metadata for cohort grouping")
- df_samples = self._prep_samples_for_cohort_grouping(
- df_samples=df_samples,
- area_by=area_by,
- period_by=period_by,
- )
-
- debug("group samples to make cohorts")
- group_samples_by_cohort = df_samples.groupby(["taxon", "area", "period"])
-
- debug("build cohorts dataframe")
- df_cohorts = self._build_cohorts_from_sample_grouping(
- group_samples_by_cohort, min_cohort_size
- )
-
- debug("bring genotypes into memory")
- with self._dask_progress(desc="Load SNP genotypes"):
- gt = gt.compute()
-
- debug("set up variant variables")
- contigs = ds_snps.attrs["contigs"]
- variant_contig = np.repeat(
- [contigs[i] for i in ds_snps["variant_contig"].values], 3
- )
- variant_position = np.repeat(ds_snps["variant_position"].values, 3)
- alleles = ds_snps["variant_allele"].values
- variant_ref_allele = np.repeat(alleles[:, 0], 3)
- variant_alt_allele = alleles[:, 1:].flatten()
- variant_pass = dict()
- for site_mask in self.site_mask_ids:
- variant_pass[site_mask] = np.repeat(
- ds_snps[f"variant_filter_pass_{site_mask}"].values, 3
- )
-
- debug("setup main event variables")
- n_variants, n_cohorts = len(variant_position), len(df_cohorts)
- count = np.zeros((n_variants, n_cohorts), dtype=int)
- nobs = np.zeros((n_variants, n_cohorts), dtype=int)
-
- debug("build event count and nobs for each cohort")
- cohorts_iterator = self._progress(
- enumerate(df_cohorts.itertuples()),
- total=len(df_cohorts),
- desc="Compute SNP allele frequencies",
- )
- for cohort_index, cohort in cohorts_iterator:
- cohort_key = cohort.taxon, cohort.area, cohort.period
- sample_indices = group_samples_by_cohort.indices[cohort_key]
-
- cohort_ac, cohort_an = self._cohort_alt_allele_counts_melt(
- gt, sample_indices, max_allele=3
- )
- count[:, cohort_index] = cohort_ac
-
- if nobs_mode == "called":
- nobs[:, cohort_index] = cohort_an
- elif nobs_mode == "fixed":
- nobs[:, cohort_index] = cohort.size * 2
- else:
- raise ValueError(f"Bad nobs_mode: {nobs_mode!r}")
-
- debug("compute frequency")
- with np.errstate(divide="ignore", invalid="ignore"):
- # ignore division warnings
- frequency = count / nobs
-
- debug("compute maximum frequency over cohorts")
- with warnings.catch_warnings():
- # ignore "All-NaN slice encountered" warnings
- warnings.simplefilter("ignore", category=RuntimeWarning)
- max_af = np.nanmax(frequency, axis=1)
-
- debug("make dataframe of SNPs")
- df_variants_cols = {
- "contig": variant_contig,
- "position": variant_position,
- "ref_allele": variant_ref_allele.astype("U1"),
- "alt_allele": variant_alt_allele.astype("U1"),
- "max_af": max_af,
- }
- for site_mask in self.site_mask_ids:
- df_variants_cols[f"pass_{site_mask}"] = variant_pass[site_mask]
- df_variants = pd.DataFrame(df_variants_cols)
-
- debug("deal with SNP alleles not observed")
- if drop_invariant:
- loc_variant = max_af > 0
- df_variants = df_variants.loc[loc_variant].reset_index(drop=True)
- count = np.compress(loc_variant, count, axis=0)
- nobs = np.compress(loc_variant, nobs, axis=0)
- frequency = np.compress(loc_variant, frequency, axis=0)
-
- debug("set up variant effect annotator")
- ann = self._annotator()
-
- debug("add effects to the dataframe")
- ann.get_effects(
- transcript=transcript, variants=df_variants, progress=self._progress
- )
-
- debug("add variant labels")
- df_variants["label"] = self._pandas_apply(
- self._make_snp_label_effect,
- df_variants,
- columns=["contig", "position", "ref_allele", "alt_allele", "aa_change"],
- )
-
- debug("build the output dataset")
- ds_out = xr.Dataset()
-
- debug("cohort variables")
- for coh_col in df_cohorts.columns:
- ds_out[f"cohort_{coh_col}"] = "cohorts", df_cohorts[coh_col]
-
- debug("variant variables")
- for snp_col in df_variants.columns:
- ds_out[f"variant_{snp_col}"] = "variants", df_variants[snp_col]
-
- debug("event variables")
- ds_out["event_count"] = ("variants", "cohorts"), count
- ds_out["event_nobs"] = ("variants", "cohorts"), nobs
- ds_out["event_frequency"] = ("variants", "cohorts"), frequency
-
- debug("apply variant query")
- if variant_query is not None:
- loc_variants = df_variants.eval(variant_query).values
- ds_out = ds_out.isel(variants=loc_variants)
-
- debug("add confidence intervals")
- self._add_frequency_ci(ds_out, ci_method)
-
- debug("tidy up display by sorting variables")
- sorted_vars: List[str] = sorted([str(k) for k in ds_out.keys()])
- ds_out = ds_out[sorted_vars]
-
- debug("add metadata")
- gene_name = self._transcript_to_gene_name(transcript)
- title = transcript
- if gene_name:
- title += f" ({gene_name})"
- title += " SNP frequencies"
- ds_out.attrs["title"] = title
-
- return ds_out
-
- # Start of @staticmethod
-
- @staticmethod
- def _locate_cohorts(*, cohorts, df_samples):
- # build cohort dictionary where key=cohort_id, value=loc_coh
- coh_dict = {}
-
- if isinstance(cohorts, dict):
- # user has supplied a custom dictionary mapping cohort identifiers
- # to pandas queries
-
- for coh, query in cohorts.items():
- # locate samples
- loc_coh = df_samples.eval(query).values
- coh_dict[coh] = loc_coh
-
- if isinstance(cohorts, str):
- # user has supplied one of the predefined cohort sets
-
- # fix the string to match columns
- if not cohorts.startswith("cohort_"):
- cohorts = "cohort_" + cohorts
-
- # check the given cohort set exists
- if cohorts not in df_samples.columns:
- raise ValueError(f"{cohorts!r} is not a known cohort set")
- cohort_labels = df_samples[cohorts].unique()
-
- # remove the nans and sort
- cohort_labels = sorted([c for c in cohort_labels if isinstance(c, str)])
- for coh in cohort_labels:
- loc_coh = df_samples[cohorts] == coh
- coh_dict[coh] = loc_coh.values
-
- return coh_dict
-
- @staticmethod
- def _make_sample_period_month(row):
- year = row.year
- month = row.month
- if year > 0 and month > 0:
- return pd.Period(freq="M", year=year, month=month)
- else:
- return pd.NaT
-
- @staticmethod
- def _make_sample_period_quarter(row):
- year = row.year
- month = row.month
- if year > 0 and month > 0:
- return pd.Period(freq="Q", year=year, month=month)
- else:
- return pd.NaT
-
- @staticmethod
- def _make_sample_period_year(row):
- year = row.year
- if year > 0:
- return pd.Period(freq="A", year=year)
- else:
- return pd.NaT
-
- @staticmethod
- @numba.njit
- def _cohort_alt_allele_counts_melt_kernel(gt, indices, max_allele):
- n_variants = gt.shape[0]
- n_indices = indices.shape[0]
- ploidy = gt.shape[2]
-
- ac_alt_melt = np.zeros(n_variants * max_allele, dtype=np.int64)
- an = np.zeros(n_variants, dtype=np.int64)
-
- for i in range(n_variants):
- out_i_offset = (i * max_allele) - 1
- for j in range(n_indices):
- ix = indices[j]
- for k in range(ploidy):
- allele = gt[i, ix, k]
- if allele > 0:
- out_i = out_i_offset + allele
- ac_alt_melt[out_i] += 1
- an[i] += 1
- elif allele == 0:
- an[i] += 1
-
- return ac_alt_melt, an
-
- @staticmethod
- def _make_snp_label(contig, position, ref_allele, alt_allele):
- return f"{contig}:{position:,} {ref_allele}>{alt_allele}"
-
- @staticmethod
- def _make_snp_label_effect(contig, position, ref_allele, alt_allele, aa_change):
- label = f"{contig}:{position:,} {ref_allele}>{alt_allele}"
- if isinstance(aa_change, str):
- label += f" ({aa_change})"
- return label
-
- @staticmethod
- def _make_snp_label_aa(aa_change, contig, position, ref_allele, alt_allele):
- label = f"{aa_change} ({contig}:{position:,} {ref_allele}>{alt_allele})"
- return label
-
@staticmethod
def _make_gene_cnv_label(gene_id, gene_name, cnv_type):
label = gene_id
@@ -528,110 +222,6 @@ def _make_gene_cnv_label(gene_id, gene_name, cnv_type):
label += f" {cnv_type}"
return label
- @staticmethod
- def _map_snp_to_aa_change_frq_ds(ds):
- # keep only variables that make sense for amino acid substitutions
- keep_vars = [
- "variant_contig",
- "variant_position",
- "variant_transcript",
- "variant_effect",
- "variant_impact",
- "variant_aa_pos",
- "variant_aa_change",
- "variant_ref_allele",
- "variant_ref_aa",
- "variant_alt_aa",
- "event_nobs",
- ]
-
- if ds.sizes["variants"] == 1:
- # keep everything as-is, no need for aggregation
- ds_out = ds[keep_vars + ["variant_alt_allele", "event_count"]]
-
- else:
- # take the first value from all variants variables
- ds_out = ds[keep_vars].isel(variants=[0])
-
- # sum event count over variants
- count = ds["event_count"].values.sum(axis=0, keepdims=True)
- ds_out["event_count"] = ("variants", "cohorts"), count
-
- # collapse alt allele
- alt_allele = "{" + ",".join(ds["variant_alt_allele"].values) + "}"
- ds_out["variant_alt_allele"] = (
- "variants",
- np.array([alt_allele], dtype=object),
- )
-
- return ds_out
-
- @staticmethod
- def _add_frequency_ci(ds, ci_method):
- from statsmodels.stats.proportion import proportion_confint # type: ignore
-
- if ci_method is not None:
- count = ds["event_count"].values
- nobs = ds["event_nobs"].values
- with np.errstate(divide="ignore", invalid="ignore"):
- frq_ci_low, frq_ci_upp = proportion_confint(
- count=count, nobs=nobs, method=ci_method
- )
- ds["event_frequency_ci_low"] = ("variants", "cohorts"), frq_ci_low
- ds["event_frequency_ci_upp"] = ("variants", "cohorts"), frq_ci_upp
-
- @staticmethod
- def _build_cohorts_from_sample_grouping(group_samples_by_cohort, min_cohort_size):
- # build cohorts dataframe
- df_cohorts = group_samples_by_cohort.agg(
- size=("sample_id", len),
- lat_mean=("latitude", "mean"),
- lat_max=("latitude", "mean"),
- lat_min=("latitude", "mean"),
- lon_mean=("longitude", "mean"),
- lon_max=("longitude", "mean"),
- lon_min=("longitude", "mean"),
- )
- # reset index so that the index fields are included as columns
- df_cohorts = df_cohorts.reset_index()
-
- # add cohort helper variables
- cohort_period_start = df_cohorts["period"].apply(lambda v: v.start_time)
- cohort_period_end = df_cohorts["period"].apply(lambda v: v.end_time)
- df_cohorts["period_start"] = cohort_period_start
- df_cohorts["period_end"] = cohort_period_end
- # create a label that is similar to the cohort metadata,
- # although this won't be perfect
- df_cohorts["label"] = df_cohorts.apply(
- lambda v: f"{v.area}_{v.taxon[:4]}_{v.period}", axis="columns"
- )
-
- # apply minimum cohort size
- df_cohorts = df_cohorts.query(f"size >= {min_cohort_size}").reset_index(
- drop=True
- )
-
- return df_cohorts
-
- @staticmethod
- def _check_param_min_cohort_size(min_cohort_size):
- if not isinstance(min_cohort_size, int):
- raise TypeError(
- f"Type of parameter min_cohort_size must be int; found {type(min_cohort_size)}."
- )
- if min_cohort_size < 1:
- raise ValueError(
- f"Value of parameter min_cohort_size must be at least 1; found {min_cohort_size}."
- )
-
- @staticmethod
- def _pandas_apply(f, df, columns):
- """Optimised alternative to pandas apply."""
- df = df.reset_index(drop=True)
- iterator = zip(*[df[c].values for c in columns])
- ret = pd.Series((f(*vals) for vals in iterator))
- return ret
-
@staticmethod
def _roh_hmm_predict(
*,
@@ -701,136 +291,6 @@ def _roh_hmm_predict(
]
]
- def _snp_df(self, *, transcript: str) -> Tuple[Region, pd.DataFrame]:
- """Set up a dataframe with SNP site and filter columns."""
- debug = self._log.debug
-
- debug("get feature direct from genome_features")
- gs = self.genome_features()
-
- with self._spinner(desc="Prepare SNP data"):
- feature = gs[gs["ID"] == transcript].squeeze()
- if feature.empty:
- raise ValueError(
- f"No genome feature ID found matching transcript {transcript}"
- )
- contig = feature.contig
- region = Region(contig, feature.start, feature.end)
-
- debug("grab pos, ref and alt for chrom arm from snp_sites")
- pos = self.snp_sites(region=contig, field="POS")
- ref = self.snp_sites(region=contig, field="REF")
- alt = self.snp_sites(region=contig, field="ALT")
- loc_feature = locate_region(region, pos)
- pos = pos[loc_feature].compute()
- ref = ref[loc_feature].compute()
- alt = alt[loc_feature].compute()
-
- debug("access site filters")
- filter_pass = dict()
- masks = self.site_mask_ids
- for m in masks:
- x = self.site_filters(region=contig, mask=m)
- x = x[loc_feature].compute()
- filter_pass[m] = x
-
- debug("set up columns with contig, pos, ref, alt columns")
- cols = {
- "contig": contig,
- "position": np.repeat(pos, 3),
- "ref_allele": np.repeat(ref.astype("U1"), 3),
- "alt_allele": alt.astype("U1").flatten(),
- }
-
- debug("add mask columns")
- for m in masks:
- x = filter_pass[m]
- cols[f"pass_{m}"] = np.repeat(x, 3)
-
- debug("construct dataframe")
- df_snps = pd.DataFrame(cols)
-
- return region, df_snps
-
- def _annotator(self):
- """Set up variant effect annotator."""
- if self._cache_annotator is None:
- self._cache_annotator = veff.Annotator(
- genome=self.open_genome(), genome_features=self.genome_features()
- )
- return self._cache_annotator
-
- @check_types
- @doc(
- summary="Compute variant effects for a gene transcript.",
- returns="""
- A dataframe of all possible SNP variants and their effects, one row
- per variant.
- """,
- )
- def snp_effects(
- self,
- transcript: base_params.transcript,
- site_mask: Optional[base_params.site_mask] = None,
- ) -> pd.DataFrame:
- debug = self._log.debug
-
- debug("setup initial dataframe of SNPs")
- _, df_snps = self._snp_df(transcript=transcript)
-
- debug("setup variant effect annotator")
- ann = self._annotator()
-
- debug("apply mask if requested")
- if site_mask is not None:
- loc_sites = df_snps[f"pass_{site_mask}"]
- df_snps = df_snps.loc[loc_sites]
-
- debug("reset index after filtering")
- df_snps.reset_index(inplace=True, drop=True)
-
- debug("add effects to the dataframe")
- ann.get_effects(transcript=transcript, variants=df_snps)
-
- return df_snps
-
- def _cohort_alt_allele_counts_melt(self, gt, indices, max_allele):
- ac_alt_melt, an = self._cohort_alt_allele_counts_melt_kernel(
- gt, indices, max_allele
- )
- an_melt = np.repeat(an, max_allele, axis=0)
- return ac_alt_melt, an_melt
-
- def _prep_samples_for_cohort_grouping(self, *, df_samples, area_by, period_by):
- # take a copy, as we will modify the dataframe
- df_samples = df_samples.copy()
-
- # fix intermediate taxon values - we only want to build cohorts with clean
- # taxon calls, so we set intermediate values to None
- loc_intermediate_taxon = (
- df_samples["taxon"].str.startswith("intermediate").fillna(False)
- )
- df_samples.loc[loc_intermediate_taxon, "taxon"] = None
-
- # add period column
- if period_by == "year":
- make_period = self._make_sample_period_year
- elif period_by == "quarter":
- make_period = self._make_sample_period_quarter
- elif period_by == "month":
- make_period = self._make_sample_period_month
- else:
- raise ValueError(
- f"Value for period_by parameter must be one of 'year', 'quarter', 'month'; found {period_by!r}."
- )
- sample_period = df_samples.apply(make_period, axis="columns")
- df_samples["period"] = sample_period
-
- # add area column for consistent output
- df_samples["area"] = df_samples[area_by]
-
- return df_samples
-
def _plot_heterozygosity_track(
self,
*,
@@ -1356,352 +816,6 @@ def plot_roh(
else:
return fig_all
- @check_types
- @doc(
- summary="""
- Compute SNP allele frequencies for a gene transcript.
- """,
- returns="""
- A dataframe of SNP allele frequencies, one row per variant allele.
- """,
- notes="""
- Cohorts with fewer samples than `min_cohort_size` will be excluded from
- output data frame.
- """,
- )
- def snp_allele_frequencies(
- self,
- transcript: base_params.transcript,
- cohorts: base_params.cohorts,
- sample_query: Optional[base_params.sample_query] = None,
- min_cohort_size: base_params.min_cohort_size = 10,
- site_mask: Optional[base_params.site_mask] = None,
- sample_sets: Optional[base_params.sample_sets] = None,
- drop_invariant: frq_params.drop_invariant = True,
- effects: frq_params.effects = True,
- ) -> pd.DataFrame:
- debug = self._log.debug
-
- debug("check parameters")
- self._check_param_min_cohort_size(min_cohort_size)
-
- debug("access sample metadata")
- df_samples = self.sample_metadata(
- sample_sets=sample_sets, sample_query=sample_query
- )
-
- debug("setup initial dataframe of SNPs")
- region, df_snps = self._snp_df(transcript=transcript)
-
- debug("get genotypes")
- gt = self.snp_genotypes(
- region=region,
- sample_sets=sample_sets,
- sample_query=sample_query,
- field="GT",
- )
-
- debug("slice to feature location")
- with self._dask_progress(desc="Load SNP genotypes"):
- gt = gt.compute()
-
- debug("build coh dict")
- coh_dict = self._locate_cohorts(cohorts=cohorts, df_samples=df_samples)
-
- debug("count alleles")
- freq_cols = dict()
- cohorts_iterator = self._progress(
- coh_dict.items(), desc="Compute allele frequencies"
- )
- for coh, loc_coh in cohorts_iterator:
- n_samples = np.count_nonzero(loc_coh)
- debug(f"{coh}, {n_samples} samples")
- if n_samples >= min_cohort_size:
- gt_coh = np.compress(loc_coh, gt, axis=1)
- ac_coh = allel.GenotypeArray(gt_coh).count_alleles(max_allele=3)
- af_coh = ac_coh.to_frequencies()
- freq_cols["frq_" + coh] = af_coh[:, 1:].flatten()
-
- debug("build a dataframe with the frequency columns")
- df_freqs = pd.DataFrame(freq_cols)
-
- debug("compute max_af")
- df_max_af = pd.DataFrame({"max_af": df_freqs.max(axis=1)})
-
- debug("build the final dataframe")
- df_snps.reset_index(drop=True, inplace=True)
- df_snps = pd.concat([df_snps, df_freqs, df_max_af], axis=1)
-
- debug("apply site mask if requested")
- if site_mask is not None:
- loc_sites = df_snps[f"pass_{site_mask}"]
- df_snps = df_snps.loc[loc_sites]
-
- debug("drop invariants")
- if drop_invariant:
- loc_variant = df_snps["max_af"] > 0
- df_snps = df_snps.loc[loc_variant]
-
- debug("reset index after filtering")
- df_snps.reset_index(inplace=True, drop=True)
-
- if effects:
- debug("add effect annotations")
- ann = self._annotator()
- ann.get_effects(
- transcript=transcript, variants=df_snps, progress=self._progress
- )
-
- debug("add label")
- df_snps["label"] = self._pandas_apply(
- self._make_snp_label_effect,
- df_snps,
- columns=["contig", "position", "ref_allele", "alt_allele", "aa_change"],
- )
-
- debug("set index")
- df_snps.set_index(
- ["contig", "position", "ref_allele", "alt_allele", "aa_change"],
- inplace=True,
- )
-
- else:
- debug("add label")
- df_snps["label"] = self._pandas_apply(
- self._make_snp_label,
- df_snps,
- columns=["contig", "position", "ref_allele", "alt_allele"],
- )
-
- debug("set index")
- df_snps.set_index(
- ["contig", "position", "ref_allele", "alt_allele"],
- inplace=True,
- )
-
- debug("add dataframe metadata")
- gene_name = self._transcript_to_gene_name(transcript)
- title = transcript
- if gene_name:
- title += f" ({gene_name})"
- title += " SNP frequencies"
- df_snps.attrs["title"] = title
-
- return df_snps
-
- @check_types
- @doc(
- summary="""
- Compute amino acid substitution frequencies for a gene transcript.
- """,
- returns="""
- A dataframe of amino acid allele frequencies, one row per
- substitution.
- """,
- notes="""
- Cohorts with fewer samples than `min_cohort_size` will be excluded from
- output data frame.
- """,
- )
- def aa_allele_frequencies(
- self,
- transcript: base_params.transcript,
- cohorts: base_params.cohorts,
- sample_query: Optional[base_params.sample_query] = None,
- min_cohort_size: Optional[base_params.min_cohort_size] = 10,
- site_mask: Optional[base_params.site_mask] = None,
- sample_sets: Optional[base_params.sample_sets] = None,
- drop_invariant: frq_params.drop_invariant = True,
- ) -> pd.DataFrame:
- debug = self._log.debug
-
- df_snps = self.snp_allele_frequencies(
- transcript=transcript,
- cohorts=cohorts,
- sample_query=sample_query,
- min_cohort_size=min_cohort_size,
- site_mask=site_mask,
- sample_sets=sample_sets,
- drop_invariant=drop_invariant,
- effects=True,
- )
- df_snps.reset_index(inplace=True)
-
- # we just want aa change
- df_ns_snps = df_snps.query(AA_CHANGE_QUERY).copy()
-
- # N.B., we need to worry about the possibility of the
- # same aa change due to SNPs at different positions. We cannot
- # sum frequencies of SNPs at different genomic positions. This
- # is why we group by position and aa_change, not just aa_change.
-
- debug("group and sum to collapse multi variant allele changes")
- freq_cols = [col for col in df_ns_snps if col.startswith("frq")]
- agg: Dict[str, Union[Callable, str]] = {c: np.nansum for c in freq_cols}
- keep_cols = (
- "contig",
- "transcript",
- "aa_pos",
- "ref_allele",
- "ref_aa",
- "alt_aa",
- "effect",
- "impact",
- )
- for c in keep_cols:
- agg[c] = "first"
- agg["alt_allele"] = lambda v: "{" + ",".join(v) + "}" if len(v) > 1 else v
- df_aaf = df_ns_snps.groupby(["position", "aa_change"]).agg(agg).reset_index()
-
- debug("compute new max_af")
- df_aaf["max_af"] = df_aaf[freq_cols].max(axis=1)
-
- debug("add label")
- df_aaf["label"] = self._pandas_apply(
- self._make_snp_label_aa,
- df_aaf,
- columns=["aa_change", "contig", "position", "ref_allele", "alt_allele"],
- )
-
- debug("sort by genomic position")
- df_aaf = df_aaf.sort_values(["position", "aa_change"])
-
- debug("set index")
- df_aaf.set_index(["aa_change", "contig", "position"], inplace=True)
-
- debug("add metadata")
- gene_name = self._transcript_to_gene_name(transcript)
- title = transcript
- if gene_name:
- title += f" ({gene_name})"
- title += " SNP frequencies"
- df_aaf.attrs["title"] = title
-
- return df_aaf
-
- @check_types
- @doc(
- summary="""
- Group samples by taxon, area (space) and period (time), then compute
- amino acid change allele frequencies.
- """,
- returns="""
- The resulting dataset contains data has dimensions "cohorts" and
- "variants". Variables prefixed with "cohort" are 1-dimensional
- arrays with data about the cohorts, such as the area, period, taxon
- and cohort size. Variables prefixed with "variant" are
- 1-dimensional arrays with data about the variants, such as the
- contig, position, reference and alternate alleles. Variables
- prefixed with "event" are 2-dimensional arrays with the allele
- counts and frequency calculations.
- """,
- )
- def aa_allele_frequencies_advanced(
- self,
- transcript: base_params.transcript,
- area_by: frq_params.area_by,
- period_by: frq_params.period_by,
- sample_sets: Optional[base_params.sample_sets] = None,
- sample_query: Optional[base_params.sample_query] = None,
- min_cohort_size: base_params.min_cohort_size = 10,
- variant_query: Optional[frq_params.variant_query] = None,
- site_mask: Optional[base_params.site_mask] = None,
- nobs_mode: frq_params.nobs_mode = "called",
- ci_method: Optional[frq_params.ci_method] = "wilson",
- ) -> xr.Dataset:
- debug = self._log.debug
-
- debug("begin by computing SNP allele frequencies")
- ds_snp_frq = self.snp_allele_frequencies_advanced(
- transcript=transcript,
- area_by=area_by,
- period_by=period_by,
- sample_sets=sample_sets,
- sample_query=sample_query,
- min_cohort_size=min_cohort_size,
- drop_invariant=True, # always drop invariant for aa frequencies
- variant_query=AA_CHANGE_QUERY, # we'll also apply a variant query later
- site_mask=site_mask,
- nobs_mode=nobs_mode,
- ci_method=None, # we will recompute confidence intervals later
- )
-
- # N.B., we need to worry about the possibility of the
- # same aa change due to SNPs at different positions. We cannot
- # sum frequencies of SNPs at different genomic positions. This
- # is why we group by position and aa_change, not just aa_change.
-
- # add in a special grouping column to work around the fact that xarray currently
- # doesn't support grouping by multiple variables in the same dimension
- df_grouper = ds_snp_frq[
- ["variant_position", "variant_aa_change"]
- ].to_dataframe()
- grouper_var = df_grouper.apply(
- lambda row: "_".join([str(v) for v in row]), axis="columns"
- )
- ds_snp_frq["variant_position_aa_change"] = "variants", grouper_var
-
- debug("group by position and amino acid change")
- group_by_aa_change = ds_snp_frq.groupby("variant_position_aa_change")
-
- debug("apply aggregation")
- ds_aa_frq = group_by_aa_change.map(self._map_snp_to_aa_change_frq_ds)
-
- debug("add back in cohort variables, unaffected by aggregation")
- cohort_vars = [v for v in ds_snp_frq if v.startswith("cohort_")]
- for v in cohort_vars:
- ds_aa_frq[v] = ds_snp_frq[v]
-
- debug("sort by genomic position")
- ds_aa_frq = ds_aa_frq.sortby(["variant_position", "variant_aa_change"])
-
- debug("recompute frequency")
- count = ds_aa_frq["event_count"].values
- nobs = ds_aa_frq["event_nobs"].values
- with np.errstate(divide="ignore", invalid="ignore"):
- frequency = count / nobs # ignore division warnings
- ds_aa_frq["event_frequency"] = ("variants", "cohorts"), frequency
-
- debug("recompute max frequency over cohorts")
- with warnings.catch_warnings():
- # ignore "All-NaN slice encountered" warnings
- warnings.simplefilter("ignore", category=RuntimeWarning)
- max_af = np.nanmax(ds_aa_frq["event_frequency"].values, axis=1)
- ds_aa_frq["variant_max_af"] = "variants", max_af
-
- debug("set up variant dataframe, useful intermediate")
- variant_cols = [v for v in ds_aa_frq if v.startswith("variant_")]
- df_variants = ds_aa_frq[variant_cols].to_dataframe()
- df_variants.columns = [c.split("variant_")[1] for c in df_variants.columns]
-
- debug("assign new variant label")
- label = self._pandas_apply(
- self._make_snp_label_aa,
- df_variants,
- columns=["aa_change", "contig", "position", "ref_allele", "alt_allele"],
- )
- ds_aa_frq["variant_label"] = "variants", label
-
- debug("apply variant query if given")
- if variant_query is not None:
- loc_variants = df_variants.eval(variant_query).values
- ds_aa_frq = ds_aa_frq.isel(variants=loc_variants)
-
- debug("compute new confidence intervals")
- self._add_frequency_ci(ds_aa_frq, ci_method)
-
- debug("tidy up display by sorting variables")
- ds_aa_frq = ds_aa_frq[sorted(ds_aa_frq)]
-
- gene_name = self._transcript_to_gene_name(transcript)
- title = transcript
- if gene_name:
- title += f" ({gene_name})"
- title += " SNP frequencies"
- ds_aa_frq.attrs["title"] = title
-
- return ds_aa_frq
-
def gene_cnv(
self,
region: base_params.regions,
@@ -1889,7 +1003,6 @@ def gene_cnv_frequencies(
debug = self._log.debug
debug("check and normalise parameters")
- self._check_param_min_cohort_size(min_cohort_size)
regions: List[Region] = parse_multi_region(self, region)
del region
@@ -1996,7 +1109,7 @@ def _gene_cnv_frequencies(
is_called = cn >= 0
debug("set up cohort dict")
- coh_dict = self._locate_cohorts(cohorts=cohorts, df_samples=df_samples)
+ coh_dict = locate_cohorts(cohorts=cohorts, data=df_samples)
debug("compute cohort frequencies")
freq_cols = dict()
@@ -2046,7 +1159,7 @@ def _gene_cnv_frequencies(
df.reset_index(drop=True, inplace=True)
debug("add label")
- df["label"] = self._pandas_apply(
+ df["label"] = pandas_apply(
self._make_gene_cnv_label, df, columns=["gene_id", "gene_name", "cnv_type"]
)
@@ -2123,8 +1236,6 @@ def gene_cnv_frequencies_advanced(
"""
- self._check_param_min_cohort_size(min_cohort_size)
-
regions: List[Region] = parse_multi_region(self, region)
del region
@@ -2187,7 +1298,7 @@ def _gene_cnv_frequencies_advanced(
df_samples = df_samples.set_index("sample_id").loc[sample_id].reset_index()
debug("prepare sample metadata for cohort grouping")
- df_samples = self._prep_samples_for_cohort_grouping(
+ df_samples = _prep_samples_for_cohort_grouping(
df_samples=df_samples,
area_by=area_by,
period_by=period_by,
@@ -2197,8 +1308,9 @@ def _gene_cnv_frequencies_advanced(
group_samples_by_cohort = df_samples.groupby(["taxon", "area", "period"])
debug("build cohorts dataframe")
- df_cohorts = self._build_cohorts_from_sample_grouping(
- group_samples_by_cohort, min_cohort_size
+ df_cohorts = _build_cohorts_from_sample_grouping(
+ group_samples_by_cohort=group_samples_by_cohort,
+ min_cohort_size=min_cohort_size,
)
debug("figure out expected copy number")
@@ -2267,7 +1379,7 @@ def _gene_cnv_frequencies_advanced(
)
debug("add variant label")
- df_variants["label"] = self._pandas_apply(
+ df_variants["label"] = pandas_apply(
self._make_gene_cnv_label,
df_variants,
columns=["gene_id", "gene_name", "cnv_type"],
@@ -2301,7 +1413,7 @@ def _gene_cnv_frequencies_advanced(
ds_out = ds_out.isel(variants=loc_variants)
debug("add confidence intervals")
- self._add_frequency_ci(ds_out, ci_method)
+ _add_frequency_ci(ds=ds_out, ci_method=ci_method)
debug("tidy up display by sorting variables")
ds_out = ds_out[sorted(ds_out)]
@@ -3021,441 +2133,6 @@ def _fst_gwss(
return results
- @check_types
- @doc(
- summary="""
- Plot a heatmap from a pandas DataFrame of frequencies, e.g., output
- from `snp_allele_frequencies()` or `gene_cnv_frequencies()`.
- """,
- parameters=dict(
- df="""
- A DataFrame of frequencies, e.g., output from
- `snp_allele_frequencies()` or `gene_cnv_frequencies()`.
- """,
- index="""
- One or more column headers that are present in the input dataframe.
- This becomes the heatmap y-axis row labels. The column/s must
- produce a unique index.
- """,
- max_len="""
- Displaying large styled dataframes may cause ipython notebooks to
- crash. If the input dataframe is larger than this value, an error
- will be raised.
- """,
- col_width="""
- Plot width per column in pixels (px).
- """,
- row_height="""
- Plot height per row in pixels (px).
- """,
- kwargs="""
- Passed through to `px.imshow()`.
- """,
- ),
- notes="""
- It's recommended to filter the input DataFrame to just rows of interest,
- i.e., fewer rows than `max_len`.
- """,
- )
- def plot_frequencies_heatmap(
- self,
- df: pd.DataFrame,
- index: Union[str, List[str]] = "label",
- max_len: Optional[int] = 100,
- col_width: int = 40,
- row_height: int = 20,
- x_label: plotly_params.x_label = "Cohorts",
- y_label: plotly_params.y_label = "Variants",
- colorbar: plotly_params.colorbar = True,
- width: plotly_params.width = None,
- height: plotly_params.height = None,
- text_auto: plotly_params.text_auto = ".0%",
- aspect: plotly_params.aspect = "auto",
- color_continuous_scale: plotly_params.color_continuous_scale = "Reds",
- title: plotly_params.title = True,
- show: plotly_params.show = True,
- renderer: plotly_params.renderer = None,
- **kwargs,
- ) -> plotly_params.figure:
- debug = self._log.debug
-
- debug("check len of input")
- if max_len and len(df) > max_len:
- raise ValueError(f"Input DataFrame is longer than {max_len}")
-
- debug("handle title")
- if title is True:
- title = df.attrs.get("title", None)
-
- debug("indexing")
- if index is None:
- index = list(df.index.names)
- df = df.reset_index().copy()
- if isinstance(index, list):
- index_col = (
- df[index]
- .astype(str)
- .apply(
- lambda row: ", ".join([o for o in row if o is not None]),
- axis="columns",
- )
- )
- elif isinstance(index, str):
- index_col = df[index].astype(str)
- else:
- raise TypeError("wrong type for index parameter, expected list or str")
-
- debug("check that index is unique")
- if not index_col.is_unique:
- raise ValueError(f"{index} does not produce a unique index")
-
- debug("drop and re-order columns")
- frq_cols = [col for col in df.columns if col.startswith("frq_")]
-
- debug("keep only freq cols")
- heatmap_df = df[frq_cols].copy()
-
- debug("set index")
- heatmap_df.set_index(index_col, inplace=True)
-
- debug("clean column names")
- heatmap_df.columns = heatmap_df.columns.str.lstrip("frq_")
-
- debug("deal with width and height")
- if width is None:
- width = 400 + col_width * len(heatmap_df.columns)
- if colorbar:
- width += 40
- if height is None:
- height = 200 + row_height * len(heatmap_df)
- if title is not None:
- height += 40
-
- debug("plotly heatmap styling")
- fig = px.imshow(
- img=heatmap_df,
- zmin=0,
- zmax=1,
- width=width,
- height=height,
- text_auto=text_auto,
- aspect=aspect,
- color_continuous_scale=color_continuous_scale,
- title=title,
- **kwargs,
- )
-
- fig.update_xaxes(side="bottom", tickangle=30)
- if x_label is not None:
- fig.update_xaxes(title=x_label)
- if y_label is not None:
- fig.update_yaxes(title=y_label)
- fig.update_layout(
- coloraxis_colorbar=dict(
- title="Frequency",
- tickvals=[0, 0.2, 0.4, 0.6, 0.8, 1.0],
- ticktext=["0%", "20%", "40%", "60%", "80%", "100%"],
- )
- )
- if not colorbar:
- fig.update(layout_coloraxis_showscale=False)
-
- if show: # pragma: no cover
- fig.show(renderer=renderer)
- return None
- else:
- return fig
-
- @check_types
- @doc(
- summary="Create a time series plot of variant frequencies using plotly.",
- parameters=dict(
- ds="""
- A dataset of variant frequencies, such as returned by
- `snp_allele_frequencies_advanced()`,
- `aa_allele_frequencies_advanced()` or
- `gene_cnv_frequencies_advanced()`.
- """,
- kwargs="Passed through to `px.line()`.",
- ),
- returns="""
- A plotly figure containing line graphs. The resulting figure will
- have one panel per cohort, grouped into columns by taxon, and
- grouped into rows by area. Markers and lines show frequencies of
- variants.
- """,
- )
- def plot_frequencies_time_series(
- self,
- ds: xr.Dataset,
- height: plotly_params.height = None,
- width: plotly_params.width = None,
- title: plotly_params.title = True,
- legend_sizing: plotly_params.legend_sizing = "constant",
- show: plotly_params.show = True,
- renderer: plotly_params.renderer = None,
- **kwargs,
- ) -> plotly_params.figure:
- debug = self._log.debug
-
- debug("handle title")
- if title is True:
- title = ds.attrs.get("title", None)
-
- debug("extract cohorts into a dataframe")
- cohort_vars = [v for v in ds if str(v).startswith("cohort_")]
- df_cohorts = ds[cohort_vars].to_dataframe()
- df_cohorts.columns = [c.split("cohort_")[1] for c in df_cohorts.columns] # type: ignore
-
- debug("extract variant labels")
- variant_labels = ds["variant_label"].values
-
- debug("build a long-form dataframe from the dataset")
- dfs = []
- for cohort_index, cohort in enumerate(df_cohorts.itertuples()):
- ds_cohort = ds.isel(cohorts=cohort_index)
- df = pd.DataFrame(
- {
- "taxon": cohort.taxon,
- "area": cohort.area,
- "date": cohort.period_start,
- "period": str(
- cohort.period
- ), # use string representation for hover label
- "sample_size": cohort.size,
- "variant": variant_labels,
- "count": ds_cohort["event_count"].values,
- "nobs": ds_cohort["event_nobs"].values,
- "frequency": ds_cohort["event_frequency"].values,
- "frequency_ci_low": ds_cohort["event_frequency_ci_low"].values,
- "frequency_ci_upp": ds_cohort["event_frequency_ci_upp"].values,
- }
- )
- dfs.append(df)
- df_events = pd.concat(dfs, axis=0).reset_index(drop=True)
-
- debug("remove events with no observations")
- df_events = df_events.query("nobs > 0")
-
- debug("calculate error bars")
- frq = df_events["frequency"]
- frq_ci_low = df_events["frequency_ci_low"]
- frq_ci_upp = df_events["frequency_ci_upp"]
- df_events["frequency_error"] = frq_ci_upp - frq
- df_events["frequency_error_minus"] = frq - frq_ci_low
-
- debug("make a plot")
- fig = px.line(
- df_events,
- facet_col="taxon",
- facet_row="area",
- x="date",
- y="frequency",
- error_y="frequency_error",
- error_y_minus="frequency_error_minus",
- color="variant",
- markers=True,
- hover_name="variant",
- hover_data={
- "frequency": ":.0%",
- "period": True,
- "area": True,
- "taxon": True,
- "sample_size": True,
- "date": False,
- "variant": False,
- },
- height=height,
- width=width,
- title=title,
- labels={
- "date": "Date",
- "frequency": "Frequency",
- "variant": "Variant",
- "taxon": "Taxon",
- "area": "Area",
- "period": "Period",
- "sample_size": "Sample size",
- },
- **kwargs,
- )
-
- debug("tidy plot")
- fig.update_layout(
- yaxis_range=[-0.05, 1.05],
- legend=dict(itemsizing=legend_sizing, tracegroupgap=0),
- )
-
- if show: # pragma: no cover
- fig.show(renderer=renderer)
- return None
- else:
- return fig
-
- @check_types
- @doc(
- summary="""
- Plot markers on a map showing variant frequencies for cohorts grouped
- by area (space), period (time) and taxon.
- """,
- parameters=dict(
- m="The map on which to add the markers.",
- variant="Index or label of variant to plot.",
- taxon="Taxon to show markers for.",
- period="Time period to show markers for.",
- clear="""
- If True, clear all layers (except the base layer) from the map
- before adding new markers.
- """,
- ),
- )
- def plot_frequencies_map_markers(
- self,
- m,
- ds: frq_params.ds_frequencies_advanced,
- variant: Union[int, str],
- taxon: str,
- period: pd.Period,
- clear: bool = True,
- ):
- debug = self._log.debug
- # only import here because of some problems importing globally
- import ipyleaflet # type: ignore
- import ipywidgets # type: ignore
-
- debug("slice dataset to variant of interest")
- if isinstance(variant, int):
- ds_variant = ds.isel(variants=variant)
- variant_label = ds["variant_label"].values[variant]
- elif isinstance(variant, str):
- ds_variant = ds.set_index(variants="variant_label").sel(variants=variant)
- variant_label = variant
- else:
- raise TypeError(
- f"Bad type for variant parameter; expected int or str, found {type(variant)}."
- )
-
- debug("convert to a dataframe for convenience")
- df_markers = ds_variant[
- [
- "cohort_taxon",
- "cohort_area",
- "cohort_period",
- "cohort_lat_mean",
- "cohort_lon_mean",
- "cohort_size",
- "event_frequency",
- "event_frequency_ci_low",
- "event_frequency_ci_upp",
- ]
- ].to_dataframe()
-
- debug("select data matching taxon and period parameters")
- df_markers = df_markers.loc[
- (
- (df_markers["cohort_taxon"] == taxon)
- & (df_markers["cohort_period"] == period)
- )
- ]
-
- debug("clear existing layers in the map")
- if clear:
- for layer in m.layers[1:]:
- m.remove_layer(layer)
-
- debug("add markers")
- for x in df_markers.itertuples():
- marker = ipyleaflet.CircleMarker()
- marker.location = (x.cohort_lat_mean, x.cohort_lon_mean)
- marker.radius = 20
- marker.color = "black"
- marker.weight = 1
- marker.fill_color = "red"
- marker.fill_opacity = x.event_frequency
- popup_html = f"""
- {variant_label}
- Taxon: {x.cohort_taxon}
- Area: {x.cohort_area}
- Period: {x.cohort_period}
- Sample size: {x.cohort_size}
- Frequency: {x.event_frequency:.0%}
- (95% CI: {x.event_frequency_ci_low:.0%} - {x.event_frequency_ci_upp:.0%})
- """
- marker.popup = ipyleaflet.Popup(
- child=ipywidgets.HTML(popup_html),
- )
- m.add_layer(marker)
-
- @check_types
- @doc(
- summary="""
- Create an interactive map with markers showing variant frequencies or
- cohorts grouped by area (space), period (time) and taxon.
- """,
- parameters=dict(
- title="""
- If True, attempt to use metadata from input dataset as a plot
- title. Otherwise, use supplied value as a title.
- """,
- epilogue="Additional text to display below the map.",
- ),
- returns="""
- An interactive map with widgets for selecting which variant, taxon
- and time period to display.
- """,
- )
- def plot_frequencies_interactive_map(
- self,
- ds: frq_params.ds_frequencies_advanced,
- center: map_params.center = map_params.center_default,
- zoom: map_params.zoom = map_params.zoom_default,
- title: Union[bool, str] = True,
- epilogue: Union[bool, str] = True,
- ):
- debug = self._log.debug
-
- import ipyleaflet
- import ipywidgets
-
- debug("handle title")
- if title is True:
- title = ds.attrs.get("title", None)
-
- debug("create a map")
- freq_map = ipyleaflet.Map(center=center, zoom=zoom)
-
- debug("set up interactive controls")
- variants = ds["variant_label"].values
- taxa = np.unique(ds["cohort_taxon"].values)
- periods = np.unique(ds["cohort_period"].values)
- controls = ipywidgets.interactive(
- self.plot_frequencies_map_markers,
- m=ipywidgets.fixed(freq_map),
- ds=ipywidgets.fixed(ds),
- variant=ipywidgets.Dropdown(options=variants, description="Variant: "),
- taxon=ipywidgets.Dropdown(options=taxa, description="Taxon: "),
- period=ipywidgets.Dropdown(options=periods, description="Period: "),
- clear=ipywidgets.fixed(True),
- )
-
- debug("lay out widgets")
- components = []
- if title is not None:
- components.append(ipywidgets.HTML(value=f"{title}
"))
- components.append(controls)
- components.append(freq_map)
- if epilogue is True:
- epilogue = """
- Variant frequencies are shown as coloured markers. Opacity of color
- denotes frequency. Click on a marker for more information.
- """
- if epilogue:
- components.append(ipywidgets.HTML(value=f"{epilogue}"))
-
- out = ipywidgets.VBox(components)
-
- return out
-
@check_types
@doc(
summary="Plot diversity summary statistics for multiple cohorts.",
diff --git a/malariagen_data/util.py b/malariagen_data/util.py
index d756f8430..8e3c52b6b 100644
--- a/malariagen_data/util.py
+++ b/malariagen_data/util.py
@@ -10,26 +10,27 @@
from textwrap import dedent, fill
from typing import IO, Dict, Hashable, List, Mapping, Optional, Tuple, Union
from urllib.parse import unquote_plus
+from numpy.testing import assert_allclose, assert_array_equal
try:
- from google import colab
+ from google import colab # type: ignore
except ImportError:
colab = None
-import allel
+import allel # type: ignore
import dask.array as da
-import ipinfo
-import numba
+import ipinfo # type: ignore
+import numba # type: ignore
import numpy as np
import pandas
import pandas as pd
-import plotly.express as px
+import plotly.express as px # type: ignore
import typeguard
import xarray as xr
-import zarr
-from fsspec.core import url_to_fs
-from fsspec.mapping import FSMap
-from numpydoc_decorator.impl import humanize_type
+import zarr # type: ignore
+from fsspec.core import url_to_fs # type: ignore
+from fsspec.mapping import FSMap # type: ignore
+from numpydoc_decorator.impl import humanize_type # type: ignore
from typing_extensions import TypeAlias, get_type_hints
DIM_VARIANT = "variants"
@@ -116,8 +117,7 @@ def unpack_gff3_attributes(df: pd.DataFrame, attributes: Tuple[str, ...]):
try:
# zarr >= 2.11.0
- # noinspection PyUnresolvedReferences
- from zarr.storage import KVStore
+ from zarr.storage import KVStore # type: ignore
class SafeStore(KVStore):
def __getitem__(self, key):
@@ -788,7 +788,7 @@ def jackknife_ci(stat_data, jack_stat, confidence_level):
https://github.com/astropy/astropy/blob/8aba9632597e6bb489488109222bf2feff5835a6/astropy/stats/jackknife.py#L55
"""
- from scipy.special import erfinv
+ from scipy.special import erfinv # type: ignore
n = len(jack_stat)
@@ -931,7 +931,7 @@ def check_types(f):
"""
@wraps(f)
- def wrapper(*args, **kwargs):
+ def check_types_wrapper(*args, **kwargs):
type_hints = get_type_hints(f)
call_args = getcallargs(f, *args, **kwargs)
for k, t in type_hints.items():
@@ -955,7 +955,7 @@ def wrapper(*args, **kwargs):
raise error from None
return f(*args, **kwargs)
- return wrapper
+ return check_types_wrapper
@numba.njit
@@ -1151,3 +1151,27 @@ def apply_allele_mapping(x, mapping, max_allele):
out[i, new_allele_index] = x[i, allele_index]
return out
+
+
+def pandas_apply(f, df, columns):
+ """Optimised alternative to pandas apply."""
+ df = df.reset_index(drop=True)
+ iterator = zip(*[df[c].values for c in columns])
+ ret = pd.Series((f(*vals) for vals in iterator))
+ return ret
+
+
+def compare_series_like(actual, expect):
+ """Compare pandas series-like objects for equality or floating point
+ similarity, handling missing values appropriately."""
+
+ # Handle object arrays, these don't get nans compared properly.
+ t = actual.dtype
+ if t == object:
+ expect = expect.fillna("NA")
+ actual = actual.fillna("NA")
+
+ if t.kind == "f":
+ assert_allclose(actual.values, expect.values)
+ else:
+ assert_array_equal(actual.values, expect.values)
diff --git a/malariagen_data/veff.py b/malariagen_data/veff.py
index 2dcb655cc..918553582 100644
--- a/malariagen_data/veff.py
+++ b/malariagen_data/veff.py
@@ -1,7 +1,7 @@
import collections
import operator
-from Bio.Seq import Seq
+from Bio.Seq import Seq # type: ignore
VariantEffect = collections.namedtuple(
"VariantEffect",
@@ -212,8 +212,9 @@ def _get_within_transcript_effect(ann, base_effect, cdss, utr5, utr3, introns):
effect = base_effect._replace(effect="THREE_PRIME_UTR", impact="LOW")
return effect
- # if none of the above
- effect = base_effect._replace(effect="TODO", impact="UNKNOWN")
+ # If none of the above, all we can say is that the variant hits
+ # a transcript.
+ effect = base_effect._replace(effect="TRANSCRIPT", impact="MODIFIER")
return effect
@@ -364,7 +365,9 @@ def _get_within_cds_effect(ann, base_effect, cds, cdss):
else:
# TODO in-frame complex variation (MNP + INDEL)
- effect = base_effect._replace(effect="TODO", impact="UNKNOWN")
+ effect = base_effect._replace(
+ effect="TODO in-frame complex variation (MNP + INDEL)", impact="UNKNOWN"
+ )
return effect
@@ -536,7 +539,7 @@ def _get_within_intron_effect(base_effect, intron):
effect = base_effect._replace(effect="INTRONIC", impact="MODIFIER")
else:
- # TODO INDELs and MNPs
- effect = base_effect._replace(effect="TODO")
+ # TODO intronic INDELs and MNPs
+ effect = base_effect._replace(effect="TODO intronic indels and MNPs")
return effect
diff --git a/notebooks/plot_frequencies_space_time.ipynb b/notebooks/plot_frequencies_space_time.ipynb
index 15512cb5d..a77246fe6 100644
--- a/notebooks/plot_frequencies_space_time.ipynb
+++ b/notebooks/plot_frequencies_space_time.ipynb
@@ -4,14 +4,7 @@
"cell_type": "code",
"execution_count": null,
"id": "f820bc66-2fb2-4ca2-9b54-824e50d61a0a",
- "metadata": {
- "execution": {
- "iopub.execute_input": "2022-12-23T17:48:46.703988Z",
- "iopub.status.busy": "2022-12-23T17:48:46.703072Z",
- "iopub.status.idle": "2022-12-23T17:48:48.388767Z",
- "shell.execute_reply": "2022-12-23T17:48:48.388206Z"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"import malariagen_data\n",
@@ -52,14 +45,7 @@
"cell_type": "code",
"execution_count": null,
"id": "1c612c69-27ee-4f50-b467-786bd998de58",
- "metadata": {
- "execution": {
- "iopub.execute_input": "2022-12-23T17:48:48.391134Z",
- "iopub.status.busy": "2022-12-23T17:48:48.390799Z",
- "iopub.status.idle": "2022-12-23T17:48:49.928704Z",
- "shell.execute_reply": "2022-12-23T17:48:49.928102Z"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"ds = ag3.gene_cnv_frequencies_advanced(\n",
@@ -77,14 +63,7 @@
"cell_type": "code",
"execution_count": null,
"id": "7b38fb26-9562-4f52-9381-ced0687b5a40",
- "metadata": {
- "execution": {
- "iopub.execute_input": "2022-12-23T17:48:49.931065Z",
- "iopub.status.busy": "2022-12-23T17:48:49.930877Z",
- "iopub.status.idle": "2022-12-23T17:48:50.235303Z",
- "shell.execute_reply": "2022-12-23T17:48:50.234797Z"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"ag3.plot_frequencies_time_series(ds, height=500, width=1000)"
@@ -94,14 +73,7 @@
"cell_type": "code",
"execution_count": null,
"id": "fb104e66-7ba0-42ec-9ed0-625a43b69d5b",
- "metadata": {
- "execution": {
- "iopub.execute_input": "2022-12-23T17:48:50.278700Z",
- "iopub.status.busy": "2022-12-23T17:48:50.278350Z",
- "iopub.status.idle": "2022-12-23T17:48:52.814714Z",
- "shell.execute_reply": "2022-12-23T17:48:52.814176Z"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"ds = ag3.gene_cnv_frequencies_advanced(\n",
@@ -118,14 +90,7 @@
"cell_type": "code",
"execution_count": null,
"id": "ccf5279d-fc5a-4efe-972e-60c7b9515558",
- "metadata": {
- "execution": {
- "iopub.execute_input": "2022-12-23T17:48:52.817166Z",
- "iopub.status.busy": "2022-12-23T17:48:52.816909Z",
- "iopub.status.idle": "2022-12-23T17:48:52.905183Z",
- "shell.execute_reply": "2022-12-23T17:48:52.904720Z"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"ag3.plot_frequencies_interactive_map(ds)"
@@ -144,14 +109,7 @@
"cell_type": "code",
"execution_count": null,
"id": "bf75e5b1-f2ca-41f4-a0fd-08a03c979fa4",
- "metadata": {
- "execution": {
- "iopub.execute_input": "2022-12-23T17:48:52.908656Z",
- "iopub.status.busy": "2022-12-23T17:48:52.908467Z",
- "iopub.status.idle": "2022-12-23T17:48:54.913588Z",
- "shell.execute_reply": "2022-12-23T17:48:54.912997Z"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"ds = ag3.aa_allele_frequencies_advanced(\n",
@@ -170,14 +128,7 @@
"cell_type": "code",
"execution_count": null,
"id": "2c6ffd89-7edc-4f90-bd23-b03e15ce9714",
- "metadata": {
- "execution": {
- "iopub.execute_input": "2022-12-23T17:48:54.915626Z",
- "iopub.status.busy": "2022-12-23T17:48:54.915442Z",
- "iopub.status.idle": "2022-12-23T17:48:55.001687Z",
- "shell.execute_reply": "2022-12-23T17:48:55.001166Z"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"ag3.plot_frequencies_time_series(ds, height=400, width=600)"
@@ -187,14 +138,7 @@
"cell_type": "code",
"execution_count": null,
"id": "c058a196-2f78-49a0-8a18-cb2420c3278a",
- "metadata": {
- "execution": {
- "iopub.execute_input": "2022-12-23T17:48:55.003660Z",
- "iopub.status.busy": "2022-12-23T17:48:55.003479Z",
- "iopub.status.idle": "2022-12-23T17:48:59.787903Z",
- "shell.execute_reply": "2022-12-23T17:48:59.787238Z"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"ds = ag3.aa_allele_frequencies_advanced(\n",
@@ -212,14 +156,7 @@
"cell_type": "code",
"execution_count": null,
"id": "e733c877-a1f1-4dc6-b8a3-a9652430ff39",
- "metadata": {
- "execution": {
- "iopub.execute_input": "2022-12-23T17:48:59.790255Z",
- "iopub.status.busy": "2022-12-23T17:48:59.790075Z",
- "iopub.status.idle": "2022-12-23T17:48:59.818333Z",
- "shell.execute_reply": "2022-12-23T17:48:59.817926Z"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"ag3.plot_frequencies_interactive_map(ds)"
@@ -286,14 +223,7 @@
"cell_type": "code",
"execution_count": null,
"id": "c3738721-8d16-40dd-8a4d-4b588af36b15",
- "metadata": {
- "execution": {
- "iopub.execute_input": "2022-12-23T17:48:59.821747Z",
- "iopub.status.busy": "2022-12-23T17:48:59.821571Z",
- "iopub.status.idle": "2022-12-23T17:49:01.166705Z",
- "shell.execute_reply": "2022-12-23T17:49:01.166143Z"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"ds = ag3.snp_allele_frequencies_advanced(\n",
@@ -316,14 +246,7 @@
"cell_type": "code",
"execution_count": null,
"id": "5c743429",
- "metadata": {
- "execution": {
- "iopub.execute_input": "2022-12-23T17:49:01.169029Z",
- "iopub.status.busy": "2022-12-23T17:49:01.168811Z",
- "iopub.status.idle": "2022-12-23T17:49:01.648657Z",
- "shell.execute_reply": "2022-12-23T17:49:01.648133Z"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"ag3.plot_frequencies_time_series(ds, height=900, width=900)"
@@ -333,14 +256,7 @@
"cell_type": "code",
"execution_count": null,
"id": "9f6a7920",
- "metadata": {
- "execution": {
- "iopub.execute_input": "2022-12-23T17:49:01.653225Z",
- "iopub.status.busy": "2022-12-23T17:49:01.653021Z",
- "iopub.status.idle": "2022-12-23T17:49:02.676879Z",
- "shell.execute_reply": "2022-12-23T17:49:02.676318Z"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"ds = ag3.snp_allele_frequencies_advanced(\n",
@@ -359,14 +275,7 @@
"cell_type": "code",
"execution_count": null,
"id": "03cfc0e7",
- "metadata": {
- "execution": {
- "iopub.execute_input": "2022-12-23T17:49:02.678825Z",
- "iopub.status.busy": "2022-12-23T17:49:02.678648Z",
- "iopub.status.idle": "2022-12-23T17:49:02.835617Z",
- "shell.execute_reply": "2022-12-23T17:49:02.835088Z"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"ag3.plot_frequencies_time_series(ds, height=400, width=800)"
@@ -376,14 +285,7 @@
"cell_type": "code",
"execution_count": null,
"id": "28922f73-f38d-424a-a9c5-1e839c86567b",
- "metadata": {
- "execution": {
- "iopub.execute_input": "2022-12-23T17:49:02.837997Z",
- "iopub.status.busy": "2022-12-23T17:49:02.837815Z",
- "iopub.status.idle": "2022-12-23T17:49:06.334806Z",
- "shell.execute_reply": "2022-12-23T17:49:06.334285Z"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"ds = ag3.snp_allele_frequencies_advanced(\n",
@@ -401,14 +303,7 @@
"cell_type": "code",
"execution_count": null,
"id": "2a8ea4a0-6bf7-41a3-b734-06535e7dbbc8",
- "metadata": {
- "execution": {
- "iopub.execute_input": "2022-12-23T17:49:06.337135Z",
- "iopub.status.busy": "2022-12-23T17:49:06.336916Z",
- "iopub.status.idle": "2022-12-23T17:49:06.365703Z",
- "shell.execute_reply": "2022-12-23T17:49:06.365218Z"
- }
- },
+ "metadata": {},
"outputs": [],
"source": [
"ag3.plot_frequencies_interactive_map(ds)"
diff --git a/tests/anoph/conftest.py b/tests/anoph/conftest.py
index e08e374b4..4869cd4e2 100644
--- a/tests/anoph/conftest.py
+++ b/tests/anoph/conftest.py
@@ -73,8 +73,8 @@ def __init__(
n_exons_low=1,
n_exons_high=5,
intron_size_low=10,
- intron_size_high=1_000,
- exon_size_low=10,
+ intron_size_high=100,
+ exon_size_low=100,
exon_size_high=1_000,
source="random",
max_genes=1_000,
@@ -256,15 +256,21 @@ def simulate_exons(
transcript_size = transcript_end - transcript_start
exons = []
exon_end = transcript_start
- for exon_ix in range(randint(self.n_exons_low, self.n_exons_high)):
+ n_exons = randint(self.n_exons_low, self.n_exons_high)
+ for exon_ix in range(n_exons):
exon_id = f"exon-{contig}-{gene_ix}-{transcript_ix}-{exon_ix}"
- intron_size = randint(
- self.intron_size_low, min(transcript_size, self.intron_size_high)
- )
- exon_start = exon_end + intron_size
- if exon_start >= transcript_end:
- # Stop making exons, no more space left in the transcript.
- break
+ if exon_ix > 0:
+ # Insert an intron between this exon and the previous one.
+ intron_size = randint(
+ self.intron_size_low, min(transcript_size, self.intron_size_high)
+ )
+ exon_start = exon_end + intron_size
+ if exon_start >= transcript_end:
+ # Stop making exons, no more space left in the transcript.
+ break
+ else:
+ # First exon, assume exon starts where the transcript starts.
+ exon_start = transcript_start
exon_size = randint(self.exon_size_low, self.exon_size_high)
exon_end = min(exon_start + exon_size, transcript_end)
assert exon_end > exon_start
@@ -282,20 +288,20 @@ def simulate_exons(
yield exon
exons.append(exon)
- # Note that this is not perfect, because sometimes we end up
- # without any CDSs. Also in reality, an exon can contain
- # part of a UTR and part of a CDS, but that is harder to
- # simulate. So keep things simple for now.
+ # Note that this is not perfect, because in reality an exon can contain
+ # part of a UTR and part of a CDS, but that is harder to simulate. So
+ # keep things simple for now.
if strand == "-":
# Take exons in reverse order.
exons == exons[::-1]
for exon_ix, exon in enumerate(exons):
first_exon = exon_ix == 0
last_exon = exon_ix == len(exons) - 1
- if first_exon:
+ # Ensure at least one CDS.
+ if first_exon and len(exons) > 1:
feature_type = self.utr5_type
phase = "."
- elif last_exon:
+ elif last_exon and len(exons) > 2:
feature_type = self.utr3_type
phase = "."
else:
@@ -1005,11 +1011,14 @@ def contigs(self) -> Tuple[str, ...]:
def random_contig(self):
return choice(self.contigs)
- def random_region_str(self):
+ def random_region_str(self, region_size=None):
contig = self.random_contig()
contig_size = self.contig_sizes[contig]
region_start = randint(1, contig_size)
- region_end = randint(region_start, contig_size)
+ if region_size:
+ region_end = region_start + region_size
+ else:
+ region_end = randint(region_start, contig_size)
region = f"{contig}:{region_start:,}-{region_end:,}"
return region
diff --git a/tests/anoph/test_pca.py b/tests/anoph/test_pca.py
index e340cde6a..2261f5b6d 100644
--- a/tests/anoph/test_pca.py
+++ b/tests/anoph/test_pca.py
@@ -2,7 +2,7 @@
import numpy as np
import pandas as pd
-import plotly.graph_objects as go
+import plotly.graph_objects as go # type: ignore
import pytest
from pytest_cases import parametrize_with_cases
diff --git a/tests/anoph/test_sample_metadata.py b/tests/anoph/test_sample_metadata.py
index c3babbae8..1fd2e57fa 100644
--- a/tests/anoph/test_sample_metadata.py
+++ b/tests/anoph/test_sample_metadata.py
@@ -683,14 +683,21 @@ def test_count_samples(fixture, api):
)
-@pytest.mark.parametrize(
- "basemap", ["satellite", None, ipyleaflet.basemaps.OpenTopoMap]
-)
@parametrize_with_cases("fixture,api", cases=".")
-def test_plot_samples_interactive_map(fixture, api, basemap):
- m = api.plot_samples_interactive_map(basemap=basemap)
+def test_plot_samples_interactive_map(fixture, api):
+ # Test behaviour with bad basemap param.
+ with pytest.raises(ValueError):
+ api.plot_samples_interactive_map(basemap="foobar")
+
+ # Default params.
+ m = api.plot_samples_interactive_map()
assert isinstance(m, ipyleaflet.Map)
+ # Explicit params.
+ for basemap in ["satellite", None, ipyleaflet.basemaps.OpenTopoMap]:
+ m = api.plot_samples_interactive_map(basemap=basemap, width=500, height=300)
+ assert isinstance(m, ipyleaflet.Map)
+
@parametrize_with_cases("fixture,api", cases=".")
def test_wgs_data_catalog(fixture, api):
diff --git a/tests/anoph/test_snp_frq.py b/tests/anoph/test_snp_frq.py
new file mode 100644
index 000000000..6d580b10f
--- /dev/null
+++ b/tests/anoph/test_snp_frq.py
@@ -0,0 +1,1469 @@
+import random
+
+import numpy as np
+import pandas as pd
+from pandas.testing import assert_frame_equal
+import pytest
+from pytest_cases import parametrize_with_cases
+import xarray as xr
+from numpy.testing import assert_allclose, assert_array_equal
+import plotly.graph_objects as go # type: ignore
+
+from malariagen_data import af1 as _af1
+from malariagen_data import ag3 as _ag3
+from malariagen_data.anoph.snp_frq import AnophelesSnpFrequencyAnalysis
+from malariagen_data.util import compare_series_like
+
+
+@pytest.fixture
+def ag3_sim_api(ag3_sim_fixture):
+ return AnophelesSnpFrequencyAnalysis(
+ url=ag3_sim_fixture.url,
+ config_path=_ag3.CONFIG_PATH,
+ gcs_url=_ag3.GCS_URL,
+ major_version_number=_ag3.MAJOR_VERSION_NUMBER,
+ major_version_path=_ag3.MAJOR_VERSION_PATH,
+ pre=True,
+ aim_metadata_dtype={
+ "aim_species_fraction_arab": "float64",
+ "aim_species_fraction_colu": "float64",
+ "aim_species_fraction_colu_no2l": "float64",
+ "aim_species_gambcolu_arabiensis": object,
+ "aim_species_gambiae_coluzzii": object,
+ "aim_species": object,
+ },
+ gff_gene_type="gene",
+ gff_default_attributes=("ID", "Parent", "Name", "description"),
+ default_site_mask="gamb_colu_arab",
+ results_cache=ag3_sim_fixture.results_cache_path.as_posix(),
+ taxon_colors=_ag3.TAXON_COLORS,
+ )
+
+
+@pytest.fixture
+def af1_sim_api(af1_sim_fixture):
+ return AnophelesSnpFrequencyAnalysis(
+ url=af1_sim_fixture.url,
+ config_path=_af1.CONFIG_PATH,
+ gcs_url=_af1.GCS_URL,
+ major_version_number=_af1.MAJOR_VERSION_NUMBER,
+ major_version_path=_af1.MAJOR_VERSION_PATH,
+ pre=False,
+ gff_gene_type="protein_coding_gene",
+ gff_default_attributes=("ID", "Parent", "Note", "description"),
+ default_site_mask="funestus",
+ results_cache=af1_sim_fixture.results_cache_path.as_posix(),
+ taxon_colors=_af1.TAXON_COLORS,
+ )
+
+
+# N.B., here we use pytest_cases to parametrize tests. Each
+# function whose name begins with "case_" defines a set of
+# inputs to the test functions. See the documentation for
+# pytest_cases for more information, e.g.:
+#
+# https://smarie.github.io/python-pytest-cases/#basic-usage
+#
+# We use this approach here because we want to use fixtures
+# as test parameters, which is otherwise hard to do with
+# pytest alone.
+
+
+def case_ag3_sim(ag3_sim_fixture, ag3_sim_api):
+ return ag3_sim_fixture, ag3_sim_api
+
+
+def case_af1_sim(af1_sim_fixture, af1_sim_api):
+ return af1_sim_fixture, af1_sim_api
+
+
+expected_alleles = list("ACGT")
+expected_effects = [
+ "FIVE_PRIME_UTR",
+ "THREE_PRIME_UTR",
+ "SYNONYMOUS_CODING",
+ "NON_SYNONYMOUS_CODING",
+ "START_LOST",
+ "STOP_LOST",
+ "STOP_GAINED",
+ "SPLICE_CORE",
+ "SPLICE_REGION",
+ "INTRONIC",
+ "TRANSCRIPT",
+]
+expected_impacts = [
+ "HIGH",
+ "MODERATE",
+ "LOW",
+ "MODIFIER",
+]
+
+
+def random_transcript(*, api):
+ df_gff = api.genome_features(attributes=["ID", "Parent"])
+ df_transcripts = df_gff.query("type == 'mRNA'")
+ transcript_ids = df_transcripts["ID"].dropna().to_list()
+ transcript_id = random.choice(transcript_ids)
+ transcript = df_transcripts.set_index("ID").loc[transcript_id]
+ return transcript
+
+
+@parametrize_with_cases("fixture,api", cases=".")
+def test_snp_effects(fixture, api: AnophelesSnpFrequencyAnalysis):
+ # Pick a random transcript.
+ transcript = random_transcript(api=api)
+
+ # Pick a random site mask.
+ site_mask = random.choice(api.site_mask_ids + (None,))
+
+ # Compute effects.
+ df = api.snp_effects(transcript=transcript.name, site_mask=site_mask)
+ assert isinstance(df, pd.DataFrame)
+
+ # Check columns.
+ expected_fields = (
+ [
+ "contig",
+ "position",
+ "ref_allele",
+ "alt_allele",
+ ]
+ + [f"pass_{m}" for m in api.site_mask_ids]
+ + [
+ "transcript",
+ "effect",
+ "impact",
+ "ref_codon",
+ "alt_codon",
+ "aa_pos",
+ "ref_aa",
+ "alt_aa",
+ "aa_change",
+ ]
+ )
+ assert df.columns.tolist() == expected_fields
+
+ # Check some values.
+ assert np.all(df["contig"] == transcript["contig"])
+ position = df["position"].values
+ assert np.all(position >= transcript["start"])
+ assert np.all(position <= transcript["end"])
+ assert np.all(position[1:] >= position[:-1])
+ expected_alleles = list("ACGT")
+ assert np.all(df["ref_allele"].isin(expected_alleles))
+ assert np.all(df["alt_allele"].isin(expected_alleles))
+ assert np.all(df["transcript"] == transcript.name)
+ assert np.all(df["effect"].isin(expected_effects))
+ assert np.all(df["impact"].isin(expected_impacts))
+ df_aa = df[~df["aa_change"].isna()]
+ expected_aa_change = (
+ df_aa["ref_aa"] + df_aa["aa_pos"].astype(int).astype(str) + df_aa["alt_aa"]
+ )
+ assert np.all(df_aa["aa_change"] == expected_aa_change)
+
+
+def check_frequency(x):
+ loc_nan = np.isnan(x)
+ assert np.all(x[~loc_nan] >= 0)
+ assert np.all(x[~loc_nan] <= 1)
+
+
+def check_snp_allele_frequencies(
+ *,
+ api,
+ df,
+ cohort_labels,
+ transcript,
+):
+ assert isinstance(df, pd.DataFrame)
+
+ # Check columns.
+ universal_fields = [f"pass_{m}" for m in api.site_mask_ids] + [
+ "label",
+ ]
+ effects_fields = [
+ "transcript",
+ "effect",
+ "impact",
+ "ref_codon",
+ "alt_codon",
+ "aa_pos",
+ "ref_aa",
+ "alt_aa",
+ ]
+ frq_fields = ["frq_" + s for s in cohort_labels] + ["max_af"]
+ expected_fields = universal_fields + frq_fields + effects_fields
+ assert sorted(df.columns.tolist()) == sorted(expected_fields)
+ assert df.index.names == [
+ "contig",
+ "position",
+ "ref_allele",
+ "alt_allele",
+ "aa_change",
+ ]
+
+ # Check some values.
+ df = df.reset_index()
+ assert np.all(df["contig"] == transcript["contig"])
+ position = df["position"].values
+ assert np.all(position >= transcript["start"])
+ assert np.all(position <= transcript["end"])
+ assert np.all(position[1:] >= position[:-1])
+ assert np.all(df["ref_allele"].isin(expected_alleles))
+ assert np.all(df["alt_allele"].isin(expected_alleles))
+ assert np.all(df["transcript"] == transcript.name)
+ assert np.all(df["effect"].isin(expected_effects))
+ assert np.all(df["impact"].isin(expected_impacts))
+ df_aa = df[~df["aa_change"].isna()]
+ expected_aa_change = (
+ df_aa["ref_aa"] + df_aa["aa_pos"].astype(int).astype(str) + df_aa["alt_aa"]
+ )
+ assert np.all(df_aa["aa_change"] == expected_aa_change)
+ for f in frq_fields:
+ x = df[f]
+ check_frequency(x)
+
+
+def check_aa_allele_frequencies(
+ *,
+ df,
+ cohort_labels,
+ transcript,
+):
+ assert isinstance(df, pd.DataFrame)
+
+ # Check columns.
+ universal_fields = [
+ "label",
+ ]
+ effects_fields = [
+ "transcript",
+ "effect",
+ "impact",
+ "aa_pos",
+ "ref_allele",
+ "alt_allele",
+ "ref_aa",
+ "alt_aa",
+ ]
+ frq_fields = ["frq_" + s for s in cohort_labels] + ["max_af"]
+ expected_fields = universal_fields + frq_fields + effects_fields
+ expected_fields = universal_fields + frq_fields + effects_fields
+ assert sorted(df.columns.tolist()) == sorted(expected_fields)
+ assert df.index.names == [
+ "aa_change",
+ "contig",
+ "position",
+ ]
+
+ # Check some values.
+ df = df.reset_index()
+ assert np.all(df["contig"] == transcript["contig"])
+ position = df["position"].values
+ assert np.all(position >= transcript["start"])
+ assert np.all(position <= transcript["end"])
+ assert np.all(position[1:] >= position[:-1])
+ assert np.all(df["ref_allele"].isin(expected_alleles))
+ # N.B., alt_allele may contain multiple alleles, e.g., "{A,T}", if
+ # multiple SNP alleles at the same position cause the same amino acid
+ # change.
+ assert np.all(df["transcript"] == transcript.name)
+ assert np.all(df["effect"].isin(expected_effects))
+ assert np.all(df["impact"].isin(expected_impacts))
+ df_aa = df[~df["aa_change"].isna()]
+ expected_aa_change = (
+ df_aa["ref_aa"] + df_aa["aa_pos"].astype(int).astype(str) + df_aa["alt_aa"]
+ )
+ assert np.all(df_aa["aa_change"] == expected_aa_change)
+ for f in frq_fields:
+ x = df[f]
+ check_frequency(x)
+
+
+@pytest.mark.parametrize(
+ "cohorts", ["admin1_year", "admin2_month", "country", "foobar"]
+)
+@parametrize_with_cases("fixture,api", cases=".")
+def test_allele_frequencies_with_str_cohorts(
+ fixture,
+ api: AnophelesSnpFrequencyAnalysis,
+ cohorts,
+):
+ # Pick test parameters at random.
+ all_sample_sets = api.sample_sets()["sample_set"].to_list()
+ sample_sets = random.choice(all_sample_sets)
+ site_mask = random.choice(api.site_mask_ids + (None,))
+ min_cohort_size = random.randint(0, 2)
+ transcript = random_transcript(api=api)
+
+ # Set up call params.
+ params = dict(
+ transcript=transcript.name,
+ cohorts=cohorts,
+ min_cohort_size=min_cohort_size,
+ site_mask=site_mask,
+ sample_sets=sample_sets,
+ drop_invariant=True,
+ )
+
+ # Test behaviour with bad cohorts param.
+ if cohorts == "foobar":
+ with pytest.raises(ValueError):
+ api.snp_allele_frequencies(**params)
+ return
+
+ # Run the function under test.
+ df_snp = api.snp_allele_frequencies(**params)
+
+ # Figure out expected cohort labels.
+ df_samples = api.sample_metadata(sample_sets=sample_sets)
+ if "cohort_" + cohorts in df_samples:
+ cohort_column = "cohort_" + cohorts
+ else:
+ cohort_column = cohorts
+ cohort_counts = df_samples[cohort_column].value_counts()
+ cohort_labels = cohort_counts[cohort_counts >= min_cohort_size].index.to_list()
+
+ # Standard checks.
+ check_snp_allele_frequencies(
+ api=api,
+ df=df_snp,
+ cohort_labels=cohort_labels,
+ transcript=transcript,
+ )
+
+ # Run the function under test.
+ df_aa = api.aa_allele_frequencies(**params)
+
+ # Standard checks.
+ check_aa_allele_frequencies(
+ df=df_aa,
+ cohort_labels=cohort_labels,
+ transcript=transcript,
+ )
+
+
+@pytest.mark.parametrize("min_cohort_size", [0, 10, 100])
+@parametrize_with_cases("fixture,api", cases=".")
+def test_allele_frequencies_with_min_cohort_size(
+ fixture,
+ api: AnophelesSnpFrequencyAnalysis,
+ min_cohort_size,
+):
+ # Pick test parameters at random.
+ all_sample_sets = api.sample_sets()["sample_set"].to_list()
+ sample_sets = random.choice(all_sample_sets)
+ site_mask = random.choice(api.site_mask_ids + (None,))
+ transcript = random_transcript(api=api)
+ cohorts = "admin1_year"
+
+ # Set up call params.
+ params = dict(
+ transcript=transcript.name,
+ cohorts=cohorts,
+ min_cohort_size=min_cohort_size,
+ site_mask=site_mask,
+ sample_sets=sample_sets,
+ drop_invariant=True,
+ )
+
+ # Figure out expected cohort labels.
+ df_samples = api.sample_metadata(sample_sets=sample_sets)
+ if "cohort_" + cohorts in df_samples:
+ cohort_column = "cohort_" + cohorts
+ else:
+ cohort_column = cohorts
+ cohort_counts = df_samples[cohort_column].value_counts()
+ cohort_labels = cohort_counts[cohort_counts >= min_cohort_size].index.to_list()
+
+ if len(cohort_labels) == 0:
+ # No cohorts, expect error.
+ with pytest.raises(ValueError):
+ api.snp_allele_frequencies(**params)
+ with pytest.raises(ValueError):
+ api.aa_allele_frequencies(**params)
+ return
+
+ # Run the function under test.
+ df_snp = api.snp_allele_frequencies(**params)
+
+ # Standard checks.
+ check_snp_allele_frequencies(
+ api=api,
+ df=df_snp,
+ cohort_labels=cohort_labels,
+ transcript=transcript,
+ )
+
+ # Run the function under test.
+ df_aa = api.aa_allele_frequencies(**params)
+
+ # Standard checks.
+ check_aa_allele_frequencies(
+ df=df_aa,
+ cohort_labels=cohort_labels,
+ transcript=transcript,
+ )
+
+
+@parametrize_with_cases("fixture,api", cases=".")
+def test_allele_frequencies_with_str_cohorts_and_sample_query(
+ fixture,
+ api: AnophelesSnpFrequencyAnalysis,
+):
+ # Pick test parameters at random.
+ sample_sets = None
+ site_mask = random.choice(api.site_mask_ids + (None,))
+ min_cohort_size = 0
+ transcript = random_transcript(api=api)
+ cohorts = random.choice(
+ ["admin1_year", "admin1_month", "admin2_year", "admin2_month"]
+ )
+ df_samples = api.sample_metadata(sample_sets=sample_sets)
+ countries = df_samples["country"].unique()
+ country = random.choice(countries)
+ sample_query = f"country == '{country}'"
+
+ # Figure out expected cohort labels.
+ df_samples = api.sample_metadata(sample_sets=sample_sets, sample_query=sample_query)
+ cohort_column = "cohort_" + cohorts
+ cohort_counts = df_samples[cohort_column].value_counts()
+ cohort_labels = cohort_counts[cohort_counts >= min_cohort_size].index.to_list()
+
+ # Set up call params.
+ params = dict(
+ transcript=transcript.name,
+ cohorts=cohorts,
+ min_cohort_size=min_cohort_size,
+ site_mask=site_mask,
+ sample_sets=sample_sets,
+ sample_query=sample_query,
+ drop_invariant=True,
+ )
+
+ # Run the function under test.
+ df_snp = api.snp_allele_frequencies(**params)
+
+ # Standard checks.
+ check_snp_allele_frequencies(
+ api=api,
+ df=df_snp,
+ cohort_labels=cohort_labels,
+ transcript=transcript,
+ )
+
+ # Run the function under test.
+ df_aa = api.aa_allele_frequencies(**params)
+
+ # Standard checks.
+ check_aa_allele_frequencies(
+ df=df_aa,
+ cohort_labels=cohort_labels,
+ transcript=transcript,
+ )
+
+
+@parametrize_with_cases("fixture,api", cases=".")
+def test_allele_frequencies_with_dict_cohorts(
+ fixture, api: AnophelesSnpFrequencyAnalysis
+):
+ # Pick test parameters at random.
+ sample_sets = None # all sample sets
+ site_mask = random.choice(api.site_mask_ids + (None,))
+ min_cohort_size = random.randint(0, 2)
+ transcript = random_transcript(api=api)
+
+ # Create cohorts by country.
+ df_samples = api.sample_metadata(sample_sets=sample_sets)
+ cohort_counts = df_samples["country"].value_counts()
+ cohorts = {cohort: f"country == '{cohort}'" for cohort in cohort_counts.index}
+ cohort_labels = cohort_counts[cohort_counts >= min_cohort_size].index.to_list()
+
+ # Set up call params.
+ params = dict(
+ transcript=transcript.name,
+ cohorts=cohorts,
+ min_cohort_size=min_cohort_size,
+ site_mask=site_mask,
+ sample_sets=sample_sets,
+ drop_invariant=True,
+ )
+
+ # Run the function under test.
+ df_snp = api.snp_allele_frequencies(**params)
+
+ # Standard checks.
+ check_snp_allele_frequencies(
+ api=api,
+ df=df_snp,
+ cohort_labels=cohort_labels,
+ transcript=transcript,
+ )
+
+ # Run the function under test.
+ df_aa = api.aa_allele_frequencies(**params)
+
+ # Standard checks.
+ check_aa_allele_frequencies(
+ df=df_aa,
+ cohort_labels=cohort_labels,
+ transcript=transcript,
+ )
+
+
+@parametrize_with_cases("fixture,api", cases=".")
+def test_allele_frequencies_without_drop_invariant(
+ fixture,
+ api: AnophelesSnpFrequencyAnalysis,
+):
+ # Pick test parameters at random.
+ all_sample_sets = api.sample_sets()["sample_set"].to_list()
+ sample_sets = random.choice(all_sample_sets)
+ site_mask = random.choice(api.site_mask_ids + (None,))
+ min_cohort_size = random.randint(0, 2)
+ transcript = random_transcript(api=api)
+ cohorts = random.choice(["admin1_year", "admin2_month", "country"])
+
+ # Figure out expected cohort labels.
+ df_samples = api.sample_metadata(sample_sets=sample_sets)
+ if "cohort_" + cohorts in df_samples:
+ cohort_column = "cohort_" + cohorts
+ else:
+ cohort_column = cohorts
+ cohort_counts = df_samples[cohort_column].value_counts()
+ cohort_labels = cohort_counts[cohort_counts >= min_cohort_size].index.to_list()
+
+ # Set up call params.
+ params = dict(
+ transcript=transcript.name,
+ cohorts=cohorts,
+ min_cohort_size=min_cohort_size,
+ site_mask=site_mask,
+ sample_sets=sample_sets,
+ )
+
+ # Run the function under test.
+ df_snp_a = api.snp_allele_frequencies(drop_invariant=True, **params)
+ df_snp_b = api.snp_allele_frequencies(drop_invariant=False, **params)
+
+ # Standard checks.
+ check_snp_allele_frequencies(
+ api=api,
+ df=df_snp_a,
+ cohort_labels=cohort_labels,
+ transcript=transcript,
+ )
+ check_snp_allele_frequencies(
+ api=api,
+ df=df_snp_b,
+ cohort_labels=cohort_labels,
+ transcript=transcript,
+ )
+
+ # Check specifics.
+ assert len(df_snp_b) > len(df_snp_a)
+
+
+@parametrize_with_cases("fixture,api", cases=".")
+def test_allele_frequencies_without_effects(
+ fixture,
+ api: AnophelesSnpFrequencyAnalysis,
+):
+ # Pick test parameters at random.
+ all_sample_sets = api.sample_sets()["sample_set"].to_list()
+ sample_sets = random.choice(all_sample_sets)
+ site_mask = random.choice(api.site_mask_ids + (None,))
+ min_cohort_size = random.randint(0, 2)
+ transcript = random_transcript(api=api)
+ cohorts = random.choice(["admin1_year", "admin2_month", "country"])
+
+ # Figure out expected cohort labels.
+ df_samples = api.sample_metadata(sample_sets=sample_sets)
+ if "cohort_" + cohorts in df_samples:
+ cohort_column = "cohort_" + cohorts
+ else:
+ cohort_column = cohorts
+ cohort_counts = df_samples[cohort_column].value_counts()
+ cohort_labels = cohort_counts[cohort_counts >= min_cohort_size].index.to_list()
+
+ # Set up call params.
+ params = dict(
+ transcript=transcript.name,
+ cohorts=cohorts,
+ min_cohort_size=min_cohort_size,
+ site_mask=site_mask,
+ sample_sets=sample_sets,
+ drop_invariant=True,
+ )
+
+ # Run the function under test.
+ df_snp_a = api.snp_allele_frequencies(effects=True, **params)
+ df_snp_b = api.snp_allele_frequencies(effects=False, **params)
+
+ # Standard checks.
+ check_snp_allele_frequencies(
+ api=api,
+ df=df_snp_a,
+ cohort_labels=cohort_labels,
+ transcript=transcript,
+ )
+
+ # Check specifics.
+ assert len(df_snp_b) == len(df_snp_a)
+
+ # Check columns and index names.
+ filter_fields = [f"pass_{m}" for m in api.site_mask_ids]
+ universal_fields = filter_fields + ["label"]
+ frq_fields = ["frq_" + s for s in cohort_labels] + ["max_af"]
+ expected_fields = universal_fields + frq_fields
+ assert sorted(df_snp_b.columns.tolist()) == sorted(expected_fields)
+ assert df_snp_b.index.names == [
+ "contig",
+ "position",
+ "ref_allele",
+ "alt_allele",
+ ]
+
+ # Compare values with and without effects.
+ comparable_fields = (
+ [
+ "contig",
+ "position",
+ "ref_allele",
+ "alt_allele",
+ ]
+ + filter_fields
+ + frq_fields
+ )
+ # N.B., values of the "label" field are different with and without
+ # effects, so don't compare them.
+ assert_frame_equal(
+ df_snp_b.reset_index()[comparable_fields],
+ df_snp_a.reset_index()[comparable_fields],
+ )
+
+
+@parametrize_with_cases("fixture,api", cases=".")
+def test_allele_frequencies_with_bad_transcript(
+ fixture,
+ api: AnophelesSnpFrequencyAnalysis,
+):
+ # Pick test parameters at random.
+ all_sample_sets = api.sample_sets()["sample_set"].to_list()
+ sample_sets = random.choice(all_sample_sets)
+ site_mask = random.choice(api.site_mask_ids + (None,))
+ min_cohort_size = random.randint(0, 2)
+ cohorts = random.choice(["admin1_year", "admin2_month", "country"])
+
+ # Set up call params.
+ params = dict(
+ transcript="foobar",
+ cohorts=cohorts,
+ min_cohort_size=min_cohort_size,
+ site_mask=site_mask,
+ sample_sets=sample_sets,
+ drop_invariant=True,
+ )
+
+ # Run the function under test.
+ with pytest.raises(ValueError):
+ api.snp_allele_frequencies(**params)
+
+
+@parametrize_with_cases("fixture,api", cases=".")
+def test_allele_frequencies_with_region(
+ fixture,
+ api: AnophelesSnpFrequencyAnalysis,
+):
+ # Pick test parameters at random.
+ all_sample_sets = api.sample_sets()["sample_set"].to_list()
+ sample_sets = random.choice(all_sample_sets)
+ site_mask = random.choice(api.site_mask_ids + (None,))
+ min_cohort_size = random.randint(0, 2)
+ cohorts = random.choice(["admin1_year", "admin2_month", "country"])
+ # This should work, as long as effects=False - i.e., can get frequencies
+ # for any genome region.
+ transcript = fixture.random_region_str(region_size=500)
+
+ # Set up call params.
+ params = dict(
+ transcript=transcript,
+ cohorts=cohorts,
+ min_cohort_size=min_cohort_size,
+ site_mask=site_mask,
+ sample_sets=sample_sets,
+ drop_invariant=False,
+ effects=False,
+ )
+
+ # Run the function under test.
+ df_snp = api.snp_allele_frequencies(**params)
+
+ # Basic checks.
+ assert isinstance(df_snp, pd.DataFrame)
+ assert len(df_snp) > 0
+
+ # Figure out expected cohort labels.
+ df_samples = api.sample_metadata(sample_sets=sample_sets)
+ if "cohort_" + cohorts in df_samples:
+ cohort_column = "cohort_" + cohorts
+ else:
+ cohort_column = cohorts
+ cohort_counts = df_samples[cohort_column].value_counts()
+ cohort_labels = cohort_counts[cohort_counts >= min_cohort_size].index.to_list()
+
+ # Check columns and index names.
+ filter_fields = [f"pass_{m}" for m in api.site_mask_ids]
+ universal_fields = filter_fields + ["label"]
+ frq_fields = ["frq_" + s for s in cohort_labels] + ["max_af"]
+ expected_fields = universal_fields + frq_fields
+ assert sorted(df_snp.columns.tolist()) == sorted(expected_fields)
+ assert df_snp.index.names == [
+ "contig",
+ "position",
+ "ref_allele",
+ "alt_allele",
+ ]
+
+
+@parametrize_with_cases("fixture,api", cases=".")
+def test_allele_frequencies_with_dup_samples(
+ fixture,
+ api: AnophelesSnpFrequencyAnalysis,
+):
+ # Pick test parameters at random.
+ all_sample_sets = api.sample_sets()["sample_set"].to_list()
+ sample_set = random.choice(all_sample_sets)
+ site_mask = random.choice(api.site_mask_ids + (None,))
+ min_cohort_size = random.randint(0, 2)
+ transcript = random_transcript(api=api)
+ cohorts = random.choice(["admin1_year", "admin2_month", "country"])
+
+ # Set up call params.
+ params = dict(
+ transcript=transcript.name,
+ cohorts=cohorts,
+ min_cohort_size=min_cohort_size,
+ site_mask=site_mask,
+ )
+
+ # Run the function under test.
+ df_snp_a = api.snp_allele_frequencies(sample_sets=[sample_set], **params)
+ df_snp_b = api.snp_allele_frequencies(
+ sample_sets=[sample_set, sample_set], **params
+ )
+
+ # Expect automatically deduplicate sample sets.
+ assert_frame_equal(df_snp_b, df_snp_a)
+
+ # Run the function under test.
+ df_aa_a = api.aa_allele_frequencies(sample_sets=[sample_set], **params)
+ df_aa_b = api.aa_allele_frequencies(sample_sets=[sample_set, sample_set], **params)
+
+ # Expect automatically deduplicate sample sets.
+ assert_frame_equal(df_aa_b, df_aa_a)
+
+
+def check_snp_allele_frequencies_advanced(
+ *,
+ api: AnophelesSnpFrequencyAnalysis,
+ transcript=None,
+ area_by="admin1_iso",
+ period_by="year",
+ sample_sets=None,
+ sample_query=None,
+ min_cohort_size=None,
+ nobs_mode="called",
+ variant_query=None,
+ site_mask=None,
+):
+ # Pick test parameters at random.
+ if transcript is None:
+ transcript = random_transcript(api=api).name
+ if area_by is None:
+ area_by = random.choice(["country", "admin1_iso", "admin2_name"])
+ if period_by is None:
+ period_by = random.choice(["year", "quarter", "month"])
+ if sample_sets is None:
+ all_sample_sets = api.sample_sets()["sample_set"].to_list()
+ sample_sets = random.choice(all_sample_sets)
+ if min_cohort_size is None:
+ min_cohort_size = random.randint(0, 2)
+ if site_mask is None:
+ site_mask = random.choice(api.site_mask_ids + (None,))
+
+ # Run function under test.
+ ds = api.snp_allele_frequencies_advanced(
+ transcript=transcript,
+ area_by=area_by,
+ period_by=period_by,
+ sample_sets=sample_sets,
+ sample_query=sample_query,
+ min_cohort_size=min_cohort_size,
+ nobs_mode=nobs_mode,
+ variant_query=variant_query,
+ site_mask=site_mask,
+ )
+
+ # Check the result.
+ assert isinstance(ds, xr.Dataset)
+ assert set(ds.dims) == {"cohorts", "variants"}
+
+ # Check variant variables.
+ expected_variant_vars = [
+ "variant_label",
+ "variant_contig",
+ "variant_position",
+ "variant_ref_allele",
+ "variant_alt_allele",
+ "variant_max_af",
+ "variant_transcript",
+ "variant_effect",
+ "variant_impact",
+ "variant_ref_codon",
+ "variant_alt_codon",
+ "variant_ref_aa",
+ "variant_alt_aa",
+ "variant_aa_pos",
+ "variant_aa_change",
+ ]
+ expected_variant_vars += [f"variant_pass_{m}" for m in api.site_mask_ids]
+ for v in expected_variant_vars:
+ a = ds[v]
+ assert isinstance(a, xr.DataArray)
+ assert a.dims == ("variants",)
+
+ # Check cohort variables.
+ expected_cohort_vars = [
+ "cohort_label",
+ "cohort_size",
+ "cohort_taxon",
+ "cohort_area",
+ "cohort_period",
+ "cohort_period_start",
+ "cohort_period_end",
+ "cohort_lat_mean",
+ "cohort_lat_min",
+ "cohort_lat_max",
+ "cohort_lon_mean",
+ "cohort_lon_min",
+ "cohort_lon_max",
+ ]
+ for v in expected_cohort_vars:
+ a = ds[v]
+ assert isinstance(a, xr.DataArray)
+ assert a.dims == ("cohorts",)
+
+ # Check event variables.
+ expected_event_vars = [
+ "event_count",
+ "event_nobs",
+ "event_frequency",
+ "event_frequency_ci_low",
+ "event_frequency_ci_upp",
+ ]
+ for v in expected_event_vars:
+ a = ds[v]
+ assert isinstance(a, xr.DataArray)
+ assert a.dims == ("variants", "cohorts")
+
+ # Sanity check for frequency values.
+ x = ds["event_frequency"].values
+ check_frequency(x)
+
+ # Sanity check area values.
+ df_samples = api.sample_metadata(sample_sets=sample_sets, sample_query=sample_query)
+ expected_area_values = np.unique(df_samples[area_by].dropna().values)
+ area_values = ds["cohort_area"].values
+ # N.B., some areas may not end up in final dataset if cohort
+ # size is too small, so do a set membership test
+ for a in area_values:
+ assert a in expected_area_values
+
+ # Sanity checks for period values.
+ period_values = ds["cohort_period"].values
+ if period_by == "year":
+ expected_freqstr = "A-DEC"
+ elif period_by == "month":
+ expected_freqstr = "M"
+ elif period_by == "quarter":
+ expected_freqstr = "Q-DEC"
+ else:
+ assert False, "not implemented"
+ for p in period_values:
+ assert isinstance(p, pd.Period)
+ assert p.freqstr == expected_freqstr
+
+ # Sanity check cohort sizes.
+ cohort_size_values = ds["cohort_size"].values
+ for s in cohort_size_values:
+ assert s >= min_cohort_size
+
+ if area_by == "admin1_iso" and period_by == "year" and nobs_mode == "called":
+ # Here we test the behaviour of the function when grouping by admin level
+ # 1 and year. We can do some more in-depth testing in this case because
+ # we can compare results directly against the simpler snp_allele_frequencies()
+ # function with the admin1_year cohorts.
+
+ # Check consistency with the basic snp allele frequencies method.
+ df_af = api.snp_allele_frequencies(
+ transcript=transcript,
+ cohorts="admin1_year",
+ sample_sets=sample_sets,
+ sample_query=sample_query,
+ min_cohort_size=min_cohort_size,
+ site_mask=site_mask,
+ include_counts=True,
+ )
+ # Make sure all variables available to check.
+ df_af = df_af.reset_index()
+ if variant_query is not None:
+ df_af = df_af.query(variant_query)
+
+ # Check cohorts are consistent.
+ expect_cohort_labels = sorted(
+ [c.split("frq_")[1] for c in df_af.columns if c.startswith("frq_")]
+ )
+ cohort_labels = sorted(ds["cohort_label"].values)
+ assert cohort_labels == expect_cohort_labels
+
+ # Check variants are consistent.
+ assert ds.sizes["variants"] == len(df_af)
+ for v in expected_variant_vars:
+ c = v.split("variant_")[1]
+ actual = ds[v]
+ expect = df_af[c]
+ compare_series_like(actual, expect)
+
+ # Check frequencies are consistent.
+ for cohort_index, cohort_label in enumerate(ds["cohort_label"].values):
+ actual_nobs = ds["event_nobs"].values[:, cohort_index]
+ expect_nobs = df_af[f"nobs_{cohort_label}"].values
+ assert_array_equal(actual_nobs, expect_nobs)
+ actual_count = ds["event_count"].values[:, cohort_index]
+ expect_count = df_af[f"count_{cohort_label}"].values
+ assert_array_equal(actual_count, expect_count)
+ actual_frq = ds["event_frequency"].values[:, cohort_index]
+ expect_frq = df_af[f"frq_{cohort_label}"].values
+ assert_allclose(actual_frq, expect_frq)
+
+
+def check_aa_allele_frequencies_advanced(
+ *,
+ api: AnophelesSnpFrequencyAnalysis,
+ transcript=None,
+ area_by="admin1_iso",
+ period_by="year",
+ sample_sets=None,
+ sample_query=None,
+ min_cohort_size=None,
+ nobs_mode="called",
+ variant_query=None,
+):
+ # Pick test parameters at random.
+ if transcript is None:
+ transcript = random_transcript(api=api).name
+ if area_by is None:
+ area_by = random.choice(["country", "admin1_iso", "admin2_name"])
+ if period_by is None:
+ period_by = random.choice(["year", "quarter", "month"])
+ if sample_sets is None:
+ all_sample_sets = api.sample_sets()["sample_set"].to_list()
+ sample_sets = random.choice(all_sample_sets)
+ if min_cohort_size is None:
+ min_cohort_size = random.randint(0, 2)
+
+ # Run function under test.
+ ds = api.aa_allele_frequencies_advanced(
+ transcript=transcript,
+ area_by=area_by,
+ period_by=period_by,
+ sample_sets=sample_sets,
+ sample_query=sample_query,
+ min_cohort_size=min_cohort_size,
+ nobs_mode=nobs_mode,
+ variant_query=variant_query,
+ )
+
+ # Check the result.
+ assert isinstance(ds, xr.Dataset)
+ assert set(ds.dims) == {"cohorts", "variants"}
+
+ expected_variant_vars = (
+ "variant_label",
+ "variant_contig",
+ "variant_position",
+ "variant_max_af",
+ "variant_transcript",
+ "variant_effect",
+ "variant_impact",
+ "variant_ref_aa",
+ "variant_alt_aa",
+ "variant_aa_pos",
+ "variant_aa_change",
+ )
+ for v in expected_variant_vars:
+ a = ds[v]
+ assert isinstance(a, xr.DataArray)
+ assert a.dims == ("variants",)
+
+ expected_cohort_vars = (
+ "cohort_label",
+ "cohort_size",
+ "cohort_taxon",
+ "cohort_area",
+ "cohort_period",
+ "cohort_period_start",
+ "cohort_period_end",
+ "cohort_lat_mean",
+ "cohort_lat_min",
+ "cohort_lat_max",
+ "cohort_lon_mean",
+ "cohort_lon_min",
+ "cohort_lon_max",
+ )
+ for v in expected_cohort_vars:
+ a = ds[v]
+ assert isinstance(a, xr.DataArray)
+ assert a.dims == ("cohorts",)
+
+ expected_event_vars = (
+ "event_count",
+ "event_nobs",
+ "event_frequency",
+ "event_frequency_ci_low",
+ "event_frequency_ci_upp",
+ )
+ for v in expected_event_vars:
+ a = ds[v]
+ assert isinstance(a, xr.DataArray)
+ assert a.dims == ("variants", "cohorts")
+
+ # Sanity check for frequency values.
+ x = ds["event_frequency"].values
+ check_frequency(x)
+
+ # Sanity checks for area values.
+ df_samples = api.sample_metadata(sample_sets=sample_sets, sample_query=sample_query)
+ expected_area_values = np.unique(df_samples[area_by].dropna().values)
+ area_values = ds["cohort_area"].values
+ # N.B., some areas may not end up in final dataset if cohort
+ # size is too small, so do a set membership test
+ for a in area_values:
+ assert a in expected_area_values
+
+ # Sanity checks for period values.
+ period_values = ds["cohort_period"].values
+ if period_by == "year":
+ expected_freqstr = "A-DEC"
+ elif period_by == "month":
+ expected_freqstr = "M"
+ elif period_by == "quarter":
+ expected_freqstr = "Q-DEC"
+ else:
+ assert False, "not implemented"
+ for p in period_values:
+ assert isinstance(p, pd.Period)
+ assert p.freqstr == expected_freqstr
+
+ # Sanity check cohort size.
+ cohort_size_values = ds["cohort_size"].values
+ for s in cohort_size_values:
+ assert s >= min_cohort_size
+
+ if area_by == "admin1_iso" and period_by == "year" and nobs_mode == "called":
+ # Here we test the behaviour of the function when grouping by admin level
+ # 1 and year. We can do some more in-depth testing in this case because
+ # we can compare results directly against the simpler aa_allele_frequencies()
+ # function with the admin1_year cohorts.
+
+ # Check consistency with the basic aa allele frequencies method.
+ df_af = api.aa_allele_frequencies(
+ transcript=transcript,
+ cohorts="admin1_year",
+ sample_sets=sample_sets,
+ sample_query=sample_query,
+ min_cohort_size=min_cohort_size,
+ include_counts=True,
+ )
+ # Make sure all variables available to check.
+ df_af = df_af.reset_index()
+ if variant_query is not None:
+ df_af = df_af.query(variant_query)
+
+ # Check cohorts are consistent.
+ expect_cohort_labels = sorted(
+ [c.split("frq_")[1] for c in df_af.columns if c.startswith("frq_")]
+ )
+ cohort_labels = sorted(ds["cohort_label"].values)
+ assert cohort_labels == expect_cohort_labels
+
+ # Check variants are consistent.
+ assert ds.sizes["variants"] == len(df_af)
+ for v in expected_variant_vars:
+ c = v.split("variant_")[1]
+ actual = ds[v]
+ expect = df_af[c]
+ compare_series_like(actual, expect)
+
+ # Check frequencies are consistent.
+ for cohort_index, cohort_label in enumerate(ds["cohort_label"].values):
+ actual_nobs = ds["event_nobs"].values[:, cohort_index]
+ expect_nobs = df_af[f"nobs_{cohort_label}"].values
+ assert_array_equal(actual_nobs, expect_nobs)
+ actual_count = ds["event_count"].values[:, cohort_index]
+ expect_count = df_af[f"count_{cohort_label}"].values
+ assert_array_equal(actual_count, expect_count)
+ actual_frq = ds["event_frequency"].values[:, cohort_index]
+ expect_frq = df_af[f"frq_{cohort_label}"].values
+ assert_allclose(actual_frq, expect_frq)
+
+
+# Here we don't explore the full matrix, but vary one parameter at a time, otherwise
+# the test suite would take too long to run.
+
+
+@pytest.mark.parametrize("area_by", ["country", "admin1_iso", "admin2_name"])
+@parametrize_with_cases("fixture,api", cases=".")
+def test_allele_frequencies_advanced_with_area_by(
+ fixture,
+ api: AnophelesSnpFrequencyAnalysis,
+ area_by,
+):
+ check_snp_allele_frequencies_advanced(
+ api=api,
+ area_by=area_by,
+ )
+ check_aa_allele_frequencies_advanced(
+ api=api,
+ area_by=area_by,
+ )
+
+
+@pytest.mark.parametrize("period_by", ["year", "quarter", "month"])
+@parametrize_with_cases("fixture,api", cases=".")
+def test_allele_frequencies_advanced_with_period_by(
+ fixture,
+ api: AnophelesSnpFrequencyAnalysis,
+ period_by,
+):
+ check_snp_allele_frequencies_advanced(
+ api=api,
+ period_by=period_by,
+ )
+ check_aa_allele_frequencies_advanced(
+ api=api,
+ period_by=period_by,
+ )
+
+
+@parametrize_with_cases("fixture,api", cases=".")
+def test_allele_frequencies_advanced_with_sample_query(
+ fixture,
+ api: AnophelesSnpFrequencyAnalysis,
+):
+ all_sample_sets = api.sample_sets()["sample_set"].to_list()
+ df_samples = api.sample_metadata(sample_sets=all_sample_sets)
+ countries = df_samples["country"].unique()
+ country = random.choice(countries)
+ sample_query = f"country == '{country}'"
+
+ check_snp_allele_frequencies_advanced(
+ api=api,
+ sample_sets=all_sample_sets,
+ sample_query=sample_query,
+ min_cohort_size=0,
+ )
+ check_aa_allele_frequencies_advanced(
+ api=api,
+ sample_sets=all_sample_sets,
+ sample_query=sample_query,
+ min_cohort_size=0,
+ )
+
+
+@pytest.mark.parametrize("min_cohort_size", [0, 10, 100])
+@parametrize_with_cases("fixture,api", cases=".")
+def test_allele_frequencies_advanced_with_min_cohort_size(
+ fixture,
+ api: AnophelesSnpFrequencyAnalysis,
+ min_cohort_size,
+):
+ all_sample_sets = api.sample_sets()["sample_set"].to_list()
+ area_by = "admin1_iso"
+ period_by = "year"
+ transcript = random_transcript(api=api).name
+
+ if min_cohort_size <= 10:
+ # Expect this to find at least one cohort, so go ahead with full
+ # checks.
+ check_snp_allele_frequencies_advanced(
+ api=api,
+ transcript=transcript,
+ sample_sets=all_sample_sets,
+ min_cohort_size=min_cohort_size,
+ area_by=area_by,
+ period_by=period_by,
+ )
+ check_aa_allele_frequencies_advanced(
+ api=api,
+ transcript=transcript,
+ sample_sets=all_sample_sets,
+ min_cohort_size=min_cohort_size,
+ area_by=area_by,
+ period_by=period_by,
+ )
+ else:
+ # Expect this to find no cohorts.
+ with pytest.raises(ValueError):
+ api.snp_allele_frequencies_advanced(
+ transcript=transcript,
+ sample_sets=all_sample_sets,
+ min_cohort_size=min_cohort_size,
+ area_by=area_by,
+ period_by=period_by,
+ )
+ with pytest.raises(ValueError):
+ api.aa_allele_frequencies_advanced(
+ transcript=transcript,
+ sample_sets=all_sample_sets,
+ min_cohort_size=min_cohort_size,
+ area_by=area_by,
+ period_by=period_by,
+ )
+
+
+@parametrize_with_cases("fixture,api", cases=".")
+def test_allele_frequencies_advanced_with_variant_query(
+ fixture,
+ api: AnophelesSnpFrequencyAnalysis,
+):
+ all_sample_sets = api.sample_sets()["sample_set"].to_list()
+ area_by = "admin1_iso"
+ period_by = "year"
+ transcript = random_transcript(api=api).name
+
+ # Test a query that should succeed.
+ variant_query = "effect == 'NON_SYNONYMOUS_CODING'"
+ check_snp_allele_frequencies_advanced(
+ api=api,
+ transcript=transcript,
+ sample_sets=all_sample_sets,
+ area_by=area_by,
+ period_by=period_by,
+ variant_query=variant_query,
+ )
+ check_aa_allele_frequencies_advanced(
+ api=api,
+ transcript=transcript,
+ sample_sets=all_sample_sets,
+ area_by=area_by,
+ period_by=period_by,
+ variant_query=variant_query,
+ )
+
+ # Test a query that should fail.
+ variant_query = "effect == 'foobar'"
+ with pytest.raises(ValueError):
+ api.snp_allele_frequencies_advanced(
+ transcript=transcript,
+ sample_sets=all_sample_sets,
+ area_by=area_by,
+ period_by=period_by,
+ variant_query=variant_query,
+ )
+ with pytest.raises(ValueError):
+ api.aa_allele_frequencies_advanced(
+ transcript=transcript,
+ sample_sets=all_sample_sets,
+ area_by=area_by,
+ period_by=period_by,
+ variant_query=variant_query,
+ )
+
+
+@pytest.mark.parametrize("nobs_mode", ["called", "fixed"])
+@parametrize_with_cases("fixture,api", cases=".")
+def test_allele_frequencies_advanced_with_nobs_mode(
+ fixture,
+ api: AnophelesSnpFrequencyAnalysis,
+ nobs_mode,
+):
+ check_snp_allele_frequencies_advanced(
+ api=api,
+ nobs_mode=nobs_mode,
+ )
+ check_aa_allele_frequencies_advanced(
+ api=api,
+ nobs_mode=nobs_mode,
+ )
+
+
+@parametrize_with_cases("fixture,api", cases=".")
+def test_allele_frequencies_advanced_with_dup_samples(
+ fixture,
+ api: AnophelesSnpFrequencyAnalysis,
+):
+ all_sample_sets = api.sample_sets()["sample_set"].to_list()
+ sample_set = random.choice(all_sample_sets)
+ sample_sets = [sample_set, sample_set]
+
+ check_snp_allele_frequencies_advanced(
+ api=api,
+ sample_sets=sample_sets,
+ )
+ check_aa_allele_frequencies_advanced(
+ api=api,
+ sample_sets=sample_sets,
+ )
+
+
+@parametrize_with_cases("fixture,api", cases=".")
+def test_plot_frequencies_heatmap(
+ fixture,
+ api: AnophelesSnpFrequencyAnalysis,
+):
+ # Pick test parameters at random.
+ all_sample_sets = api.sample_sets()["sample_set"].to_list()
+ sample_sets = random.choice(all_sample_sets)
+ site_mask = random.choice(api.site_mask_ids + (None,))
+ min_cohort_size = random.randint(0, 2)
+ transcript = random_transcript(api=api).name
+ cohorts = random.choice(
+ ["admin1_year", "admin1_month", "admin2_year", "admin2_month"]
+ )
+
+ # Set up call params.
+ params = dict(
+ transcript=transcript,
+ cohorts=cohorts,
+ min_cohort_size=min_cohort_size,
+ site_mask=site_mask,
+ sample_sets=sample_sets,
+ )
+
+ # Test SNP allele frequencies.
+ df_snp = api.snp_allele_frequencies(**params)
+ fig = api.plot_frequencies_heatmap(df_snp, show=False, max_len=None)
+ assert isinstance(fig, go.Figure)
+
+ # Test amino acid change allele frequencies.
+ df_aa = api.aa_allele_frequencies(**params)
+ fig = api.plot_frequencies_heatmap(df_aa, show=False, max_len=None)
+ assert isinstance(fig, go.Figure)
+
+ # Test max_len behaviour.
+ with pytest.raises(ValueError):
+ api.plot_frequencies_heatmap(df_snp, show=False, max_len=len(df_snp) - 1)
+
+ # Test index parameter - if None, should use dataframe index.
+ fig = api.plot_frequencies_heatmap(df_snp, show=False, index=None, max_len=None)
+ # Not unique.
+ with pytest.raises(ValueError):
+ api.plot_frequencies_heatmap(df_snp, show=False, index="contig", max_len=None)
+
+
+@parametrize_with_cases("fixture,api", cases=".")
+def test_plot_frequencies_time_series(
+ fixture,
+ api: AnophelesSnpFrequencyAnalysis,
+):
+ # Pick test parameters at random.
+ all_sample_sets = api.sample_sets()["sample_set"].to_list()
+ sample_sets = random.choice(all_sample_sets)
+ site_mask = random.choice(api.site_mask_ids + (None,))
+ min_cohort_size = random.randint(0, 2)
+ transcript = random_transcript(api=api).name
+ area_by = random.choice(["country", "admin1_iso", "admin2_name"])
+ period_by = random.choice(["year", "quarter", "month"])
+
+ # Compute SNP frequencies.
+ ds = api.snp_allele_frequencies_advanced(
+ transcript=transcript,
+ area_by=area_by,
+ period_by=period_by,
+ sample_sets=sample_sets,
+ min_cohort_size=min_cohort_size,
+ site_mask=site_mask,
+ )
+
+ # Trim things down a bit for speed.
+ ds = ds.isel(variants=slice(0, 100))
+
+ # Plot.
+ fig = api.plot_frequencies_time_series(ds, show=False)
+
+ # Test.
+ assert isinstance(fig, go.Figure)
+
+ # Compute amino acid change frequencies.
+ ds = api.aa_allele_frequencies_advanced(
+ transcript=transcript,
+ area_by=area_by,
+ period_by=period_by,
+ sample_sets=sample_sets,
+ min_cohort_size=min_cohort_size,
+ )
+
+ # Trim things down a bit for speed.
+ ds = ds.isel(variants=slice(0, 100))
+
+ # Plot.
+ fig = api.plot_frequencies_time_series(ds, show=False)
+
+ # Test.
+ assert isinstance(fig, go.Figure)
+
+
+@parametrize_with_cases("fixture,api", cases=".")
+def test_plot_frequencies_interactive_map(
+ fixture,
+ api: AnophelesSnpFrequencyAnalysis,
+):
+ import ipywidgets # type: ignore
+
+ # Pick test parameters at random.
+ all_sample_sets = api.sample_sets()["sample_set"].to_list()
+ sample_sets = random.choice(all_sample_sets)
+ site_mask = random.choice(api.site_mask_ids + (None,))
+ min_cohort_size = random.randint(0, 2)
+ transcript = random_transcript(api=api).name
+ area_by = random.choice(["country", "admin1_iso", "admin2_name"])
+ period_by = random.choice(["year", "quarter", "month"])
+
+ # Compute SNP frequencies.
+ ds = api.snp_allele_frequencies_advanced(
+ transcript=transcript,
+ area_by=area_by,
+ period_by=period_by,
+ sample_sets=sample_sets,
+ min_cohort_size=min_cohort_size,
+ site_mask=site_mask,
+ )
+
+ # Trim things down a bit for speed.
+ ds = ds.isel(variants=slice(0, 100))
+
+ # Plot.
+ fig = api.plot_frequencies_interactive_map(ds)
+
+ # Test.
+ assert isinstance(fig, ipywidgets.Widget)
+
+ # Compute amino acid change frequencies.
+ ds = api.aa_allele_frequencies_advanced(
+ transcript=transcript,
+ area_by=area_by,
+ period_by=period_by,
+ sample_sets=sample_sets,
+ min_cohort_size=min_cohort_size,
+ )
+
+ # Trim things down a bit for speed.
+ ds = ds.isel(variants=slice(0, 100))
+
+ # Plot.
+ fig = api.plot_frequencies_interactive_map(ds)
+
+ # Test.
+ assert isinstance(fig, ipywidgets.Widget)
diff --git a/tests/test_af1.py b/tests/test_af1.py
index b849225cd..2ef10dcc8 100644
--- a/tests/test_af1.py
+++ b/tests/test_af1.py
@@ -1,8 +1,6 @@
import numpy as np
-import pandas as pd
import pytest
-import xarray as xr
-from numpy.testing import assert_allclose, assert_array_equal
+from numpy.testing import assert_allclose
from malariagen_data import Af1, Region
from malariagen_data.util import locate_region, resolve_region
@@ -29,142 +27,6 @@ def test_repr():
assert isinstance(r, str)
-# TODO: test_snp_effects() for Af1.0
-# # reverse strand gene
-# gste2 = "LOC125761549"
-#
-# # test forward strand gene gste6
-# gste6 = "LOC125767311"
-#
-# # check 5' utr intron and the different intron effects
-# utr_intron5 = "utr_LOC125767311_t1_1"
-#
-# # check 3' utr intron
-# utr_intron3 = "utr_LOC125767311_t1_3"
-
-
-def test_snp_allele_frequencies__dict_cohorts():
- af1 = setup_af1(cohorts_analysis="20221129")
- cohorts = {
- "ke": "country == 'Kenya'",
- "gh_2017": "country == 'Ghana' and year == 2017",
- }
- universal_fields = [
- "pass_funestus",
- "label",
- ]
-
- # test drop invariants
- df = af1.snp_allele_frequencies(
- transcript="LOC125761549_t5",
- cohorts=cohorts,
- site_mask="funestus",
- sample_sets="1.0",
- drop_invariant=True,
- effects=False,
- )
-
- assert isinstance(df, pd.DataFrame)
- frq_columns = ["frq_" + s for s in list(cohorts.keys())]
- expected_fields = universal_fields + frq_columns + ["max_af"]
- assert sorted(df.columns.tolist()) == sorted(expected_fields)
- assert df.shape == (3185, len(expected_fields))
- assert df.iloc[2].frq_ke == 0
- assert df.iloc[2].frq_gh_2017 == pytest.approx(0.013889, abs=1e-6)
- assert df.iloc[2].max_af == pytest.approx(0.013889, abs=1e-6)
- # check invariant have been dropped
- assert df.max_af.min() > 0
-
- # test keep invariants
- df = af1.snp_allele_frequencies(
- transcript="LOC125761549_t7",
- cohorts=cohorts,
- site_mask="funestus",
- sample_sets="1.0",
- drop_invariant=False,
- effects=False,
- )
- assert isinstance(df, pd.DataFrame)
- assert sorted(df.columns.tolist()) == sorted(expected_fields)
- assert df.shape == (29418, len(expected_fields))
- # check invariant positions are still present
- assert np.any(df.max_af == 0)
-
-
-def test_snp_allele_frequencies__str_cohorts__effects():
- af1 = setup_af1(cohorts_analysis="20221129")
- cohorts = "admin1_month"
- min_cohort_size = 10
- universal_fields = [
- "pass_funestus",
- "label",
- ]
- effects_fields = [
- "transcript",
- "effect",
- "impact",
- "ref_codon",
- "alt_codon",
- "aa_pos",
- "ref_aa",
- "alt_aa",
- ]
- df = af1.snp_allele_frequencies(
- transcript="LOC125767311_t2",
- cohorts=cohorts,
- min_cohort_size=min_cohort_size,
- site_mask="funestus",
- sample_sets="1.0",
- drop_invariant=True,
- effects=True,
- )
- df_coh = af1.cohorts_metadata(sample_sets="1.0")
- coh_nm = "cohort_" + cohorts
- coh_counts = df_coh[coh_nm].dropna().value_counts()
- cohort_labels = coh_counts[coh_counts >= min_cohort_size].index.to_list()
- frq_cohort_labels = ["frq_" + s for s in cohort_labels]
- expected_fields = universal_fields + frq_cohort_labels + ["max_af"] + effects_fields
-
- assert isinstance(df, pd.DataFrame)
- assert len(df) == 4221
- assert sorted(df.columns.tolist()) == sorted(expected_fields)
- assert df.index.names == [
- "contig",
- "position",
- "ref_allele",
- "alt_allele",
- "aa_change",
- ]
-
-
-def test_snp_allele_frequencies__query():
- af1 = setup_af1(cohorts_analysis="20221129")
- cohorts = "admin1_year"
- min_cohort_size = 10
- expected_columns = [
- "pass_funestus",
- "frq_GH-AH_fune_2014",
- "frq_GH-NP_fune_2017",
- "max_af",
- "label",
- ]
-
- df = af1.snp_allele_frequencies(
- transcript="LOC125767311_t2",
- cohorts=cohorts,
- sample_query="country == 'Ghana'",
- min_cohort_size=min_cohort_size,
- site_mask="funestus",
- sample_sets="1.0",
- drop_invariant=True,
- effects=False,
- )
-
- assert isinstance(df, pd.DataFrame)
- assert sorted(df.columns) == sorted(expected_columns)
- assert len(df) == 1309
-
-
@pytest.mark.parametrize(
"region_raw",
[
@@ -216,470 +78,6 @@ def test_locate_region(region_raw):
assert region == Region("2RL", 24630355, 24633221)
-def test_aa_allele_frequencies():
- af1 = setup_af1(cohorts_analysis="20221129")
-
- expected_fields = [
- "transcript",
- "aa_pos",
- "ref_allele",
- "alt_allele",
- "ref_aa",
- "alt_aa",
- "effect",
- "impact",
- "frq_CD-HU_fune_2017",
- "frq_MZ-P_fune_2015",
- "max_af",
- "label",
- ]
-
- df = af1.aa_allele_frequencies(
- transcript="LOC125767311_t2",
- cohorts="admin1_year",
- min_cohort_size=10,
- site_mask="funestus",
- sample_sets=("1240-VO-CD-KOEKEMOER-VMF00099", "1240-VO-MZ-KOEKEMOER-VMF00101"),
- drop_invariant=True,
- )
-
- assert sorted(df.columns.tolist()) == sorted(expected_fields)
- assert isinstance(df, pd.DataFrame)
- assert df.index.names == ["aa_change", "contig", "position"]
- assert df.shape == (53, len(expected_fields))
- assert df.loc["V947L"].max_af[0] == pytest.approx(0.025, abs=1e-6)
-
-
-# noinspection PyDefaultArgument
-def _check_snp_allele_frequencies_advanced(
- transcript="LOC125767311_t2",
- area_by="admin1_iso",
- period_by="year",
- sample_sets=[
- "1229-VO-GH-DADZIE-VMF00095",
- "1240-VO-CD-KOEKEMOER-VMF00099",
- "1240-VO-MZ-KOEKEMOER-VMF00101",
- ],
- sample_query=None,
- min_cohort_size=10,
- nobs_mode="called",
- variant_query="max_af > 0.02",
-):
- af1 = setup_af1(cohorts_analysis="20221129")
-
- ds = af1.snp_allele_frequencies_advanced(
- transcript=transcript,
- area_by=area_by,
- period_by=period_by,
- sample_sets=sample_sets,
- sample_query=sample_query,
- min_cohort_size=min_cohort_size,
- nobs_mode=nobs_mode,
- variant_query=variant_query,
- )
-
- assert isinstance(ds, xr.Dataset)
-
- # noinspection PyTypeChecker
- assert sorted(ds.dims) == ["cohorts", "variants"]
-
- expected_variant_vars = (
- "variant_label",
- "variant_contig",
- "variant_position",
- "variant_ref_allele",
- "variant_alt_allele",
- "variant_max_af",
- "variant_pass_funestus",
- "variant_transcript",
- "variant_effect",
- "variant_impact",
- "variant_ref_codon",
- "variant_alt_codon",
- "variant_ref_aa",
- "variant_alt_aa",
- "variant_aa_pos",
- "variant_aa_change",
- )
- for v in expected_variant_vars:
- a = ds[v]
- assert isinstance(a, xr.DataArray)
- assert a.dims == ("variants",)
-
- expected_cohort_vars = (
- "cohort_label",
- "cohort_size",
- "cohort_taxon",
- "cohort_area",
- "cohort_period",
- "cohort_period_start",
- "cohort_period_end",
- "cohort_lat_mean",
- "cohort_lat_min",
- "cohort_lat_max",
- "cohort_lon_mean",
- "cohort_lon_min",
- "cohort_lon_max",
- )
- for v in expected_cohort_vars:
- a = ds[v]
- assert isinstance(a, xr.DataArray)
- assert a.dims == ("cohorts",)
-
- expected_event_vars = (
- "event_count",
- "event_nobs",
- "event_frequency",
- "event_frequency_ci_low",
- "event_frequency_ci_upp",
- )
- for v in expected_event_vars:
- a = ds[v]
- assert isinstance(a, xr.DataArray)
- assert a.dims == ("variants", "cohorts")
-
- # sanity checks for area values
- df_samples = af1.sample_metadata(sample_sets=sample_sets)
- if sample_query is not None:
- df_samples = df_samples.query(sample_query)
- expected_area = np.unique(df_samples[area_by].dropna().values)
- area = ds["cohort_area"].values
- # N.B., some areas may not end up in final dataset if cohort
- # size is too small, so do a set membership test
- for a in area:
- assert a in expected_area
-
- # sanity checks for period values
- period = ds["cohort_period"].values
- if period_by == "year":
- expected_freqstr = "A-DEC"
- elif period_by == "month":
- expected_freqstr = "M"
- elif period_by == "quarter":
- expected_freqstr = "Q-DEC"
- else:
- assert False, "not implemented"
- for p in period:
- assert isinstance(p, pd.Period)
- assert p.freqstr == expected_freqstr
-
- # sanity check cohort size
- size = ds["cohort_size"].values
- for s in size:
- assert s >= min_cohort_size
-
- if area_by == "admin1_iso" and period_by == "year" and nobs_mode == "called":
- # Here we test the behaviour of the function when grouping by admin level
- # 1 and year. We can do some more in-depth testing in this case because
- # we can compare results directly against the simpler snp_allele_frequencies()
- # function with the admin1_year cohorts.
-
- # check consistency with the basic snp allele frequencies method
- df_af = af1.snp_allele_frequencies(
- transcript=transcript,
- cohorts="admin1_year",
- sample_sets=sample_sets,
- sample_query=sample_query,
- min_cohort_size=min_cohort_size,
- )
- df_af = df_af.reset_index() # make sure all variables available to check
- if variant_query is not None:
- df_af = df_af.query(variant_query)
-
- # check cohorts are consistent
- expect_cohort_labels = sorted(
- [c.split("frq_")[1] for c in df_af.columns if c.startswith("frq_")]
- )
- cohort_labels = sorted(ds["cohort_label"].values)
- assert cohort_labels == expect_cohort_labels
-
- # check variants are consistent
- assert ds.sizes["variants"] == len(df_af)
- for v in expected_variant_vars:
- c = v.split("variant_")[1]
- actual = ds[v]
- expect = df_af[c]
- _compare_series_like(actual, expect)
-
- # check frequencies are consistent
- for cohort_index, cohort_label in enumerate(ds["cohort_label"].values):
- actual_frq = ds["event_frequency"].values[:, cohort_index]
- expect_frq = df_af[f"frq_{cohort_label}"].values
- assert_allclose(actual_frq, expect_frq)
-
-
-# noinspection PyDefaultArgument
-def _check_aa_allele_frequencies_advanced(
- transcript="LOC125767311_t2",
- area_by="admin1_iso",
- period_by="year",
- sample_sets=[
- "1229-VO-GH-DADZIE-VMF00095",
- "1240-VO-CD-KOEKEMOER-VMF00099",
- "1240-VO-MZ-KOEKEMOER-VMF00101",
- ],
- sample_query=None,
- min_cohort_size=10,
- nobs_mode="called",
- variant_query="max_af > 0.02",
-):
- af1 = setup_af1(cohorts_analysis="20221129")
-
- ds = af1.aa_allele_frequencies_advanced(
- transcript=transcript,
- area_by=area_by,
- period_by=period_by,
- sample_sets=sample_sets,
- sample_query=sample_query,
- min_cohort_size=min_cohort_size,
- nobs_mode=nobs_mode,
- variant_query=variant_query,
- )
-
- assert isinstance(ds, xr.Dataset)
-
- # noinspection PyTypeChecker
- assert sorted(ds.dims) == ["cohorts", "variants"]
-
- expected_variant_vars = (
- "variant_label",
- "variant_contig",
- "variant_position",
- "variant_max_af",
- "variant_transcript",
- "variant_effect",
- "variant_impact",
- "variant_ref_aa",
- "variant_alt_aa",
- "variant_aa_pos",
- "variant_aa_change",
- )
- for v in expected_variant_vars:
- a = ds[v]
- assert isinstance(a, xr.DataArray)
- assert a.dims == ("variants",)
-
- expected_cohort_vars = (
- "cohort_label",
- "cohort_size",
- "cohort_taxon",
- "cohort_area",
- "cohort_period",
- "cohort_period_start",
- "cohort_period_end",
- "cohort_lat_mean",
- "cohort_lat_min",
- "cohort_lat_max",
- "cohort_lon_mean",
- "cohort_lon_min",
- "cohort_lon_max",
- )
- for v in expected_cohort_vars:
- a = ds[v]
- assert isinstance(a, xr.DataArray)
- assert a.dims == ("cohorts",)
-
- expected_event_vars = (
- "event_count",
- "event_nobs",
- "event_frequency",
- "event_frequency_ci_low",
- "event_frequency_ci_upp",
- )
- for v in expected_event_vars:
- a = ds[v]
- assert isinstance(a, xr.DataArray)
- assert a.dims == ("variants", "cohorts")
-
- # sanity checks for area values
- df_samples = af1.sample_metadata(sample_sets=sample_sets)
- if sample_query is not None:
- df_samples = df_samples.query(sample_query)
- expected_area = np.unique(df_samples[area_by].dropna().values)
- area = ds["cohort_area"].values
- # N.B., some areas may not end up in final dataset if cohort
- # size is too small, so do a set membership test
- for a in area:
- assert a in expected_area
-
- # sanity checks for period values
- period = ds["cohort_period"].values
- if period_by == "year":
- expected_freqstr = "A-DEC"
- elif period_by == "month":
- expected_freqstr = "M"
- elif period_by == "quarter":
- expected_freqstr = "Q-DEC"
- else:
- assert False, "not implemented"
- for p in period:
- assert isinstance(p, pd.Period)
- assert p.freqstr == expected_freqstr
-
- # sanity check cohort size
- size = ds["cohort_size"].values
- for s in size:
- assert s >= min_cohort_size
-
- if area_by == "admin1_iso" and period_by == "year" and nobs_mode == "called":
- # Here we test the behaviour of the function when grouping by admin level
- # 1 and year. We can do some more in-depth testing in this case because
- # we can compare results directly against the simpler aa_allele_frequencies()
- # function with the admin1_year cohorts.
-
- # check consistency with the basic snp allele frequencies method
- df_af = af1.aa_allele_frequencies(
- transcript=transcript,
- cohorts="admin1_year",
- sample_sets=sample_sets,
- sample_query=sample_query,
- min_cohort_size=min_cohort_size,
- )
- df_af = df_af.reset_index() # make sure all variables available to check
- if variant_query is not None:
- df_af = df_af.query(variant_query)
-
- # check cohorts are consistent
- expect_cohort_labels = sorted(
- [c.split("frq_")[1] for c in df_af.columns if c.startswith("frq_")]
- )
- cohort_labels = sorted(ds["cohort_label"].values)
- assert cohort_labels == expect_cohort_labels
-
- # check variants are consistent
- assert ds.sizes["variants"] == len(df_af)
- for v in expected_variant_vars:
- c = v.split("variant_")[1]
- actual = ds[v]
- expect = df_af[c]
- _compare_series_like(actual, expect)
-
- # check frequencies are consistent
- for cohort_index, cohort_label in enumerate(ds["cohort_label"].values):
- print(cohort_label)
- actual_frq = ds["event_frequency"].values[:, cohort_index]
- expect_frq = df_af[f"frq_{cohort_label}"].values
- assert_allclose(actual_frq, expect_frq)
-
-
-# Here we don't explore the full matrix, but vary one parameter at a time, otherwise
-# the test suite would take too long to run.
-
-
-@pytest.mark.parametrize(
- "transcript", ["LOC125767311_t2", "LOC125761549_t5", "LOC125761549_t7"]
-)
-def test_allele_frequencies_advanced__transcript(transcript):
- _check_snp_allele_frequencies_advanced(
- transcript=transcript,
- )
- _check_aa_allele_frequencies_advanced(
- transcript=transcript,
- )
-
-
-@pytest.mark.parametrize("area_by", ["country", "admin1_iso", "admin2_name"])
-def test_allele_frequencies_advanced__area_by(area_by):
- _check_snp_allele_frequencies_advanced(
- area_by=area_by,
- )
- _check_aa_allele_frequencies_advanced(
- area_by=area_by,
- )
-
-
-@pytest.mark.parametrize("period_by", ["year", "quarter", "month"])
-def test_allele_frequencies_advanced__period_by(period_by):
- _check_snp_allele_frequencies_advanced(
- period_by=period_by,
- )
- _check_aa_allele_frequencies_advanced(
- period_by=period_by,
- )
-
-
-@pytest.mark.parametrize(
- "sample_sets",
- [
- "1229-VO-GH-DADZIE-VMF00095",
- ["1240-VO-CD-KOEKEMOER-VMF00099", "1240-VO-MZ-KOEKEMOER-VMF00101"],
- "1.0",
- ],
-)
-def test_allele_frequencies_advanced__sample_sets(sample_sets):
- _check_snp_allele_frequencies_advanced(
- sample_sets=sample_sets,
- )
- _check_aa_allele_frequencies_advanced(
- sample_sets=sample_sets,
- )
-
-
-def test_allele_frequencies_advanced__sample_query():
- _check_snp_allele_frequencies_advanced(
- sample_query="taxon == 'funestus' and country in ['Ghana', 'Gabon']",
- )
- # noinspection PyTypeChecker
- _check_aa_allele_frequencies_advanced(
- sample_query="taxon == 'funestus' and country in ['Ghana', 'Gabon']",
- variant_query=None,
- )
-
-
-@pytest.mark.parametrize("min_cohort_size", [10, 40])
-def test_allele_frequencies_advanced__min_cohort_size(min_cohort_size):
- _check_snp_allele_frequencies_advanced(
- min_cohort_size=min_cohort_size,
- )
- _check_aa_allele_frequencies_advanced(
- min_cohort_size=min_cohort_size,
- )
-
-
-@pytest.mark.parametrize(
- "variant_query",
- [
- None,
- "effect == 'NON_SYNONYMOUS_CODING' and max_af > 0.05",
- "effect == 'foobar'", # no variants
- ],
-)
-def test_allele_frequencies_advanced__variant_query(variant_query):
- _check_snp_allele_frequencies_advanced(
- variant_query=variant_query,
- )
- _check_aa_allele_frequencies_advanced(
- variant_query=variant_query,
- )
-
-
-@pytest.mark.parametrize("nobs_mode", ["called", "fixed"])
-def test_allele_frequencies_advanced__nobs_mode(nobs_mode):
- _check_snp_allele_frequencies_advanced(
- nobs_mode=nobs_mode,
- )
- _check_aa_allele_frequencies_advanced(
- nobs_mode=nobs_mode,
- )
-
-
-# TODO: this function is a verbatim duplicate, from test_ag3.py
-def _compare_series_like(actual, expect):
- # compare pandas series-like objects for equality or floating point
- # similarity, handling missing values appropriately
-
- # handle object arrays, these don't get nans compared properly
- t = actual.dtype
- if t == object:
- expect = expect.fillna("NA")
- actual = actual.fillna("NA")
-
- if t.kind == "f":
- assert_allclose(actual.values, expect.values)
- else:
- assert_array_equal(actual.values, expect.values)
-
-
def test_h12_gwss():
af1 = setup_af1(cohorts_analysis="20230823")
sample_query = "country == 'Ghana'"
diff --git a/tests/test_ag3.py b/tests/test_ag3.py
index c20c5b617..feee66add 100644
--- a/tests/test_ag3.py
+++ b/tests/test_ag3.py
@@ -10,7 +10,7 @@
from malariagen_data import Ag3, Region
from malariagen_data.anopheles import _cn_mode
-from malariagen_data.util import locate_region, resolve_region
+from malariagen_data.util import locate_region, resolve_region, compare_series_like
contigs = "2R", "2L", "3R", "3L", "X"
@@ -54,238 +54,6 @@ def test_cross_metadata():
assert df_crosses["sex"].unique().tolist() == expected_sex_values
-def test_snp_effects():
- ag3 = setup_ag3()
- gste2 = "AGAP009194-RA"
- site_mask = "gamb_colu"
- expected_fields = [
- "contig",
- "position",
- "ref_allele",
- "alt_allele",
- "pass_gamb_colu_arab",
- "pass_gamb_colu",
- "pass_arab",
- "transcript",
- "effect",
- "impact",
- "ref_codon",
- "alt_codon",
- "aa_pos",
- "ref_aa",
- "alt_aa",
- "aa_change",
- ]
-
- df = ag3.snp_effects(transcript=gste2, site_mask=site_mask)
- assert isinstance(df, pd.DataFrame)
- assert df.columns.tolist() == expected_fields
-
- # reverse strand gene
- assert df.shape == (2838, len(expected_fields))
- # check first, second, third codon position non-syn
- assert df.iloc[1454].aa_change == "I114L"
- assert df.iloc[1446].aa_change == "I114M"
- # while we are here, check all columns for a position
- assert df.iloc[1451].position == 28598166
- assert df.iloc[1451].ref_allele == "A"
- assert df.iloc[1451].alt_allele == "G"
- assert df.iloc[1451].effect == "NON_SYNONYMOUS_CODING"
- assert df.iloc[1451].impact == "MODERATE"
- assert df.iloc[1451].ref_codon == "aTt"
- assert df.iloc[1451].alt_codon == "aCt"
- assert df.iloc[1451].aa_pos == 114
- assert df.iloc[1451].ref_aa == "I"
- assert df.iloc[1451].alt_aa == "T"
- assert df.iloc[1451].aa_change == "I114T"
- # check syn
- assert df.iloc[1447].aa_change == "I114I"
- # check intronic
- assert df.iloc[1197].effect == "INTRONIC"
- # check 5' utr
- assert df.iloc[2661].effect == "FIVE_PRIME_UTR"
- # check 3' utr
- assert df.iloc[0].effect == "THREE_PRIME_UTR"
-
- # test forward strand gene gste6
- gste6 = "AGAP009196-RA"
- df = ag3.snp_effects(transcript=gste6, site_mask=site_mask)
- assert isinstance(df, pd.DataFrame)
- assert df.columns.tolist() == expected_fields
- assert df.shape == (2829, len(expected_fields))
-
- # check first, second, third codon position non-syn
- assert df.iloc[701].aa_change == "E35*"
- assert df.iloc[703].aa_change == "E35V"
- # while we are here, check all columns for a position
- assert df.iloc[706].position == 28600605
- assert df.iloc[706].ref_allele == "G"
- assert df.iloc[706].alt_allele == "C"
- assert df.iloc[706].effect == "NON_SYNONYMOUS_CODING"
- assert df.iloc[706].impact == "MODERATE"
- assert df.iloc[706].ref_codon == "gaG"
- assert df.iloc[706].alt_codon == "gaC"
- assert df.iloc[706].aa_pos == 35
- assert df.iloc[706].ref_aa == "E"
- assert df.iloc[706].alt_aa == "D"
- assert df.iloc[706].aa_change == "E35D"
- # check syn
- assert df.iloc[705].aa_change == "E35E"
- # check intronic
- assert df.iloc[900].effect == "INTRONIC"
- # check 5' utr
- assert df.iloc[0].effect == "FIVE_PRIME_UTR"
- # check 3' utr
- assert df.iloc[2828].effect == "THREE_PRIME_UTR"
-
- # check 5' utr intron and the different intron effects
- utr_intron5 = "AGAP004679-RB"
- df = ag3.snp_effects(transcript=utr_intron5, site_mask=site_mask)
- assert isinstance(df, pd.DataFrame)
- assert df.columns.tolist() == expected_fields
- assert df.shape == (7686, len(expected_fields))
- assert df.iloc[180].effect == "SPLICE_CORE"
- assert df.iloc[198].effect == "SPLICE_REGION"
- assert df.iloc[202].effect == "INTRONIC"
-
- # check 3' utr intron
- utr_intron3 = "AGAP000689-RA"
- df = ag3.snp_effects(transcript=utr_intron3, site_mask=site_mask)
- assert isinstance(df, pd.DataFrame)
- assert df.columns.tolist() == expected_fields
- assert df.shape == (5397, len(expected_fields))
- assert df.iloc[646].effect == "SPLICE_CORE"
- assert df.iloc[652].effect == "SPLICE_REGION"
- assert df.iloc[674].effect == "INTRONIC"
-
-
-def test_snp_allele_frequencies__dict_cohorts():
- ag3 = setup_ag3(cohorts_analysis="20230516")
- cohorts = {
- "ke": "country == 'Kenya'",
- "bf_2012_col": "country == 'Burkina Faso' and year == 2012 and aim_species == 'coluzzii'",
- }
- universal_fields = [
- "pass_gamb_colu_arab",
- "pass_gamb_colu",
- "pass_arab",
- "label",
- ]
-
- # test drop invariants
- df = ag3.snp_allele_frequencies(
- transcript="AGAP009194-RA",
- cohorts=cohorts,
- site_mask="gamb_colu",
- sample_sets="3.0",
- drop_invariant=True,
- effects=False,
- )
-
- assert isinstance(df, pd.DataFrame)
- frq_columns = ["frq_" + s for s in list(cohorts.keys())]
- expected_fields = universal_fields + frq_columns + ["max_af"]
- assert sorted(df.columns.tolist()) == sorted(expected_fields)
- assert df.shape == (133, len(expected_fields))
- assert df.iloc[3].frq_ke == 0
- assert df.iloc[4].frq_bf_2012_col == pytest.approx(0.006097, abs=1e-6)
- assert df.iloc[4].max_af == pytest.approx(0.006097, abs=1e-6)
- # check invariant have been dropped
- assert df.max_af.min() > 0
-
- # test keep invariants
- df = ag3.snp_allele_frequencies(
- transcript="AGAP004707-RD",
- cohorts=cohorts,
- site_mask="gamb_colu",
- sample_sets="3.0",
- drop_invariant=False,
- effects=False,
- )
- assert isinstance(df, pd.DataFrame)
- assert sorted(df.columns.tolist()) == sorted(expected_fields)
- assert df.shape == (132306, len(expected_fields))
- # check invariant positions are still present
- assert np.any(df.max_af == 0)
-
-
-def test_snp_allele_frequencies__str_cohorts__effects():
- ag3 = setup_ag3(cohorts_analysis="20230516")
- cohorts = "admin1_month"
- min_cohort_size = 10
- universal_fields = [
- "pass_gamb_colu_arab",
- "pass_gamb_colu",
- "pass_arab",
- "label",
- ]
- effects_fields = [
- "transcript",
- "effect",
- "impact",
- "ref_codon",
- "alt_codon",
- "aa_pos",
- "ref_aa",
- "alt_aa",
- ]
- df = ag3.snp_allele_frequencies(
- transcript="AGAP004707-RD",
- cohorts=cohorts,
- min_cohort_size=min_cohort_size,
- site_mask="gamb_colu",
- sample_sets="3.0",
- drop_invariant=True,
- effects=True,
- )
- df_coh = ag3.cohorts_metadata(sample_sets="3.0")
- coh_nm = "cohort_" + cohorts
- coh_counts = df_coh[coh_nm].dropna().value_counts()
- cohort_labels = coh_counts[coh_counts >= min_cohort_size].index.to_list()
- frq_cohort_labels = ["frq_" + s for s in cohort_labels]
- expected_fields = universal_fields + frq_cohort_labels + ["max_af"] + effects_fields
-
- assert isinstance(df, pd.DataFrame)
- assert len(df) == 16641
- assert sorted(df.columns.tolist()) == sorted(expected_fields)
- assert df.index.names == [
- "contig",
- "position",
- "ref_allele",
- "alt_allele",
- "aa_change",
- ]
-
-
-def test_snp_allele_frequencies__query():
- ag3 = setup_ag3(cohorts_analysis="20230516")
- cohorts = "admin1_year"
- min_cohort_size = 10
- expected_columns = [
- "pass_gamb_colu_arab",
- "pass_gamb_colu",
- "pass_arab",
- "frq_AO-LUA_colu_2009",
- "max_af",
- "label",
- ]
-
- df = ag3.snp_allele_frequencies(
- transcript="AGAP004707-RD",
- cohorts=cohorts,
- sample_query="country == 'Angola'",
- min_cohort_size=min_cohort_size,
- site_mask="gamb_colu",
- sample_sets="3.0",
- drop_invariant=True,
- effects=False,
- )
-
- assert isinstance(df, pd.DataFrame)
- assert sorted(df.columns) == sorted(expected_columns)
- assert len(df) == 695
-
-
@pytest.mark.parametrize("rows", [10, 100, 1000])
@pytest.mark.parametrize("cols", [10, 100, 1000])
@pytest.mark.parametrize("vmax", [2, 12, 100])
@@ -716,450 +484,6 @@ def test_locate_region(region_raw):
assert region == Region("2R", 24630355, 24633221)
-def test_aa_allele_frequencies():
- ag3 = setup_ag3(cohorts_analysis="20230516")
-
- expected_fields = [
- "transcript",
- "aa_pos",
- "ref_allele",
- "alt_allele",
- "ref_aa",
- "alt_aa",
- "effect",
- "impact",
- "frq_BF-09_gamb_2012",
- "frq_BF-09_colu_2012",
- "frq_BF-09_colu_2014",
- "frq_BF-09_gamb_2014",
- "frq_BF-07_gamb_2004",
- "max_af",
- "label",
- ]
-
- df = ag3.aa_allele_frequencies(
- transcript="AGAP004707-RD",
- cohorts="admin1_year",
- min_cohort_size=10,
- site_mask="gamb_colu",
- sample_sets=("AG1000G-BF-A", "AG1000G-BF-B", "AG1000G-BF-C"),
- drop_invariant=True,
- )
-
- assert sorted(df.columns.tolist()) == sorted(expected_fields)
- assert isinstance(df, pd.DataFrame)
- assert df.index.names == ["aa_change", "contig", "position"]
- assert df.shape == (61, len(expected_fields))
- assert df.loc["V402L"].max_af[0] == pytest.approx(0.121951, abs=1e-6)
-
-
-# noinspection PyDefaultArgument
-def _check_snp_allele_frequencies_advanced(
- transcript="AGAP004707-RD",
- area_by="admin1_iso",
- period_by="year",
- sample_sets=["AG1000G-BF-A", "AG1000G-ML-A", "AG1000G-UG"],
- sample_query=None,
- min_cohort_size=10,
- nobs_mode="called",
- variant_query="max_af > 0.02",
-):
- ag3 = setup_ag3(cohorts_analysis="20230516")
-
- ds = ag3.snp_allele_frequencies_advanced(
- transcript=transcript,
- area_by=area_by,
- period_by=period_by,
- sample_sets=sample_sets,
- sample_query=sample_query,
- min_cohort_size=min_cohort_size,
- nobs_mode=nobs_mode,
- variant_query=variant_query,
- )
-
- assert isinstance(ds, xr.Dataset)
-
- # noinspection PyTypeChecker
- assert sorted(ds.dims) == ["cohorts", "variants"]
-
- expected_variant_vars = (
- "variant_label",
- "variant_contig",
- "variant_position",
- "variant_ref_allele",
- "variant_alt_allele",
- "variant_max_af",
- "variant_pass_gamb_colu_arab",
- "variant_pass_gamb_colu",
- "variant_pass_arab",
- "variant_transcript",
- "variant_effect",
- "variant_impact",
- "variant_ref_codon",
- "variant_alt_codon",
- "variant_ref_aa",
- "variant_alt_aa",
- "variant_aa_pos",
- "variant_aa_change",
- )
- for v in expected_variant_vars:
- a = ds[v]
- assert isinstance(a, xr.DataArray)
- assert a.dims == ("variants",)
-
- expected_cohort_vars = (
- "cohort_label",
- "cohort_size",
- "cohort_taxon",
- "cohort_area",
- "cohort_period",
- "cohort_period_start",
- "cohort_period_end",
- "cohort_lat_mean",
- "cohort_lat_min",
- "cohort_lat_max",
- "cohort_lon_mean",
- "cohort_lon_min",
- "cohort_lon_max",
- )
- for v in expected_cohort_vars:
- a = ds[v]
- assert isinstance(a, xr.DataArray)
- assert a.dims == ("cohorts",)
-
- expected_event_vars = (
- "event_count",
- "event_nobs",
- "event_frequency",
- "event_frequency_ci_low",
- "event_frequency_ci_upp",
- )
- for v in expected_event_vars:
- a = ds[v]
- assert isinstance(a, xr.DataArray)
- assert a.dims == ("variants", "cohorts")
-
- # sanity checks for area values
- df_samples = ag3.sample_metadata(sample_sets=sample_sets)
- if sample_query is not None:
- df_samples = df_samples.query(sample_query)
- expected_area = np.unique(df_samples[area_by].dropna().values)
- area = ds["cohort_area"].values
- # N.B., some areas may not end up in final dataset if cohort
- # size is too small, so do a set membership test
- for a in area:
- assert a in expected_area
-
- # sanity checks for period values
- period = ds["cohort_period"].values
- if period_by == "year":
- expected_freqstr = "A-DEC"
- elif period_by == "month":
- expected_freqstr = "M"
- elif period_by == "quarter":
- expected_freqstr = "Q-DEC"
- else:
- assert False, "not implemented"
- for p in period:
- assert isinstance(p, pd.Period)
- assert p.freqstr == expected_freqstr
-
- # sanity check cohort size
- size = ds["cohort_size"].values
- for s in size:
- assert s >= min_cohort_size
-
- if area_by == "admin1_iso" and period_by == "year" and nobs_mode == "called":
- # Here we test the behaviour of the function when grouping by admin level
- # 1 and year. We can do some more in-depth testing in this case because
- # we can compare results directly against the simpler snp_allele_frequencies()
- # function with the admin1_year cohorts.
-
- # check consistency with the basic snp allele frequencies method
- df_af = ag3.snp_allele_frequencies(
- transcript=transcript,
- cohorts="admin1_year",
- sample_sets=sample_sets,
- sample_query=sample_query,
- min_cohort_size=min_cohort_size,
- )
- df_af = df_af.reset_index() # make sure all variables available to check
- if variant_query is not None:
- df_af = df_af.query(variant_query)
-
- # check cohorts are consistent
- expect_cohort_labels = sorted(
- [c.split("frq_")[1] for c in df_af.columns if c.startswith("frq_")]
- )
- cohort_labels = sorted(ds["cohort_label"].values)
- assert cohort_labels == expect_cohort_labels
-
- # check variants are consistent
- assert ds.sizes["variants"] == len(df_af)
- for v in expected_variant_vars:
- c = v.split("variant_")[1]
- actual = ds[v]
- expect = df_af[c]
- _compare_series_like(actual, expect)
-
- # check frequencies are consistent
- for cohort_index, cohort_label in enumerate(ds["cohort_label"].values):
- actual_frq = ds["event_frequency"].values[:, cohort_index]
- expect_frq = df_af[f"frq_{cohort_label}"].values
- assert_allclose(actual_frq, expect_frq)
-
-
-# noinspection PyDefaultArgument
-def _check_aa_allele_frequencies_advanced(
- transcript="AGAP004707-RD",
- area_by="admin1_iso",
- period_by="year",
- sample_sets=["AG1000G-BF-A", "AG1000G-ML-A", "AG1000G-UG"],
- sample_query=None,
- min_cohort_size=10,
- nobs_mode="called",
- variant_query="max_af > 0.02",
-):
- ag3 = setup_ag3(cohorts_analysis="20230516")
-
- ds = ag3.aa_allele_frequencies_advanced(
- transcript=transcript,
- area_by=area_by,
- period_by=period_by,
- sample_sets=sample_sets,
- sample_query=sample_query,
- min_cohort_size=min_cohort_size,
- nobs_mode=nobs_mode,
- variant_query=variant_query,
- )
-
- assert isinstance(ds, xr.Dataset)
-
- # noinspection PyTypeChecker
- assert sorted(ds.dims) == ["cohorts", "variants"]
-
- expected_variant_vars = (
- "variant_label",
- "variant_contig",
- "variant_position",
- "variant_max_af",
- "variant_transcript",
- "variant_effect",
- "variant_impact",
- "variant_ref_aa",
- "variant_alt_aa",
- "variant_aa_pos",
- "variant_aa_change",
- )
- for v in expected_variant_vars:
- a = ds[v]
- assert isinstance(a, xr.DataArray)
- assert a.dims == ("variants",)
-
- expected_cohort_vars = (
- "cohort_label",
- "cohort_size",
- "cohort_taxon",
- "cohort_area",
- "cohort_period",
- "cohort_period_start",
- "cohort_period_end",
- "cohort_lat_mean",
- "cohort_lat_min",
- "cohort_lat_max",
- "cohort_lon_mean",
- "cohort_lon_min",
- "cohort_lon_max",
- )
- for v in expected_cohort_vars:
- a = ds[v]
- assert isinstance(a, xr.DataArray)
- assert a.dims == ("cohorts",)
-
- expected_event_vars = (
- "event_count",
- "event_nobs",
- "event_frequency",
- "event_frequency_ci_low",
- "event_frequency_ci_upp",
- )
- for v in expected_event_vars:
- a = ds[v]
- assert isinstance(a, xr.DataArray)
- assert a.dims == ("variants", "cohorts")
-
- # sanity checks for area values
- df_samples = ag3.sample_metadata(sample_sets=sample_sets)
- if sample_query is not None:
- df_samples = df_samples.query(sample_query)
- expected_area = np.unique(df_samples[area_by].dropna().values)
- area = ds["cohort_area"].values
- # N.B., some areas may not end up in final dataset if cohort
- # size is too small, so do a set membership test
- for a in area:
- assert a in expected_area
-
- # sanity checks for period values
- period = ds["cohort_period"].values
- if period_by == "year":
- expected_freqstr = "A-DEC"
- elif period_by == "month":
- expected_freqstr = "M"
- elif period_by == "quarter":
- expected_freqstr = "Q-DEC"
- else:
- assert False, "not implemented"
- for p in period:
- assert isinstance(p, pd.Period)
- assert p.freqstr == expected_freqstr
-
- # sanity check cohort size
- size = ds["cohort_size"].values
- for s in size:
- assert s >= min_cohort_size
-
- if area_by == "admin1_iso" and period_by == "year" and nobs_mode == "called":
- # Here we test the behaviour of the function when grouping by admin level
- # 1 and year. We can do some more in-depth testing in this case because
- # we can compare results directly against the simpler aa_allele_frequencies()
- # function with the admin1_year cohorts.
-
- # check consistency with the basic snp allele frequencies method
- df_af = ag3.aa_allele_frequencies(
- transcript=transcript,
- cohorts="admin1_year",
- sample_sets=sample_sets,
- sample_query=sample_query,
- min_cohort_size=min_cohort_size,
- )
- df_af = df_af.reset_index() # make sure all variables available to check
- if variant_query is not None:
- df_af = df_af.query(variant_query)
-
- # check cohorts are consistent
- expect_cohort_labels = sorted(
- [c.split("frq_")[1] for c in df_af.columns if c.startswith("frq_")]
- )
- cohort_labels = sorted(ds["cohort_label"].values)
- assert cohort_labels == expect_cohort_labels
-
- # check variants are consistent
- assert ds.sizes["variants"] == len(df_af)
- for v in expected_variant_vars:
- c = v.split("variant_")[1]
- actual = ds[v]
- expect = df_af[c]
- _compare_series_like(actual, expect)
-
- # check frequencies are consistent
- for cohort_index, cohort_label in enumerate(ds["cohort_label"].values):
- print(cohort_label)
- actual_frq = ds["event_frequency"].values[:, cohort_index]
- expect_frq = df_af[f"frq_{cohort_label}"].values
- assert_allclose(actual_frq, expect_frq)
-
-
-# Here we don't explore the full matrix, but vary one parameter at a time, otherwise
-# the test suite would take too long to run.
-
-
-@pytest.mark.parametrize("transcript", ["AGAP004707-RD", "AGAP006028-RA"])
-def test_allele_frequencies_advanced__transcript(transcript):
- _check_snp_allele_frequencies_advanced(
- transcript=transcript,
- )
- _check_aa_allele_frequencies_advanced(
- transcript=transcript,
- )
-
-
-@pytest.mark.parametrize("area_by", ["country", "admin1_iso", "admin2_name"])
-def test_allele_frequencies_advanced__area_by(area_by):
- _check_snp_allele_frequencies_advanced(
- area_by=area_by,
- )
- _check_aa_allele_frequencies_advanced(
- area_by=area_by,
- )
-
-
-@pytest.mark.parametrize("period_by", ["year", "quarter", "month"])
-def test_allele_frequencies_advanced__period_by(period_by):
- _check_snp_allele_frequencies_advanced(
- period_by=period_by,
- )
- _check_aa_allele_frequencies_advanced(
- period_by=period_by,
- )
-
-
-@pytest.mark.parametrize(
- "sample_sets", ["AG1000G-BF-A", ["AG1000G-BF-A", "AG1000G-ML-A"], "3.0"]
-)
-def test_allele_frequencies_advanced__sample_sets(sample_sets):
- _check_snp_allele_frequencies_advanced(
- sample_sets=sample_sets,
- )
- _check_aa_allele_frequencies_advanced(
- sample_sets=sample_sets,
- )
-
-
-@pytest.mark.parametrize(
- "sample_query",
- [
- "taxon in ['gambiae', 'coluzzii'] and country == 'Mali'",
- "taxon == 'arabiensis' and country in ['Uganda', 'Tanzania']",
- ],
-)
-def test_allele_frequencies_advanced__sample_query(sample_query):
- _check_snp_allele_frequencies_advanced(
- sample_query=sample_query,
- )
- # noinspection PyTypeChecker
- _check_aa_allele_frequencies_advanced(
- sample_query=sample_query,
- variant_query=None,
- )
-
-
-@pytest.mark.parametrize("min_cohort_size", [10, 100])
-def test_allele_frequencies_advanced__min_cohort_size(min_cohort_size):
- _check_snp_allele_frequencies_advanced(
- min_cohort_size=min_cohort_size,
- )
- _check_aa_allele_frequencies_advanced(
- min_cohort_size=min_cohort_size,
- )
-
-
-@pytest.mark.parametrize(
- "variant_query",
- [
- None,
- "effect == 'NON_SYNONYMOUS_CODING' and max_af > 0.05",
- "effect == 'foobar'", # no variants
- ],
-)
-def test_allele_frequencies_advanced__variant_query(variant_query):
- _check_snp_allele_frequencies_advanced(
- variant_query=variant_query,
- )
- _check_aa_allele_frequencies_advanced(
- variant_query=variant_query,
- )
-
-
-@pytest.mark.parametrize("nobs_mode", ["called", "fixed"])
-def test_allele_frequencies_advanced__nobs_mode(nobs_mode):
- _check_snp_allele_frequencies_advanced(
- nobs_mode=nobs_mode,
- )
- _check_aa_allele_frequencies_advanced(
- nobs_mode=nobs_mode,
- )
-
-
# noinspection PyDefaultArgument
def _check_gene_cnv_frequencies_advanced(
region="2L",
@@ -1303,7 +627,7 @@ def _check_gene_cnv_frequencies_advanced(
c = v.split("variant_")[1]
actual = ds[v]
expect = df_af[c]
- _compare_series_like(actual, expect)
+ compare_series_like(actual, expect)
# check frequencies are consistent
for cohort_index, cohort_label in enumerate(ds["cohort_label"].values):
@@ -1435,7 +759,7 @@ def test_gene_cnv_frequencies_advanced__multi_contig_x():
for v in ds1:
a = ds1[v]
b = ds2[v]
- _compare_series_like(a, b)
+ compare_series_like(a, b)
def test_gene_cnv_frequencies_advanced__missing_samples():
@@ -1469,22 +793,6 @@ def test_gene_cnv_frequencies_advanced__dup_samples():
assert ds.dims == ds_dup.dims
-def _compare_series_like(actual, expect):
- # compare pandas series-like objects for equality or floating point
- # similarity, handling missing values appropriately
-
- # handle object arrays, these don't get nans compared properly
- t = actual.dtype
- if t == object:
- expect = expect.fillna("NA")
- actual = actual.fillna("NA")
-
- if t.kind == "f":
- assert_allclose(actual.values, expect.values)
- else:
- assert_array_equal(actual.values, expect.values)
-
-
def test_h12_gwss():
ag3 = setup_ag3(cohorts_analysis="20230516")
sample_query = "country == 'Ghana'"
diff --git a/tests/test_anopheles.py b/tests/test_anopheles.py
index 9b2ca7bd1..1046d0682 100644
--- a/tests/test_anopheles.py
+++ b/tests/test_anopheles.py
@@ -1,8 +1,6 @@
import numpy as np
-import pandas as pd
import pytest
from numpy.testing import assert_allclose
-from pandas.testing import assert_frame_equal
from malariagen_data import Af1, Ag3
from malariagen_data.af1 import GCS_URL as AF1_GCS_URL
@@ -43,243 +41,6 @@ def setup_subclass_cached(subclass, **kwargs):
return setup_subclass(subclass, url=url, **kwargs)
-@pytest.mark.parametrize(
- "subclass, sample_sets, universal_fields, transcript, site_mask, cohorts_analysis, expected_snp_count",
- [
- (
- Ag3,
- "3.0",
- [
- "pass_gamb_colu_arab",
- "pass_gamb_colu",
- "pass_arab",
- "label",
- ],
- "AGAP004707-RD",
- "gamb_colu",
- "20211101",
- 16526,
- ),
- (
- Af1,
- "1.0",
- [
- "pass_funestus",
- "label",
- ],
- "LOC125767311_t2",
- "funestus",
- "20221129",
- 4221,
- ),
- ],
-)
-def test_snp_allele_frequencies__str_cohorts(
- subclass,
- sample_sets,
- universal_fields,
- transcript,
- site_mask,
- cohorts_analysis,
- expected_snp_count,
-):
- anoph = setup_subclass_cached(subclass, cohorts_analysis=cohorts_analysis)
-
- cohorts = "admin1_month"
- min_cohort_size = 10
- df = anoph.snp_allele_frequencies(
- transcript=transcript,
- cohorts=cohorts,
- min_cohort_size=min_cohort_size,
- site_mask=site_mask,
- sample_sets=sample_sets,
- drop_invariant=True,
- effects=False,
- )
- df_coh = anoph.cohorts_metadata(sample_sets=sample_sets)
- coh_nm = "cohort_" + cohorts
- coh_counts = df_coh[coh_nm].dropna().value_counts()
- cohort_labels = coh_counts[coh_counts >= min_cohort_size].index.to_list()
- frq_cohort_labels = ["frq_" + s for s in cohort_labels]
- expected_fields = universal_fields + frq_cohort_labels + ["max_af"]
-
- assert isinstance(df, pd.DataFrame)
- assert sorted(df.columns.tolist()) == sorted(expected_fields)
- assert df.index.names == ["contig", "position", "ref_allele", "alt_allele"]
- assert len(df) == expected_snp_count
-
-
-@pytest.mark.parametrize(
- "subclass, transcript, sample_set",
- [
- (
- Ag3,
- "AGAP004707-RD",
- "AG1000G-FR",
- ),
- (
- Af1,
- "LOC125767311_t2",
- "1229-VO-GH-DADZIE-VMF00095",
- ),
- ],
-)
-def test_snp_allele_frequencies__dup_samples(
- subclass,
- transcript,
- sample_set,
-):
- # Expect automatically deduplicate any sample sets.
- anoph = setup_subclass_cached(subclass)
- df = anoph.snp_allele_frequencies(
- transcript=transcript,
- cohorts="admin1_year",
- sample_sets=[sample_set],
- )
- df_dup = anoph.snp_allele_frequencies(
- transcript=transcript,
- cohorts="admin1_year",
- sample_sets=[sample_set, sample_set],
- )
- assert_frame_equal(df, df_dup)
-
-
-@pytest.mark.parametrize(
- "subclass, transcript, sample_sets",
- [
- (
- Ag3,
- "foobar",
- "3.0",
- ),
- (
- Af1,
- "foobar",
- "1.0",
- ),
- ],
-)
-def test_snp_allele_frequencies__bad_transcript(
- subclass,
- transcript,
- sample_sets,
-):
- anoph = setup_subclass_cached(subclass)
- with pytest.raises(ValueError):
- anoph.snp_allele_frequencies(
- transcript=transcript,
- cohorts="admin1_year",
- sample_sets=sample_sets,
- )
-
-
-@pytest.mark.parametrize(
- "subclass, cohorts_analysis, transcript, sample_set",
- [
- (
- Ag3,
- "20211101",
- "AGAP004707-RD",
- "AG1000G-FR",
- ),
- (
- Af1,
- "20221129",
- "LOC125767311_t2",
- "1229-VO-GH-DADZIE-VMF00095",
- ),
- ],
-)
-def test_aa_allele_frequencies__dup_samples(
- subclass, cohorts_analysis, transcript, sample_set
-):
- # Expect automatically deduplicate sample sets.
- anoph = setup_subclass_cached(subclass=subclass, cohorts_analysis=cohorts_analysis)
- df = anoph.aa_allele_frequencies(
- transcript=transcript,
- cohorts="admin1_year",
- sample_sets=[sample_set],
- )
- df_dup = anoph.aa_allele_frequencies(
- transcript=transcript,
- cohorts="admin1_year",
- sample_sets=[sample_set, sample_set],
- )
- assert_frame_equal(df, df_dup)
-
-
-@pytest.mark.parametrize(
- "subclass, cohorts_analysis, transcript, sample_set",
- [
- (
- Ag3,
- "20211101",
- "AGAP004707-RD",
- "AG1000G-FR",
- ),
- (
- Af1,
- "20221129",
- "LOC125767311_t2",
- "1229-VO-GH-DADZIE-VMF00095",
- ),
- ],
-)
-def test_snp_allele_frequencies_advanced__dup_samples(
- subclass, cohorts_analysis, transcript, sample_set
-):
- anoph = setup_subclass_cached(subclass=subclass, cohorts_analysis=cohorts_analysis)
- ds = anoph.snp_allele_frequencies_advanced(
- transcript=transcript,
- area_by="admin1_iso",
- period_by="year",
- sample_sets=[sample_set],
- )
- ds_dup = anoph.snp_allele_frequencies_advanced(
- transcript=transcript,
- area_by="admin1_iso",
- period_by="year",
- sample_sets=[sample_set, sample_set],
- )
- assert ds.dims == ds_dup.dims
-
-
-@pytest.mark.parametrize(
- "subclass, cohorts_analysis, transcript, sample_set",
- [
- (
- Ag3,
- "20211101",
- "AGAP004707-RD",
- "AG1000G-FR",
- ),
- (
- Af1,
- "20221129",
- "LOC125767311_t2",
- "1229-VO-GH-DADZIE-VMF00095",
- ),
- ],
-)
-def test_aa_allele_frequencies_advanced__dup_samples(
- subclass, cohorts_analysis, transcript, sample_set
-):
- anoph = setup_subclass_cached(subclass=subclass, cohorts_analysis=cohorts_analysis)
- ds_dup = anoph.aa_allele_frequencies_advanced(
- transcript=transcript,
- area_by="admin1_iso",
- period_by="year",
- sample_sets=[sample_set, sample_set],
- )
- ds = anoph.aa_allele_frequencies_advanced(
- transcript=transcript,
- area_by="admin1_iso",
- period_by="year",
- sample_sets=[sample_set],
- )
- assert ds.dims == ds_dup.dims
-
-
def test_haplotype_frequencies():
h1 = np.array(
[