diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index 66feaa829..5b4a4fd08 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -28,7 +28,7 @@ jobs: run: poetry install - name: Run tests with coverage - run: poetry run pytest -v --cov malariagen_data/anoph --cov-report=xml tests/anoph + run: poetry run pytest --durations=20 --durations-min=1.0 -v --cov malariagen_data/anoph --cov-report=xml tests/anoph - name: Upload coverage report uses: codecov/codecov-action@v3 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4c5e31ee1..28d8f13a9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -41,7 +41,7 @@ jobs: key: gcs_cache_tests_20231119 - name: Run full test suite - run: poetry run pytest -v tests + run: poetry run pytest --durations=20 --durations-min=10.0 -v tests - name: Save GCS cache uses: actions/cache/save@v3 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9bdb021ea..34d8c53c1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,5 +12,7 @@ repos: hooks: # Run the linter. - id: ruff + args: + - "--fix" # Run the formatter. - id: ruff-format diff --git a/malariagen_data/af1.py b/malariagen_data/af1.py index eb1f3e342..fda0d39f2 100644 --- a/malariagen_data/af1.py +++ b/malariagen_data/af1.py @@ -128,6 +128,7 @@ def __init__( tqdm_class=tqdm_class, taxon_colors=TAXON_COLORS, virtual_contigs=None, + gene_names=None, ) def __repr__(self): @@ -209,21 +210,3 @@ def _repr_html_(self): """ return html - - def _transcript_to_gene_name(self, transcript): - df_genome_features = self.genome_features().set_index("ID") - rec_transcript = df_genome_features.loc[transcript] - parent = rec_transcript["Parent"] - - # E.g. manual overrides (used in Ag3) - # if parent == "AGAP004707": - # parent_name = "Vgsc/para" - # else: - # parent_name = rec_parent["Name"] - - # Note: Af1 doesn't have the "Name" attribute - # rec_parent = df_genome_features.loc[parent] - # parent_name = rec_parent["Name"] - parent_name = parent - - return parent_name diff --git a/malariagen_data/ag3.py b/malariagen_data/ag3.py index 8431b0d0d..512431077 100644 --- a/malariagen_data/ag3.py +++ b/malariagen_data/ag3.py @@ -27,6 +27,9 @@ "3RL": ("3R", "3L"), "23X": ("2R", "2L", "3R", "3L", "X"), } +GENE_NAMES = { + "AGAP004707": "Vgsc/para", +} def _setup_aim_palettes(): @@ -191,6 +194,7 @@ def __init__( tqdm_class=tqdm_class, taxon_colors=TAXON_COLORS, virtual_contigs=VIRTUAL_CONTIGS, + gene_names=GENE_NAMES, ) # set up caches @@ -293,20 +297,6 @@ def _repr_html_(self): """ return html - def _transcript_to_gene_name(self, transcript): - df_genome_features = self.genome_features().set_index("ID") - rec_transcript = df_genome_features.loc[transcript] - parent = rec_transcript["Parent"] - rec_parent = df_genome_features.loc[parent] - - # manual overrides - if parent == "AGAP004707": - parent_name = "Vgsc/para" - else: - parent_name = rec_parent["Name"] - - return parent_name - def cross_metadata(self): """Load a dataframe containing metadata about samples in colony crosses, including which samples are parents or progeny in which crosses. diff --git a/malariagen_data/anoph/cnv_data.py b/malariagen_data/anoph/cnv_data.py index fe50a639a..3a9cebf3c 100644 --- a/malariagen_data/anoph/cnv_data.py +++ b/malariagen_data/anoph/cnv_data.py @@ -163,60 +163,61 @@ def cnv_hmm( regions: List[Region] = parse_multi_region(self, region) del region - debug("access CNV HMM data and concatenate as needed") - lx = [] - for r in regions: - ly = [] - for s in sample_sets: - y = self._cnv_hmm_dataset( - contig=r.contig, - sample_set=s, - inline_array=inline_array, - chunks=chunks, + with self._spinner("Access CNV HMM data"): + debug("access CNV HMM data and concatenate as needed") + lx = [] + for r in regions: + ly = [] + for s in sample_sets: + y = self._cnv_hmm_dataset( + contig=r.contig, + sample_set=s, + inline_array=inline_array, + chunks=chunks, + ) + ly.append(y) + + debug("concatenate data from multiple sample sets") + x = simple_xarray_concat(ly, dim=DIM_SAMPLE) + + debug("handle region, do this only once - optimisation") + if r.start is not None or r.end is not None: + start = x["variant_position"].values + end = x["variant_end"].values + index = pd.IntervalIndex.from_arrays(start, end, closed="both") + # noinspection PyArgumentList + other = pd.Interval(r.start, r.end, closed="both") + loc_region = index.overlaps(other) # type: ignore + x = x.isel(variants=loc_region) + + lx.append(x) + + debug("concatenate data from multiple regions") + ds = simple_xarray_concat(lx, dim=DIM_VARIANT) + + debug("handle sample query") + if sample_query is not None: + debug("load sample metadata") + df_samples = self.sample_metadata(sample_sets=sample_sets) + + debug("align sample metadata with CNV data") + cnv_samples = ds["sample_id"].values.tolist() + df_samples_cnv = ( + df_samples.set_index("sample_id").loc[cnv_samples].reset_index() ) - ly.append(y) - - debug("concatenate data from multiple sample sets") - x = simple_xarray_concat(ly, dim=DIM_SAMPLE) - - debug("handle region, do this only once - optimisation") - if r.start is not None or r.end is not None: - start = x["variant_position"].values - end = x["variant_end"].values - index = pd.IntervalIndex.from_arrays(start, end, closed="both") - # noinspection PyArgumentList - other = pd.Interval(r.start, r.end, closed="both") - loc_region = index.overlaps(other) # type: ignore - x = x.isel(variants=loc_region) - - lx.append(x) - - debug("concatenate data from multiple regions") - ds = simple_xarray_concat(lx, dim=DIM_VARIANT) - - debug("handle sample query") - if sample_query is not None: - debug("load sample metadata") - df_samples = self.sample_metadata(sample_sets=sample_sets) - - debug("align sample metadata with CNV data") - cnv_samples = ds["sample_id"].values.tolist() - df_samples_cnv = ( - df_samples.set_index("sample_id").loc[cnv_samples].reset_index() - ) - debug("apply the query") - loc_query_samples = df_samples_cnv.eval(sample_query).values - if np.count_nonzero(loc_query_samples) == 0: - raise ValueError(f"No samples found for query {sample_query!r}") + debug("apply the query") + loc_query_samples = df_samples_cnv.eval(sample_query).values + if np.count_nonzero(loc_query_samples) == 0: + raise ValueError(f"No samples found for query {sample_query!r}") - ds = ds.isel(samples=loc_query_samples) + ds = ds.isel(samples=loc_query_samples) - debug("handle coverage variance filter") - if max_coverage_variance is not None: - cov_var = ds["sample_coverage_variance"].values - loc_pass_samples = cov_var <= max_coverage_variance - ds = ds.isel(samples=loc_pass_samples) + debug("handle coverage variance filter") + if max_coverage_variance is not None: + cov_var = ds["sample_coverage_variance"].values + loc_pass_samples = cov_var <= max_coverage_variance + ds = ds.isel(samples=loc_pass_samples) return ds diff --git a/malariagen_data/anoph/frq_params.py b/malariagen_data/anoph/frq_params.py index 417e507b0..579125b43 100644 --- a/malariagen_data/anoph/frq_params.py +++ b/malariagen_data/anoph/frq_params.py @@ -65,3 +65,8 @@ `gene_cnv_frequencies_advanced()`. """, ] + +include_counts: TypeAlias = Annotated[ + bool, + "Include columns with allele counts and number of non-missing allele calls (nobs).", +] diff --git a/malariagen_data/anoph/genome_features.py b/malariagen_data/anoph/genome_features.py index 3f2bfaaaf..d6c6aa5bb 100644 --- a/malariagen_data/anoph/genome_features.py +++ b/malariagen_data/anoph/genome_features.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, Mapping import bokeh.models import bokeh.plotting @@ -26,6 +26,7 @@ def __init__( *, gff_gene_type: str, gff_default_attributes: Tuple[str, ...], + gene_names: Optional[Mapping[str, str]] = None, **kwargs, ): # N.B., this class is designed to work cooperatively, and @@ -38,6 +39,11 @@ def __init__( self._gff_gene_type = gff_gene_type self._gff_default_attributes = gff_default_attributes + # Allow manual override of gene names. + if gene_names is None: + gene_names = dict() + self._gene_name_overrides = gene_names + # Setup caches. self._cache_genome_features: Dict[Tuple[str, ...], pd.DataFrame] = dict() @@ -45,7 +51,7 @@ def __init__( def _geneset_gff3_path(self): return self.config["GENESET_GFF3_PATH"] - def geneset(self, *args, **kwargs): + def geneset(self, *args, **kwargs): # pragma: no cover """Deprecated, this method has been renamed to genome_features().""" return self.genome_features(*args, **kwargs) @@ -429,3 +435,21 @@ def _bokeh_style_genome_xaxis(fig, contig): fig.xaxis.ticker = bokeh.models.AdaptiveTicker(min_interval=1) fig.xaxis.minor_tick_line_color = None fig.xaxis[0].formatter = bokeh.models.NumeralTickFormatter(format="0,0") + + def _transcript_to_parent_name(self, transcript): + df_genome_features = self.genome_features().set_index("ID") + + try: + rec_transcript = df_genome_features.loc[transcript] + except KeyError: + return None + + parent_id = rec_transcript["Parent"] + + try: + # Manual override. + return self._gene_name_overrides[parent_id] + except KeyError: + rec_parent = df_genome_features.loc[parent_id] + # Try to access "Name" attribute, fall back to "ID" if not present. + return rec_parent.get("Name", parent_id) diff --git a/malariagen_data/anoph/sample_metadata.py b/malariagen_data/anoph/sample_metadata.py index fa158bd9c..eee4b1b95 100644 --- a/malariagen_data/anoph/sample_metadata.py +++ b/malariagen_data/anoph/sample_metadata.py @@ -969,3 +969,37 @@ def _setup_sample_hover_data_plotly( if symbol and symbol not in hover_data: hover_data.append(symbol) return hover_data + + +def locate_cohorts(*, cohorts, data): + # Build cohort dictionary where key=cohort_id, value=loc_coh. + coh_dict = {} + + if isinstance(cohorts, Mapping): + # User has supplied a custom dictionary mapping cohort identifiers + # to pandas queries. + + for coh, query in cohorts.items(): + loc_coh = data.eval(query).values + coh_dict[coh] = loc_coh + + else: + assert isinstance(cohorts, str) + # User has supplied the name of a sample metadata column. + + # Convenience to allow things like "admin1_year" instead of "cohort_admin1_year". + if "cohort_" + cohorts in data.columns: + cohorts = "cohort_" + cohorts + + # Check the given cohort set exists. + if cohorts not in data.columns: + raise ValueError(f"{cohorts!r} is not a known column in the data.") + cohort_labels = data[cohorts].unique() + + # Remove the nans and sort. + cohort_labels = sorted([c for c in cohort_labels if isinstance(c, str)]) + for coh in cohort_labels: + loc_coh = data[cohorts] == coh + coh_dict[coh] = loc_coh.values + + return coh_dict diff --git a/malariagen_data/anoph/snp_frq.py b/malariagen_data/anoph/snp_frq.py new file mode 100644 index 000000000..07466ddb4 --- /dev/null +++ b/malariagen_data/anoph/snp_frq.py @@ -0,0 +1,1352 @@ +from typing import Optional, Dict, Union, Callable, List +import warnings +from textwrap import dedent + +import allel # type: ignore +import numpy as np +import pandas as pd +from numpydoc_decorator import doc # type: ignore +import xarray as xr +import numba # type: ignore +import plotly.express as px # type: ignore + +from .. import veff +from ..util import check_types, pandas_apply +from .snp_data import AnophelesSnpData +from .sample_metadata import locate_cohorts +from . import base_params, frq_params, map_params, plotly_params + + +AA_CHANGE_QUERY = ( + "effect in ['NON_SYNONYMOUS_CODING', 'START_LOST', 'STOP_LOST', 'STOP_GAINED']" +) + + +class AnophelesSnpFrequencyAnalysis( + AnophelesSnpData, +): + def __init__( + self, + **kwargs, + ): + # N.B., this class is designed to work cooperatively, and + # so it's important that any remaining parameters are passed + # to the superclass constructor. + super().__init__(**kwargs) + + # Set up cache variables. + self._cache_annotator = None + + def _snp_df_melt(self, *, ds_snp: xr.Dataset) -> pd.DataFrame: + """Set up a dataframe with SNP site and filter data, + melting each alternate allele into a separate row.""" + + with self._spinner(desc="Prepare SNP dataframe"): + # Grab contig, pos, ref and alt. + contig_index = ds_snp["variant_contig"].values[0] + contig = ds_snp.attrs["contigs"][contig_index] + pos = ds_snp["variant_position"].values + alleles = ds_snp["variant_allele"].values + ref = alleles[:, 0] + alt = alleles[:, 1:] + + # Access site filters. + filter_pass = dict() + for m in self.site_mask_ids: + x = ds_snp[f"variant_filter_pass_{m}"].values + filter_pass[m] = x + + # Set up columns with contig, pos, ref, alt columns, melting + # the data out to one row per alternate allele. + cols = { + "contig": contig, + "position": np.repeat(pos, 3), + "ref_allele": np.repeat(ref.astype("U1"), 3), + "alt_allele": alt.astype("U1").flatten(), + } + + # Add mask columns. + for m in self.site_mask_ids: + x = filter_pass[m] + cols[f"pass_{m}"] = np.repeat(x, 3) + + # Construct dataframe. + df_snps = pd.DataFrame(cols) + + return df_snps + + def _snp_effect_annotator(self): + """Set up variant effect annotator.""" + if self._cache_annotator is None: + self._cache_annotator = veff.Annotator( + genome=self.open_genome(), genome_features=self.genome_features() + ) + return self._cache_annotator + + @check_types + @doc( + summary="Compute variant effects for a gene transcript.", + returns=""" + A dataframe of all possible SNP variants and their effects, one row + per variant. + """, + ) + def snp_effects( + self, + transcript: base_params.transcript, + site_mask: Optional[base_params.site_mask] = None, + ) -> pd.DataFrame: + # Access SNP data. + ds_snp = self.snp_variants( + region=transcript, + site_mask=site_mask, + ) + + # Setup initial dataframe of SNPs. + df_snps = self._snp_df_melt(ds_snp=ds_snp) + + # Setup variant effect annotator. + ann = self._snp_effect_annotator() + + # Add effects to the dataframe. + ann.get_effects(transcript=transcript, variants=df_snps) + + return df_snps + + @check_types + @doc( + summary=""" + Compute SNP allele frequencies for a gene transcript. + """, + returns=""" + A dataframe of SNP allele frequencies, one row per variant allele. + """, + notes=""" + Cohorts with fewer samples than `min_cohort_size` will be excluded from + output data frame. + """, + ) + def snp_allele_frequencies( + self, + transcript: base_params.transcript, + cohorts: base_params.cohorts, + sample_query: Optional[base_params.sample_query] = None, + min_cohort_size: base_params.min_cohort_size = 10, + site_mask: Optional[base_params.site_mask] = None, + sample_sets: Optional[base_params.sample_sets] = None, + drop_invariant: frq_params.drop_invariant = True, + effects: frq_params.effects = True, + include_counts: frq_params.include_counts = False, + ) -> pd.DataFrame: + # Access sample metadata. + df_samples = self.sample_metadata( + sample_sets=sample_sets, sample_query=sample_query + ) + + # Build cohort dictionary, maps cohort labels to boolean indexers. + coh_dict = locate_cohorts(cohorts=cohorts, data=df_samples) + + # Remove cohorts below minimum cohort size. + coh_dict = { + coh: loc_coh + for coh, loc_coh in coh_dict.items() + if np.count_nonzero(loc_coh) >= min_cohort_size + } + + # Early check for no cohorts. + if len(coh_dict) == 0: + raise ValueError( + "No cohorts available for the given sample selection parameters and minimum cohort size." + ) + + # Access SNP data. + ds_snp = self.snp_calls( + region=transcript, + site_mask=site_mask, + sample_sets=sample_sets, + sample_query=sample_query, + ) + + # Early check for no SNPs. + if ds_snp.sizes["variants"] == 0: # pragma: no cover + raise ValueError("No SNPs available for the given region and site mask.") + + # Access genotypes. + gt = ds_snp["call_genotype"].data + with self._dask_progress(desc="Load SNP genotypes"): + gt = gt.compute() + + # Set up initial dataframe of SNPs. + df_snps = self._snp_df_melt(ds_snp=ds_snp) + + # Count alleles. + count_cols = dict() + nobs_cols = dict() + freq_cols = dict() + cohorts_iterator = self._progress( + coh_dict.items(), desc="Compute allele frequencies" + ) + for coh, loc_coh in cohorts_iterator: + n_samples = np.count_nonzero(loc_coh) + assert n_samples >= min_cohort_size + gt_coh = np.compress(loc_coh, gt, axis=1) + ac_coh = np.asarray(allel.GenotypeArray(gt_coh).count_alleles(max_allele=3)) + an_coh = np.sum(ac_coh, axis=1)[:, None] + with np.errstate(divide="ignore", invalid="ignore"): + af_coh = np.where(an_coh > 0, ac_coh / an_coh, np.nan) + # Melt the frequencies so we get one row per alternate allele. + frq = af_coh[:, 1:].flatten() + freq_cols["frq_" + coh] = frq + count = ac_coh[:, 1:].flatten() + count_cols["count_" + coh] = count + nobs = np.repeat(an_coh[:, 0], 3) + nobs_cols["nobs_" + coh] = nobs + + # Build a dataframe with the frequency columns. + df_freqs = pd.DataFrame(freq_cols) + df_counts = pd.DataFrame(count_cols) + df_nobs = pd.DataFrame(nobs_cols) + + # Compute max_af. + df_max_af = pd.DataFrame({"max_af": df_freqs.max(axis=1)}) + + # Build the final dataframe. + df_snps.reset_index(drop=True, inplace=True) + if include_counts: + df_snps = pd.concat( + [df_snps, df_freqs, df_max_af, df_counts, df_nobs], axis=1 + ) + else: + df_snps = pd.concat([df_snps, df_freqs, df_max_af], axis=1) + + # Drop invariants. + if drop_invariant: + loc_variant = df_snps["max_af"] > 0 + + # Check for no SNPs remaining after dropping invariants. + if np.count_nonzero(loc_variant) == 0: # pragma: no cover + raise ValueError("No SNPs remaining after dropping invariant SNPs.") + + df_snps = df_snps.loc[loc_variant] + + # Reset index after filtering. + df_snps.reset_index(inplace=True, drop=True) + + if effects: + # Add effect annotations. + ann = self._snp_effect_annotator() + ann.get_effects( + transcript=transcript, variants=df_snps, progress=self._progress + ) + + # Add label. + df_snps["label"] = pandas_apply( + _make_snp_label_effect, + df_snps, + columns=["contig", "position", "ref_allele", "alt_allele", "aa_change"], + ) + + # Set index. + df_snps.set_index( + ["contig", "position", "ref_allele", "alt_allele", "aa_change"], + inplace=True, + ) + + else: + # Add label. + df_snps["label"] = pandas_apply( + _make_snp_label, + df_snps, + columns=["contig", "position", "ref_allele", "alt_allele"], + ) + + # Set index. + df_snps.set_index( + ["contig", "position", "ref_allele", "alt_allele"], + inplace=True, + ) + + # Add dataframe metadata. + gene_name = self._transcript_to_parent_name(transcript) + title = transcript + if gene_name: + title += f" ({gene_name})" + title += " SNP frequencies" + df_snps.attrs["title"] = title + + return df_snps + + @check_types + @doc( + summary=""" + Compute amino acid substitution frequencies for a gene transcript. + """, + returns=""" + A dataframe of amino acid allele frequencies, one row per + substitution. + """, + notes=""" + Cohorts with fewer samples than `min_cohort_size` will be excluded from + output data frame. + """, + ) + def aa_allele_frequencies( + self, + transcript: base_params.transcript, + cohorts: base_params.cohorts, + sample_query: Optional[base_params.sample_query] = None, + min_cohort_size: Optional[base_params.min_cohort_size] = 10, + site_mask: Optional[base_params.site_mask] = None, + sample_sets: Optional[base_params.sample_sets] = None, + drop_invariant: frq_params.drop_invariant = True, + include_counts: frq_params.include_counts = False, + ) -> pd.DataFrame: + df_snps = self.snp_allele_frequencies( + transcript=transcript, + cohorts=cohorts, + sample_query=sample_query, + min_cohort_size=min_cohort_size, + site_mask=site_mask, + sample_sets=sample_sets, + drop_invariant=drop_invariant, + effects=True, + include_counts=include_counts, + ) + df_snps.reset_index(inplace=True) + + # We just want aa change. + df_ns_snps = df_snps.query(AA_CHANGE_QUERY).copy() + + # Early check for no matching SNPs. + if len(df_ns_snps) == 0: # pragma: no cover + raise ValueError( + "No amino acid change SNPs found for the given transcript and site mask." + ) + + # N.B., we need to worry about the possibility of the + # same aa change due to SNPs at different positions. We cannot + # sum frequencies of SNPs at different genomic positions. This + # is why we group by position and aa_change, not just aa_change. + + # Group and sum to collapse multi variant allele changes. + freq_cols = [col for col in df_ns_snps if col.startswith("frq_")] + + # Special handling here to ensure nans don't get summed to zero. + # See also https://github.com/pandas-dev/pandas/issues/20824#issuecomment-705376621 + def np_sum(g): + return np.sum(g.values) + + agg: Dict[str, Union[Callable, str]] = {c: np_sum for c in freq_cols} + + # Add in counts and observations data if requested. + if include_counts: + count_cols = [col for col in df_ns_snps if col.startswith("count_")] + for c in count_cols: + agg[c] = "sum" + nobs_cols = [col for col in df_ns_snps if col.startswith("nobs_")] + for c in nobs_cols: + agg[c] = "first" + + keep_cols = ( + "contig", + "transcript", + "aa_pos", + "ref_allele", + "ref_aa", + "alt_aa", + "effect", + "impact", + ) + for c in keep_cols: + agg[c] = "first" + agg["alt_allele"] = lambda v: "{" + ",".join(v) + "}" if len(v) > 1 else v + df_aaf = df_ns_snps.groupby(["position", "aa_change"]).agg(agg).reset_index() + + # Compute new max_af. + df_aaf["max_af"] = df_aaf[freq_cols].max(axis=1) + + # Add label. + df_aaf["label"] = pandas_apply( + _make_snp_label_aa, + df_aaf, + columns=["aa_change", "contig", "position", "ref_allele", "alt_allele"], + ) + + # Sort by genomic position. + df_aaf = df_aaf.sort_values(["position", "aa_change"]) + + # Set index. + df_aaf.set_index(["aa_change", "contig", "position"], inplace=True) + + # Add metadata. + gene_name = self._transcript_to_parent_name(transcript) + title = transcript + if gene_name: + title += f" ({gene_name})" + title += " SNP frequencies" + df_aaf.attrs["title"] = title + + return df_aaf + + @check_types + @doc( + summary=""" + Group samples by taxon, area (space) and period (time), then compute + SNP allele frequencies. + """, + returns=""" + The resulting dataset contains data has dimensions "cohorts" and + "variants". Variables prefixed with "cohort" are 1-dimensional + arrays with data about the cohorts, such as the area, period, taxon + and cohort size. Variables prefixed with "variant" are + 1-dimensional arrays with data about the variants, such as the + contig, position, reference and alternate alleles. Variables + prefixed with "event" are 2-dimensional arrays with the allele + counts and frequency calculations. + """, + ) + def snp_allele_frequencies_advanced( + self, + transcript: base_params.transcript, + area_by: frq_params.area_by, + period_by: frq_params.period_by, + sample_sets: Optional[base_params.sample_sets] = None, + sample_query: Optional[base_params.sample_query] = None, + min_cohort_size: base_params.min_cohort_size = 10, + drop_invariant: frq_params.drop_invariant = True, + variant_query: Optional[frq_params.variant_query] = None, + site_mask: Optional[base_params.site_mask] = None, + nobs_mode: frq_params.nobs_mode = frq_params.nobs_mode_default, + ci_method: Optional[frq_params.ci_method] = frq_params.ci_method_default, + ) -> xr.Dataset: + # Load sample metadata. + df_samples = self.sample_metadata( + sample_sets=sample_sets, sample_query=sample_query + ) + + # Prepare sample metadata for cohort grouping. + df_samples = _prep_samples_for_cohort_grouping( + df_samples=df_samples, + area_by=area_by, + period_by=period_by, + ) + + # Group samples to make cohorts. + group_samples_by_cohort = df_samples.groupby(["taxon", "area", "period"]) + + # Build cohorts dataframe. + df_cohorts = _build_cohorts_from_sample_grouping( + group_samples_by_cohort=group_samples_by_cohort, + min_cohort_size=min_cohort_size, + ) + + # Early check for no cohorts. + if len(df_cohorts) == 0: + raise ValueError( + "No cohorts available for the given sample selection parameters and minimum cohort size." + ) + + # Access SNP calls. + ds_snps = self.snp_calls( + region=transcript, + sample_sets=sample_sets, + sample_query=sample_query, + site_mask=site_mask, + ) + + # Early check for no SNPs. + if ds_snps.sizes["variants"] == 0: # pragma: no cover + raise ValueError("No SNPs available for the given region and site mask.") + + # Access genotypes. + gt = ds_snps["call_genotype"].data + with self._dask_progress(desc="Load SNP genotypes"): + gt = gt.compute() + + # Set up variant variables. + contigs = ds_snps.attrs["contigs"] + variant_contig = np.repeat( + [contigs[i] for i in ds_snps["variant_contig"].values], 3 + ) + variant_position = np.repeat(ds_snps["variant_position"].values, 3) + alleles = ds_snps["variant_allele"].values + variant_ref_allele = np.repeat(alleles[:, 0], 3) + variant_alt_allele = alleles[:, 1:].flatten() + variant_pass = dict() + for site_mask in self.site_mask_ids: + variant_pass[site_mask] = np.repeat( + ds_snps[f"variant_filter_pass_{site_mask}"].values, 3 + ) + + # Set up main event variables. + n_variants, n_cohorts = len(variant_position), len(df_cohorts) + count = np.zeros((n_variants, n_cohorts), dtype=int) + nobs = np.zeros((n_variants, n_cohorts), dtype=int) + + # Build event count and nobs for each cohort. + cohorts_iterator = self._progress( + enumerate(df_cohorts.itertuples()), + total=len(df_cohorts), + desc="Compute SNP allele frequencies", + ) + for cohort_index, cohort in cohorts_iterator: + cohort_key = cohort.taxon, cohort.area, cohort.period + sample_indices = group_samples_by_cohort.indices[cohort_key] + + cohort_ac, cohort_an = _cohort_alt_allele_counts_melt( + gt=gt, + indices=sample_indices, + max_allele=3, + ) + count[:, cohort_index] = cohort_ac + + if nobs_mode == "called": + nobs[:, cohort_index] = cohort_an + else: + assert nobs_mode == "fixed" + nobs[:, cohort_index] = cohort.size * 2 + + # Compute frequency. + with np.errstate(divide="ignore", invalid="ignore"): + # Ignore division warnings. + frequency = count / nobs + + # Compute maximum frequency over cohorts. + with warnings.catch_warnings(): + # Ignore "All-NaN slice encountered" warnings. + warnings.simplefilter("ignore", category=RuntimeWarning) + max_af = np.nanmax(frequency, axis=1) + + # Make dataframe of SNPs. + df_variants_cols = { + "contig": variant_contig, + "position": variant_position, + "ref_allele": variant_ref_allele.astype("U1"), + "alt_allele": variant_alt_allele.astype("U1"), + "max_af": max_af, + } + for site_mask in self.site_mask_ids: + df_variants_cols[f"pass_{site_mask}"] = variant_pass[site_mask] + df_variants = pd.DataFrame(df_variants_cols) + + # Deal with SNP alleles not observed. + if drop_invariant: + loc_variant = max_af > 0 + + # Check for no SNPs remaining after dropping invariants. + if np.count_nonzero(loc_variant) == 0: # pragma: no cover + raise ValueError("No SNPs remaining after dropping invariant SNPs.") + + df_variants = df_variants.loc[loc_variant].reset_index(drop=True) + count = np.compress(loc_variant, count, axis=0) + nobs = np.compress(loc_variant, nobs, axis=0) + frequency = np.compress(loc_variant, frequency, axis=0) + + # Set up variant effect annotator. + ann = self._snp_effect_annotator() + + # Add effects to the dataframe. + ann.get_effects( + transcript=transcript, variants=df_variants, progress=self._progress + ) + + # Add variant labels. + df_variants["label"] = pandas_apply( + _make_snp_label_effect, + df_variants, + columns=["contig", "position", "ref_allele", "alt_allele", "aa_change"], + ) + + # Build the output dataset. + ds_out = xr.Dataset() + + # Cohort variables. + for coh_col in df_cohorts.columns: + ds_out[f"cohort_{coh_col}"] = "cohorts", df_cohorts[coh_col] + + # Variant variables. + for snp_col in df_variants.columns: + ds_out[f"variant_{snp_col}"] = "variants", df_variants[snp_col] + + # Event variables. + ds_out["event_count"] = ("variants", "cohorts"), count + ds_out["event_nobs"] = ("variants", "cohorts"), nobs + ds_out["event_frequency"] = ("variants", "cohorts"), frequency + + # Apply variant query. + if variant_query is not None: + loc_variants = df_variants.eval(variant_query).values + + # Check for no SNPs remaining after applying variant query. + if np.count_nonzero(loc_variants) == 0: + raise ValueError( + f"No SNPs remaining after applying variant query {variant_query!r}." + ) + + ds_out = ds_out.isel(variants=loc_variants) + + # Add confidence intervals. + _add_frequency_ci(ds=ds_out, ci_method=ci_method) + + # Tidy up display by sorting variables. + sorted_vars: List[str] = sorted([str(k) for k in ds_out.keys()]) + ds_out = ds_out[sorted_vars] + + # Add metadata. + gene_name = self._transcript_to_parent_name(transcript) + title = transcript + if gene_name: + title += f" ({gene_name})" + title += " SNP frequencies" + ds_out.attrs["title"] = title + + return ds_out + + @check_types + @doc( + summary=""" + Group samples by taxon, area (space) and period (time), then compute + amino acid change allele frequencies. + """, + returns=""" + The resulting dataset contains data has dimensions "cohorts" and + "variants". Variables prefixed with "cohort" are 1-dimensional + arrays with data about the cohorts, such as the area, period, taxon + and cohort size. Variables prefixed with "variant" are + 1-dimensional arrays with data about the variants, such as the + contig, position, reference and alternate alleles. Variables + prefixed with "event" are 2-dimensional arrays with the allele + counts and frequency calculations. + """, + ) + def aa_allele_frequencies_advanced( + self, + transcript: base_params.transcript, + area_by: frq_params.area_by, + period_by: frq_params.period_by, + sample_sets: Optional[base_params.sample_sets] = None, + sample_query: Optional[base_params.sample_query] = None, + min_cohort_size: base_params.min_cohort_size = 10, + variant_query: Optional[frq_params.variant_query] = None, + site_mask: Optional[base_params.site_mask] = None, + nobs_mode: frq_params.nobs_mode = "called", + ci_method: Optional[frq_params.ci_method] = "wilson", + ) -> xr.Dataset: + # Begin by computing SNP allele frequencies. + ds_snp_frq = self.snp_allele_frequencies_advanced( + transcript=transcript, + area_by=area_by, + period_by=period_by, + sample_sets=sample_sets, + sample_query=sample_query, + min_cohort_size=min_cohort_size, + drop_invariant=True, # always drop invariant for aa frequencies + variant_query=AA_CHANGE_QUERY, # we'll also apply a variant query later + site_mask=site_mask, + nobs_mode=nobs_mode, + ci_method=None, # we will recompute confidence intervals later + ) + + # N.B., we need to worry about the possibility of the + # same aa change due to SNPs at different positions. We cannot + # sum frequencies of SNPs at different genomic positions. This + # is why we group by position and aa_change, not just aa_change. + + # Add in a special grouping column to work around the fact that xarray currently + # doesn't support grouping by multiple variables in the same dimension. + df_grouper = ds_snp_frq[ + ["variant_position", "variant_aa_change"] + ].to_dataframe() + grouper_var = df_grouper.apply( + lambda row: "_".join([str(v) for v in row]), axis="columns" + ) + ds_snp_frq["variant_position_aa_change"] = "variants", grouper_var + + # Group by position and amino acid change. + group_by_aa_change = ds_snp_frq.groupby("variant_position_aa_change") + + # Apply aggregation. + ds_aa_frq = group_by_aa_change.map(_map_snp_to_aa_change_frq_ds) + + # Add back in cohort variables, unaffected by aggregation. + cohort_vars = [v for v in ds_snp_frq if v.startswith("cohort_")] + for v in cohort_vars: + ds_aa_frq[v] = ds_snp_frq[v] + + # Sort by genomic position. + ds_aa_frq = ds_aa_frq.sortby(["variant_position", "variant_aa_change"]) + + # Recompute frequency. + count = ds_aa_frq["event_count"].values + nobs = ds_aa_frq["event_nobs"].values + with np.errstate(divide="ignore", invalid="ignore"): + frequency = count / nobs # ignore division warnings + ds_aa_frq["event_frequency"] = ("variants", "cohorts"), frequency + + # Recompute max frequency over cohorts. + with warnings.catch_warnings(): + # Ignore "All-NaN slice encountered" warnings. + warnings.simplefilter("ignore", category=RuntimeWarning) + max_af = np.nanmax(ds_aa_frq["event_frequency"].values, axis=1) + ds_aa_frq["variant_max_af"] = "variants", max_af + + # Set up variant dataframe, useful intermediate. + variant_cols = [v for v in ds_aa_frq if v.startswith("variant_")] + df_variants = ds_aa_frq[variant_cols].to_dataframe() + df_variants.columns = [c.split("variant_")[1] for c in df_variants.columns] + + # Assign new variant label. + label = pandas_apply( + _make_snp_label_aa, + df_variants, + columns=["aa_change", "contig", "position", "ref_allele", "alt_allele"], + ) + ds_aa_frq["variant_label"] = "variants", label + + # Apply variant query if given. + if variant_query is not None: + loc_variants = df_variants.eval(variant_query).values + + # Check for no SNPs remaining after applying variant query. + if np.count_nonzero(loc_variants) == 0: + raise ValueError( + f"No SNPs remaining after applying variant query {variant_query!r}." + ) + + ds_aa_frq = ds_aa_frq.isel(variants=loc_variants) + + # Compute new confidence intervals. + _add_frequency_ci(ds=ds_aa_frq, ci_method=ci_method) + + # Tidy up display by sorting variables. + ds_aa_frq = ds_aa_frq[sorted(ds_aa_frq)] + + gene_name = self._transcript_to_parent_name(transcript) + title = transcript + if gene_name: + title += f" ({gene_name})" + title += " SNP frequencies" + ds_aa_frq.attrs["title"] = title + + return ds_aa_frq + + @check_types + @doc( + summary=""" + Plot a heatmap from a pandas DataFrame of frequencies, e.g., output + from `snp_allele_frequencies()` or `gene_cnv_frequencies()`. + """, + parameters=dict( + df=""" + A DataFrame of frequencies, e.g., output from + `snp_allele_frequencies()` or `gene_cnv_frequencies()`. + """, + index=""" + One or more column headers that are present in the input dataframe. + This becomes the heatmap y-axis row labels. The column/s must + produce a unique index. + """, + max_len=""" + Displaying large styled dataframes may cause ipython notebooks to + crash. If the input dataframe is larger than this value, an error + will be raised. + """, + col_width=""" + Plot width per column in pixels (px). + """, + row_height=""" + Plot height per row in pixels (px). + """, + kwargs=""" + Passed through to `px.imshow()`. + """, + ), + notes=""" + It's recommended to filter the input DataFrame to just rows of interest, + i.e., fewer rows than `max_len`. + """, + ) + def plot_frequencies_heatmap( + self, + df: pd.DataFrame, + index: Optional[Union[str, List[str]]] = "label", + max_len: Optional[int] = 100, + col_width: int = 40, + row_height: int = 20, + x_label: plotly_params.x_label = "Cohorts", + y_label: plotly_params.y_label = "Variants", + colorbar: plotly_params.colorbar = True, + width: plotly_params.width = None, + height: plotly_params.height = None, + text_auto: plotly_params.text_auto = ".0%", + aspect: plotly_params.aspect = "auto", + color_continuous_scale: plotly_params.color_continuous_scale = "Reds", + title: plotly_params.title = True, + show: plotly_params.show = True, + renderer: plotly_params.renderer = None, + **kwargs, + ) -> plotly_params.figure: + # Check len of input. + if max_len and len(df) > max_len: + raise ValueError( + dedent( + f""" + Input DataFrame is longer than max_len parameter value {max_len}, which means + that the plot is likely to be very large. If you really want to go ahead, + please rerun the function with max_len=None. + """ + ) + ) + + # Handle title. + if title is True: + title = df.attrs.get("title", None) + + # Indexing. + if index is None: + index = list(df.index.names) + df = df.reset_index().copy() + if isinstance(index, list): + index_col = ( + df[index] + .astype(str) + .apply( + lambda row: ", ".join([o for o in row if o is not None]), + axis="columns", + ) + ) + else: + assert isinstance(index, str) + index_col = df[index].astype(str) + + # Check that index is unique. + if not index_col.is_unique: + raise ValueError(f"{index} does not produce a unique index") + + # Drop and re-order columns. + frq_cols = [col for col in df.columns if col.startswith("frq_")] + + # Keep only freq cols. + heatmap_df = df[frq_cols].copy() + + # Set index. + heatmap_df.set_index(index_col, inplace=True) + + # Clean column names. + heatmap_df.columns = heatmap_df.columns.str.lstrip("frq_") + + # Deal with width and height. + if width is None: + width = 400 + col_width * len(heatmap_df.columns) + if colorbar: + width += 40 + if height is None: + height = 200 + row_height * len(heatmap_df) + if title is not None: + height += 40 + + # Plotly heatmap styling. + fig = px.imshow( + img=heatmap_df, + zmin=0, + zmax=1, + width=width, + height=height, + text_auto=text_auto, + aspect=aspect, + color_continuous_scale=color_continuous_scale, + title=title, + **kwargs, + ) + + fig.update_xaxes(side="bottom", tickangle=30) + if x_label is not None: + fig.update_xaxes(title=x_label) + if y_label is not None: + fig.update_yaxes(title=y_label) + fig.update_layout( + coloraxis_colorbar=dict( + title="Frequency", + tickvals=[0, 0.2, 0.4, 0.6, 0.8, 1.0], + ticktext=["0%", "20%", "40%", "60%", "80%", "100%"], + ) + ) + if not colorbar: + fig.update(layout_coloraxis_showscale=False) + + if show: # pragma: no cover + fig.show(renderer=renderer) + return None + else: + return fig + + @check_types + @doc( + summary="Create a time series plot of variant frequencies using plotly.", + parameters=dict( + ds=""" + A dataset of variant frequencies, such as returned by + `snp_allele_frequencies_advanced()`, + `aa_allele_frequencies_advanced()` or + `gene_cnv_frequencies_advanced()`. + """, + kwargs="Passed through to `px.line()`.", + ), + returns=""" + A plotly figure containing line graphs. The resulting figure will + have one panel per cohort, grouped into columns by taxon, and + grouped into rows by area. Markers and lines show frequencies of + variants. + """, + ) + def plot_frequencies_time_series( + self, + ds: xr.Dataset, + height: plotly_params.height = None, + width: plotly_params.width = None, + title: plotly_params.title = True, + legend_sizing: plotly_params.legend_sizing = "constant", + show: plotly_params.show = True, + renderer: plotly_params.renderer = None, + **kwargs, + ) -> plotly_params.figure: + # Handle title. + if title is True: + title = ds.attrs.get("title", None) + + # Extract cohorts into a dataframe. + cohort_vars = [v for v in ds if str(v).startswith("cohort_")] + df_cohorts = ds[cohort_vars].to_dataframe() + df_cohorts.columns = [c.split("cohort_")[1] for c in df_cohorts.columns] # type: ignore + + # Extract variant labels. + variant_labels = ds["variant_label"].values + + # Build a long-form dataframe from the dataset. + dfs = [] + for cohort_index, cohort in enumerate(df_cohorts.itertuples()): + ds_cohort = ds.isel(cohorts=cohort_index) + df = pd.DataFrame( + { + "taxon": cohort.taxon, + "area": cohort.area, + "date": cohort.period_start, + "period": str( + cohort.period + ), # use string representation for hover label + "sample_size": cohort.size, + "variant": variant_labels, + "count": ds_cohort["event_count"].values, + "nobs": ds_cohort["event_nobs"].values, + "frequency": ds_cohort["event_frequency"].values, + "frequency_ci_low": ds_cohort["event_frequency_ci_low"].values, + "frequency_ci_upp": ds_cohort["event_frequency_ci_upp"].values, + } + ) + dfs.append(df) + df_events = pd.concat(dfs, axis=0).reset_index(drop=True) + + # Remove events with no observations. + df_events = df_events.query("nobs > 0").copy() + + # Calculate error bars. + frq = df_events["frequency"] + frq_ci_low = df_events["frequency_ci_low"] + frq_ci_upp = df_events["frequency_ci_upp"] + df_events["frequency_error"] = frq_ci_upp - frq + df_events["frequency_error_minus"] = frq - frq_ci_low + + # Make a plot. + fig = px.line( + df_events, + facet_col="taxon", + facet_row="area", + x="date", + y="frequency", + error_y="frequency_error", + error_y_minus="frequency_error_minus", + color="variant", + markers=True, + hover_name="variant", + hover_data={ + "frequency": ":.0%", + "period": True, + "area": True, + "taxon": True, + "sample_size": True, + "date": False, + "variant": False, + }, + height=height, + width=width, + title=title, + labels={ + "date": "Date", + "frequency": "Frequency", + "variant": "Variant", + "taxon": "Taxon", + "area": "Area", + "period": "Period", + "sample_size": "Sample size", + }, + **kwargs, + ) + + # Tidy plot. + fig.update_layout( + yaxis_range=[-0.05, 1.05], + legend=dict(itemsizing=legend_sizing, tracegroupgap=0), + ) + + if show: # pragma: no cover + fig.show(renderer=renderer) + return None + else: + return fig + + @check_types + @doc( + summary=""" + Plot markers on a map showing variant frequencies for cohorts grouped + by area (space), period (time) and taxon. + """, + parameters=dict( + m="The map on which to add the markers.", + variant="Index or label of variant to plot.", + taxon="Taxon to show markers for.", + period="Time period to show markers for.", + clear=""" + If True, clear all layers (except the base layer) from the map + before adding new markers. + """, + ), + ) + def plot_frequencies_map_markers( + self, + m, + ds: frq_params.ds_frequencies_advanced, + variant: Union[int, str], + taxon: str, + period: pd.Period, + clear: bool = True, + ): + # Only import here because of some problems importing globally. + import ipyleaflet # type: ignore + import ipywidgets # type: ignore + + # Slice dataset to variant of interest. + if isinstance(variant, int): + ds_variant = ds.isel(variants=variant) + variant_label = ds["variant_label"].values[variant] + else: + assert isinstance(variant, str) + ds_variant = ds.set_index(variants="variant_label").sel(variants=variant) + variant_label = variant + + # Convert to a dataframe for convenience. + df_markers = ds_variant[ + [ + "cohort_taxon", + "cohort_area", + "cohort_period", + "cohort_lat_mean", + "cohort_lon_mean", + "cohort_size", + "event_frequency", + "event_frequency_ci_low", + "event_frequency_ci_upp", + ] + ].to_dataframe() + + # Select data matching taxon and period parameters. + df_markers = df_markers.loc[ + ( + (df_markers["cohort_taxon"] == taxon) + & (df_markers["cohort_period"] == period) + ) + ] + + # Clear existing layers in the map. + if clear: + for layer in m.layers[1:]: + m.remove_layer(layer) + + # Add markers. + for x in df_markers.itertuples(): + marker = ipyleaflet.CircleMarker() + marker.location = (x.cohort_lat_mean, x.cohort_lon_mean) + marker.radius = 20 + marker.color = "black" + marker.weight = 1 + marker.fill_color = "red" + marker.fill_opacity = x.event_frequency + popup_html = f""" + {variant_label}
+ Taxon: {x.cohort_taxon}
+ Area: {x.cohort_area}
+ Period: {x.cohort_period}
+ Sample size: {x.cohort_size}
+ Frequency: {x.event_frequency:.0%} + (95% CI: {x.event_frequency_ci_low:.0%} - {x.event_frequency_ci_upp:.0%}) + """ + marker.popup = ipyleaflet.Popup( + child=ipywidgets.HTML(popup_html), + ) + m.add(marker) + + @check_types + @doc( + summary=""" + Create an interactive map with markers showing variant frequencies or + cohorts grouped by area (space), period (time) and taxon. + """, + parameters=dict( + title=""" + If True, attempt to use metadata from input dataset as a plot + title. Otherwise, use supplied value as a title. + """, + epilogue="Additional text to display below the map.", + ), + returns=""" + An interactive map with widgets for selecting which variant, taxon + and time period to display. + """, + ) + def plot_frequencies_interactive_map( + self, + ds: frq_params.ds_frequencies_advanced, + center: map_params.center = map_params.center_default, + zoom: map_params.zoom = map_params.zoom_default, + title: Union[bool, str] = True, + epilogue: Union[bool, str] = True, + ): + import ipyleaflet + import ipywidgets + + # Handle title. + if title is True: + title = ds.attrs.get("title", None) + + # Create a map. + freq_map = ipyleaflet.Map(center=center, zoom=zoom) + + # Set up interactive controls. + variants = ds["variant_label"].values + taxa = np.unique(ds["cohort_taxon"].values) + periods = np.unique(ds["cohort_period"].values) + controls = ipywidgets.interactive( + self.plot_frequencies_map_markers, + m=ipywidgets.fixed(freq_map), + ds=ipywidgets.fixed(ds), + variant=ipywidgets.Dropdown(options=variants, description="Variant: "), + taxon=ipywidgets.Dropdown(options=taxa, description="Taxon: "), + period=ipywidgets.Dropdown(options=periods, description="Period: "), + clear=ipywidgets.fixed(True), + ) + + # Lay out widgets. + components = [] + if title is not None: + components.append(ipywidgets.HTML(value=f"

{title}

")) + components.append(controls) + components.append(freq_map) + if epilogue is True: + epilogue = """ + Variant frequencies are shown as coloured markers. Opacity of color + denotes frequency. Click on a marker for more information. + """ + if epilogue: + components.append(ipywidgets.HTML(value=f"{epilogue}")) + + out = ipywidgets.VBox(components) + + return out + + +def _make_snp_label(contig, position, ref_allele, alt_allele): + return f"{contig}:{position:,} {ref_allele}>{alt_allele}" + + +def _make_snp_label_effect(contig, position, ref_allele, alt_allele, aa_change): + label = f"{contig}:{position:,} {ref_allele}>{alt_allele}" + if isinstance(aa_change, str): + label += f" ({aa_change})" + return label + + +def _make_snp_label_aa(aa_change, contig, position, ref_allele, alt_allele): + label = f"{aa_change} ({contig}:{position:,} {ref_allele}>{alt_allele})" + return label + + +def _make_sample_period_month(row): + year = row.year + month = row.month + if year > 0 and month > 0: + return pd.Period(freq="M", year=year, month=month) + else: + return pd.NaT + + +def _make_sample_period_quarter(row): + year = row.year + month = row.month + if year > 0 and month > 0: + return pd.Period(freq="Q", year=year, month=month) + else: + return pd.NaT + + +def _make_sample_period_year(row): + year = row.year + if year > 0: + return pd.Period(freq="A", year=year) + else: + return pd.NaT + + +def _prep_samples_for_cohort_grouping(*, df_samples, area_by, period_by): + # Take a copy, as we will modify the dataframe. + df_samples = df_samples.copy() + + # Fix intermediate taxon values - we only want to build cohorts with clean + # taxon calls, so we set intermediate values to None. + loc_intermediate_taxon = ( + df_samples["taxon"].str.startswith("intermediate").fillna(False) + ) + df_samples.loc[loc_intermediate_taxon, "taxon"] = None + + # Add period column. + if period_by == "year": + make_period = _make_sample_period_year + elif period_by == "quarter": + make_period = _make_sample_period_quarter + elif period_by == "month": + make_period = _make_sample_period_month + else: # pragma: no cover + raise ValueError( + f"Value for period_by parameter must be one of 'year', 'quarter', 'month'; found {period_by!r}." + ) + sample_period = df_samples.apply(make_period, axis="columns") + df_samples["period"] = sample_period + + # Add area column for consistent output. + df_samples["area"] = df_samples[area_by] + + return df_samples + + +def _build_cohorts_from_sample_grouping(*, group_samples_by_cohort, min_cohort_size): + # Build cohorts dataframe. + df_cohorts = group_samples_by_cohort.agg( + size=("sample_id", len), + lat_mean=("latitude", "mean"), + lat_max=("latitude", "mean"), + lat_min=("latitude", "mean"), + lon_mean=("longitude", "mean"), + lon_max=("longitude", "mean"), + lon_min=("longitude", "mean"), + ) + # Reset index so that the index fields are included as columns. + df_cohorts = df_cohorts.reset_index() + + # Add cohort helper variables. + cohort_period_start = df_cohorts["period"].apply(lambda v: v.start_time) + cohort_period_end = df_cohorts["period"].apply(lambda v: v.end_time) + df_cohorts["period_start"] = cohort_period_start + df_cohorts["period_end"] = cohort_period_end + # Create a label that is similar to the cohort metadata, + # although this won't be perfect. + df_cohorts["label"] = df_cohorts.apply( + lambda v: f"{v.area}_{v.taxon[:4]}_{v.period}", axis="columns" + ) + + # Apply minimum cohort size. + df_cohorts = df_cohorts.query(f"size >= {min_cohort_size}").reset_index(drop=True) + + return df_cohorts + + +def _cohort_alt_allele_counts_melt(*, gt, indices, max_allele): + ac_alt_melt, an = _cohort_alt_allele_counts_melt_kernel(gt, indices, max_allele) + an_melt = np.repeat(an, max_allele, axis=0) + return ac_alt_melt, an_melt + + +@numba.njit +def _cohort_alt_allele_counts_melt_kernel( + gt, sample_indices, max_allele +): # pragma: no cover + n_variants = gt.shape[0] + n_samples = sample_indices.shape[0] + ploidy = gt.shape[2] + + ac_alt_melt = np.zeros(n_variants * max_allele, dtype=np.int64) + an = np.zeros(n_variants, dtype=np.int64) + + for i in range(n_variants): + out_i_offset = (i * max_allele) - 1 + for j in range(n_samples): + sample_index = sample_indices[j] + for k in range(ploidy): + allele = gt[i, sample_index, k] + if allele > 0: + out_i = out_i_offset + allele + ac_alt_melt[out_i] += 1 + an[i] += 1 + elif allele == 0: + an[i] += 1 + + return ac_alt_melt, an + + +def _add_frequency_ci(*, ds, ci_method): + from statsmodels.stats.proportion import proportion_confint # type: ignore + + if ci_method is not None: + count = ds["event_count"].values + nobs = ds["event_nobs"].values + with np.errstate(divide="ignore", invalid="ignore"): + frq_ci_low, frq_ci_upp = proportion_confint( + count=count, nobs=nobs, method=ci_method + ) + ds["event_frequency_ci_low"] = ("variants", "cohorts"), frq_ci_low + ds["event_frequency_ci_upp"] = ("variants", "cohorts"), frq_ci_upp + + +def _map_snp_to_aa_change_frq_ds(ds): + # Keep only variables that make sense for amino acid substitutions. + keep_vars = [ + "variant_contig", + "variant_position", + "variant_transcript", + "variant_effect", + "variant_impact", + "variant_aa_pos", + "variant_aa_change", + "variant_ref_allele", + "variant_ref_aa", + "variant_alt_aa", + "event_nobs", + ] + + if ds.sizes["variants"] == 1: + # Keep everything as-is, no need for aggregation. + ds_out = ds[keep_vars + ["variant_alt_allele", "event_count"]] + + else: + # Take the first value from all variants variables. + ds_out = ds[keep_vars].isel(variants=[0]) + + # Sum event count over variants. + count = ds["event_count"].values.sum(axis=0, keepdims=True) + ds_out["event_count"] = ("variants", "cohorts"), count + + # Collapse alt allele. + alt_allele = "{" + ",".join(ds["variant_alt_allele"].values) + "}" + ds_out["variant_alt_allele"] = ( + "variants", + np.array([alt_allele], dtype=object), + ) + + return ds_out diff --git a/malariagen_data/anopheles.py b/malariagen_data/anopheles.py index 783e8017e..591ed07f3 100644 --- a/malariagen_data/anopheles.py +++ b/malariagen_data/anopheles.py @@ -3,7 +3,7 @@ from abc import abstractmethod from bisect import bisect_left, bisect_right from collections import Counter -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, Sequence +from typing import Any, Dict, List, Mapping, Optional, Tuple, Sequence import allel # type: ignore import bokeh.layouts @@ -20,14 +20,18 @@ from numpydoc_decorator import doc # type: ignore from malariagen_data.anoph import tree_params +from malariagen_data.anoph.snp_frq import ( + AnophelesSnpFrequencyAnalysis, + _add_frequency_ci, + _build_cohorts_from_sample_grouping, + _prep_samples_for_cohort_grouping, +) -from . import veff from .anoph import ( aim_params, base_params, dash_params, diplotype_distance_params, - frq_params, fst_params, g123_params, gplt_params, @@ -36,7 +40,6 @@ hapnet_params, het_params, ihs_params, - map_params, plotly_params, xpehh_params, ) @@ -49,7 +52,7 @@ from .anoph.hap_data import AnophelesHapData, hap_params from .anoph.igv import AnophelesIgv from .anoph.pca import AnophelesPca -from .anoph.sample_metadata import AnophelesSampleMetadata +from .anoph.sample_metadata import AnophelesSampleMetadata, locate_cohorts from .anoph.snp_data import AnophelesSnpData from .mjn import median_joining_network, mjn_graph from .util import ( @@ -61,18 +64,15 @@ biallelic_diplotype_sqeuclidean, check_types, jackknife_ci, - locate_region, parse_multi_region, parse_single_region, pdist_abs_hamming, plotly_discrete_legend, region_str, simple_xarray_concat, + pandas_apply, ) -AA_CHANGE_QUERY = ( - "effect in ['NON_SYNONYMOUS_CODING', 'START_LOST', 'STOP_LOST', 'STOP_GAINED']" -) DEFAULT_MAX_COVERAGE_VARIANCE = 0.2 @@ -98,6 +98,7 @@ # work around pycharm failing to recognise that doc() is callable # noinspection PyCallingNonCallable class AnophelesDataResource( + AnophelesSnpFrequencyAnalysis, AnophelesPca, AnophelesIgv, AnophelesCnvData, @@ -140,6 +141,7 @@ def __init__( storage_options: Mapping, # used by fsspec via init_filesystem(url, **kwargs) taxon_colors: Optional[Mapping[str, str]], virtual_contigs: Optional[Mapping[str, Sequence[str]]], + gene_names: Optional[Mapping[str, str]], ): super().__init__( url=url, @@ -169,12 +171,9 @@ def __init__( tqdm_class=tqdm_class, taxon_colors=taxon_colors, virtual_contigs=virtual_contigs, + gene_names=gene_names, ) - # set up caches - # TODO review type annotations here, maybe can tighten - self._cache_annotator = None - @property @abstractmethod def _fst_gwss_results_cache_name(self): @@ -215,311 +214,6 @@ def _h1x_gwss_cache_name(self): def _ihs_gwss_cache_name(self): raise NotImplementedError("Must override _ihs_gwss_cache_name") - @abstractmethod - def _transcript_to_gene_name(self, transcript): - # children may have different manual overrides. - raise NotImplementedError("Must override _transcript_to_gene_name") - - @check_types - @doc( - summary=""" - Group samples by taxon, area (space) and period (time), then compute - SNP allele frequencies. - """, - returns=""" - The resulting dataset contains data has dimensions "cohorts" and - "variants". Variables prefixed with "cohort" are 1-dimensional - arrays with data about the cohorts, such as the area, period, taxon - and cohort size. Variables prefixed with "variant" are - 1-dimensional arrays with data about the variants, such as the - contig, position, reference and alternate alleles. Variables - prefixed with "event" are 2-dimensional arrays with the allele - counts and frequency calculations. - """, - ) - def snp_allele_frequencies_advanced( - self, - transcript: base_params.transcript, - area_by: frq_params.area_by, - period_by: frq_params.period_by, - sample_sets: Optional[base_params.sample_sets] = None, - sample_query: Optional[base_params.sample_query] = None, - min_cohort_size: base_params.min_cohort_size = 10, - drop_invariant: frq_params.drop_invariant = True, - variant_query: Optional[frq_params.variant_query] = None, - site_mask: Optional[base_params.site_mask] = None, - nobs_mode: frq_params.nobs_mode = frq_params.nobs_mode_default, - ci_method: Optional[frq_params.ci_method] = frq_params.ci_method_default, - ) -> xr.Dataset: - debug = self._log.debug - - debug("check parameters") - self._check_param_min_cohort_size(min_cohort_size) - - debug("load sample metadata") - df_samples = self.sample_metadata( - sample_sets=sample_sets, sample_query=sample_query - ) - - debug("access SNP calls") - ds_snps = self.snp_calls( - region=transcript, - sample_sets=sample_sets, - sample_query=sample_query, - site_mask=site_mask, - ) - - debug("access genotypes") - gt = ds_snps["call_genotype"].data - - debug("prepare sample metadata for cohort grouping") - df_samples = self._prep_samples_for_cohort_grouping( - df_samples=df_samples, - area_by=area_by, - period_by=period_by, - ) - - debug("group samples to make cohorts") - group_samples_by_cohort = df_samples.groupby(["taxon", "area", "period"]) - - debug("build cohorts dataframe") - df_cohorts = self._build_cohorts_from_sample_grouping( - group_samples_by_cohort, min_cohort_size - ) - - debug("bring genotypes into memory") - with self._dask_progress(desc="Load SNP genotypes"): - gt = gt.compute() - - debug("set up variant variables") - contigs = ds_snps.attrs["contigs"] - variant_contig = np.repeat( - [contigs[i] for i in ds_snps["variant_contig"].values], 3 - ) - variant_position = np.repeat(ds_snps["variant_position"].values, 3) - alleles = ds_snps["variant_allele"].values - variant_ref_allele = np.repeat(alleles[:, 0], 3) - variant_alt_allele = alleles[:, 1:].flatten() - variant_pass = dict() - for site_mask in self.site_mask_ids: - variant_pass[site_mask] = np.repeat( - ds_snps[f"variant_filter_pass_{site_mask}"].values, 3 - ) - - debug("setup main event variables") - n_variants, n_cohorts = len(variant_position), len(df_cohorts) - count = np.zeros((n_variants, n_cohorts), dtype=int) - nobs = np.zeros((n_variants, n_cohorts), dtype=int) - - debug("build event count and nobs for each cohort") - cohorts_iterator = self._progress( - enumerate(df_cohorts.itertuples()), - total=len(df_cohorts), - desc="Compute SNP allele frequencies", - ) - for cohort_index, cohort in cohorts_iterator: - cohort_key = cohort.taxon, cohort.area, cohort.period - sample_indices = group_samples_by_cohort.indices[cohort_key] - - cohort_ac, cohort_an = self._cohort_alt_allele_counts_melt( - gt, sample_indices, max_allele=3 - ) - count[:, cohort_index] = cohort_ac - - if nobs_mode == "called": - nobs[:, cohort_index] = cohort_an - elif nobs_mode == "fixed": - nobs[:, cohort_index] = cohort.size * 2 - else: - raise ValueError(f"Bad nobs_mode: {nobs_mode!r}") - - debug("compute frequency") - with np.errstate(divide="ignore", invalid="ignore"): - # ignore division warnings - frequency = count / nobs - - debug("compute maximum frequency over cohorts") - with warnings.catch_warnings(): - # ignore "All-NaN slice encountered" warnings - warnings.simplefilter("ignore", category=RuntimeWarning) - max_af = np.nanmax(frequency, axis=1) - - debug("make dataframe of SNPs") - df_variants_cols = { - "contig": variant_contig, - "position": variant_position, - "ref_allele": variant_ref_allele.astype("U1"), - "alt_allele": variant_alt_allele.astype("U1"), - "max_af": max_af, - } - for site_mask in self.site_mask_ids: - df_variants_cols[f"pass_{site_mask}"] = variant_pass[site_mask] - df_variants = pd.DataFrame(df_variants_cols) - - debug("deal with SNP alleles not observed") - if drop_invariant: - loc_variant = max_af > 0 - df_variants = df_variants.loc[loc_variant].reset_index(drop=True) - count = np.compress(loc_variant, count, axis=0) - nobs = np.compress(loc_variant, nobs, axis=0) - frequency = np.compress(loc_variant, frequency, axis=0) - - debug("set up variant effect annotator") - ann = self._annotator() - - debug("add effects to the dataframe") - ann.get_effects( - transcript=transcript, variants=df_variants, progress=self._progress - ) - - debug("add variant labels") - df_variants["label"] = self._pandas_apply( - self._make_snp_label_effect, - df_variants, - columns=["contig", "position", "ref_allele", "alt_allele", "aa_change"], - ) - - debug("build the output dataset") - ds_out = xr.Dataset() - - debug("cohort variables") - for coh_col in df_cohorts.columns: - ds_out[f"cohort_{coh_col}"] = "cohorts", df_cohorts[coh_col] - - debug("variant variables") - for snp_col in df_variants.columns: - ds_out[f"variant_{snp_col}"] = "variants", df_variants[snp_col] - - debug("event variables") - ds_out["event_count"] = ("variants", "cohorts"), count - ds_out["event_nobs"] = ("variants", "cohorts"), nobs - ds_out["event_frequency"] = ("variants", "cohorts"), frequency - - debug("apply variant query") - if variant_query is not None: - loc_variants = df_variants.eval(variant_query).values - ds_out = ds_out.isel(variants=loc_variants) - - debug("add confidence intervals") - self._add_frequency_ci(ds_out, ci_method) - - debug("tidy up display by sorting variables") - sorted_vars: List[str] = sorted([str(k) for k in ds_out.keys()]) - ds_out = ds_out[sorted_vars] - - debug("add metadata") - gene_name = self._transcript_to_gene_name(transcript) - title = transcript - if gene_name: - title += f" ({gene_name})" - title += " SNP frequencies" - ds_out.attrs["title"] = title - - return ds_out - - # Start of @staticmethod - - @staticmethod - def _locate_cohorts(*, cohorts, df_samples): - # build cohort dictionary where key=cohort_id, value=loc_coh - coh_dict = {} - - if isinstance(cohorts, dict): - # user has supplied a custom dictionary mapping cohort identifiers - # to pandas queries - - for coh, query in cohorts.items(): - # locate samples - loc_coh = df_samples.eval(query).values - coh_dict[coh] = loc_coh - - if isinstance(cohorts, str): - # user has supplied one of the predefined cohort sets - - # fix the string to match columns - if not cohorts.startswith("cohort_"): - cohorts = "cohort_" + cohorts - - # check the given cohort set exists - if cohorts not in df_samples.columns: - raise ValueError(f"{cohorts!r} is not a known cohort set") - cohort_labels = df_samples[cohorts].unique() - - # remove the nans and sort - cohort_labels = sorted([c for c in cohort_labels if isinstance(c, str)]) - for coh in cohort_labels: - loc_coh = df_samples[cohorts] == coh - coh_dict[coh] = loc_coh.values - - return coh_dict - - @staticmethod - def _make_sample_period_month(row): - year = row.year - month = row.month - if year > 0 and month > 0: - return pd.Period(freq="M", year=year, month=month) - else: - return pd.NaT - - @staticmethod - def _make_sample_period_quarter(row): - year = row.year - month = row.month - if year > 0 and month > 0: - return pd.Period(freq="Q", year=year, month=month) - else: - return pd.NaT - - @staticmethod - def _make_sample_period_year(row): - year = row.year - if year > 0: - return pd.Period(freq="A", year=year) - else: - return pd.NaT - - @staticmethod - @numba.njit - def _cohort_alt_allele_counts_melt_kernel(gt, indices, max_allele): - n_variants = gt.shape[0] - n_indices = indices.shape[0] - ploidy = gt.shape[2] - - ac_alt_melt = np.zeros(n_variants * max_allele, dtype=np.int64) - an = np.zeros(n_variants, dtype=np.int64) - - for i in range(n_variants): - out_i_offset = (i * max_allele) - 1 - for j in range(n_indices): - ix = indices[j] - for k in range(ploidy): - allele = gt[i, ix, k] - if allele > 0: - out_i = out_i_offset + allele - ac_alt_melt[out_i] += 1 - an[i] += 1 - elif allele == 0: - an[i] += 1 - - return ac_alt_melt, an - - @staticmethod - def _make_snp_label(contig, position, ref_allele, alt_allele): - return f"{contig}:{position:,} {ref_allele}>{alt_allele}" - - @staticmethod - def _make_snp_label_effect(contig, position, ref_allele, alt_allele, aa_change): - label = f"{contig}:{position:,} {ref_allele}>{alt_allele}" - if isinstance(aa_change, str): - label += f" ({aa_change})" - return label - - @staticmethod - def _make_snp_label_aa(aa_change, contig, position, ref_allele, alt_allele): - label = f"{aa_change} ({contig}:{position:,} {ref_allele}>{alt_allele})" - return label - @staticmethod def _make_gene_cnv_label(gene_id, gene_name, cnv_type): label = gene_id @@ -528,110 +222,6 @@ def _make_gene_cnv_label(gene_id, gene_name, cnv_type): label += f" {cnv_type}" return label - @staticmethod - def _map_snp_to_aa_change_frq_ds(ds): - # keep only variables that make sense for amino acid substitutions - keep_vars = [ - "variant_contig", - "variant_position", - "variant_transcript", - "variant_effect", - "variant_impact", - "variant_aa_pos", - "variant_aa_change", - "variant_ref_allele", - "variant_ref_aa", - "variant_alt_aa", - "event_nobs", - ] - - if ds.sizes["variants"] == 1: - # keep everything as-is, no need for aggregation - ds_out = ds[keep_vars + ["variant_alt_allele", "event_count"]] - - else: - # take the first value from all variants variables - ds_out = ds[keep_vars].isel(variants=[0]) - - # sum event count over variants - count = ds["event_count"].values.sum(axis=0, keepdims=True) - ds_out["event_count"] = ("variants", "cohorts"), count - - # collapse alt allele - alt_allele = "{" + ",".join(ds["variant_alt_allele"].values) + "}" - ds_out["variant_alt_allele"] = ( - "variants", - np.array([alt_allele], dtype=object), - ) - - return ds_out - - @staticmethod - def _add_frequency_ci(ds, ci_method): - from statsmodels.stats.proportion import proportion_confint # type: ignore - - if ci_method is not None: - count = ds["event_count"].values - nobs = ds["event_nobs"].values - with np.errstate(divide="ignore", invalid="ignore"): - frq_ci_low, frq_ci_upp = proportion_confint( - count=count, nobs=nobs, method=ci_method - ) - ds["event_frequency_ci_low"] = ("variants", "cohorts"), frq_ci_low - ds["event_frequency_ci_upp"] = ("variants", "cohorts"), frq_ci_upp - - @staticmethod - def _build_cohorts_from_sample_grouping(group_samples_by_cohort, min_cohort_size): - # build cohorts dataframe - df_cohorts = group_samples_by_cohort.agg( - size=("sample_id", len), - lat_mean=("latitude", "mean"), - lat_max=("latitude", "mean"), - lat_min=("latitude", "mean"), - lon_mean=("longitude", "mean"), - lon_max=("longitude", "mean"), - lon_min=("longitude", "mean"), - ) - # reset index so that the index fields are included as columns - df_cohorts = df_cohorts.reset_index() - - # add cohort helper variables - cohort_period_start = df_cohorts["period"].apply(lambda v: v.start_time) - cohort_period_end = df_cohorts["period"].apply(lambda v: v.end_time) - df_cohorts["period_start"] = cohort_period_start - df_cohorts["period_end"] = cohort_period_end - # create a label that is similar to the cohort metadata, - # although this won't be perfect - df_cohorts["label"] = df_cohorts.apply( - lambda v: f"{v.area}_{v.taxon[:4]}_{v.period}", axis="columns" - ) - - # apply minimum cohort size - df_cohorts = df_cohorts.query(f"size >= {min_cohort_size}").reset_index( - drop=True - ) - - return df_cohorts - - @staticmethod - def _check_param_min_cohort_size(min_cohort_size): - if not isinstance(min_cohort_size, int): - raise TypeError( - f"Type of parameter min_cohort_size must be int; found {type(min_cohort_size)}." - ) - if min_cohort_size < 1: - raise ValueError( - f"Value of parameter min_cohort_size must be at least 1; found {min_cohort_size}." - ) - - @staticmethod - def _pandas_apply(f, df, columns): - """Optimised alternative to pandas apply.""" - df = df.reset_index(drop=True) - iterator = zip(*[df[c].values for c in columns]) - ret = pd.Series((f(*vals) for vals in iterator)) - return ret - @staticmethod def _roh_hmm_predict( *, @@ -701,136 +291,6 @@ def _roh_hmm_predict( ] ] - def _snp_df(self, *, transcript: str) -> Tuple[Region, pd.DataFrame]: - """Set up a dataframe with SNP site and filter columns.""" - debug = self._log.debug - - debug("get feature direct from genome_features") - gs = self.genome_features() - - with self._spinner(desc="Prepare SNP data"): - feature = gs[gs["ID"] == transcript].squeeze() - if feature.empty: - raise ValueError( - f"No genome feature ID found matching transcript {transcript}" - ) - contig = feature.contig - region = Region(contig, feature.start, feature.end) - - debug("grab pos, ref and alt for chrom arm from snp_sites") - pos = self.snp_sites(region=contig, field="POS") - ref = self.snp_sites(region=contig, field="REF") - alt = self.snp_sites(region=contig, field="ALT") - loc_feature = locate_region(region, pos) - pos = pos[loc_feature].compute() - ref = ref[loc_feature].compute() - alt = alt[loc_feature].compute() - - debug("access site filters") - filter_pass = dict() - masks = self.site_mask_ids - for m in masks: - x = self.site_filters(region=contig, mask=m) - x = x[loc_feature].compute() - filter_pass[m] = x - - debug("set up columns with contig, pos, ref, alt columns") - cols = { - "contig": contig, - "position": np.repeat(pos, 3), - "ref_allele": np.repeat(ref.astype("U1"), 3), - "alt_allele": alt.astype("U1").flatten(), - } - - debug("add mask columns") - for m in masks: - x = filter_pass[m] - cols[f"pass_{m}"] = np.repeat(x, 3) - - debug("construct dataframe") - df_snps = pd.DataFrame(cols) - - return region, df_snps - - def _annotator(self): - """Set up variant effect annotator.""" - if self._cache_annotator is None: - self._cache_annotator = veff.Annotator( - genome=self.open_genome(), genome_features=self.genome_features() - ) - return self._cache_annotator - - @check_types - @doc( - summary="Compute variant effects for a gene transcript.", - returns=""" - A dataframe of all possible SNP variants and their effects, one row - per variant. - """, - ) - def snp_effects( - self, - transcript: base_params.transcript, - site_mask: Optional[base_params.site_mask] = None, - ) -> pd.DataFrame: - debug = self._log.debug - - debug("setup initial dataframe of SNPs") - _, df_snps = self._snp_df(transcript=transcript) - - debug("setup variant effect annotator") - ann = self._annotator() - - debug("apply mask if requested") - if site_mask is not None: - loc_sites = df_snps[f"pass_{site_mask}"] - df_snps = df_snps.loc[loc_sites] - - debug("reset index after filtering") - df_snps.reset_index(inplace=True, drop=True) - - debug("add effects to the dataframe") - ann.get_effects(transcript=transcript, variants=df_snps) - - return df_snps - - def _cohort_alt_allele_counts_melt(self, gt, indices, max_allele): - ac_alt_melt, an = self._cohort_alt_allele_counts_melt_kernel( - gt, indices, max_allele - ) - an_melt = np.repeat(an, max_allele, axis=0) - return ac_alt_melt, an_melt - - def _prep_samples_for_cohort_grouping(self, *, df_samples, area_by, period_by): - # take a copy, as we will modify the dataframe - df_samples = df_samples.copy() - - # fix intermediate taxon values - we only want to build cohorts with clean - # taxon calls, so we set intermediate values to None - loc_intermediate_taxon = ( - df_samples["taxon"].str.startswith("intermediate").fillna(False) - ) - df_samples.loc[loc_intermediate_taxon, "taxon"] = None - - # add period column - if period_by == "year": - make_period = self._make_sample_period_year - elif period_by == "quarter": - make_period = self._make_sample_period_quarter - elif period_by == "month": - make_period = self._make_sample_period_month - else: - raise ValueError( - f"Value for period_by parameter must be one of 'year', 'quarter', 'month'; found {period_by!r}." - ) - sample_period = df_samples.apply(make_period, axis="columns") - df_samples["period"] = sample_period - - # add area column for consistent output - df_samples["area"] = df_samples[area_by] - - return df_samples - def _plot_heterozygosity_track( self, *, @@ -1356,352 +816,6 @@ def plot_roh( else: return fig_all - @check_types - @doc( - summary=""" - Compute SNP allele frequencies for a gene transcript. - """, - returns=""" - A dataframe of SNP allele frequencies, one row per variant allele. - """, - notes=""" - Cohorts with fewer samples than `min_cohort_size` will be excluded from - output data frame. - """, - ) - def snp_allele_frequencies( - self, - transcript: base_params.transcript, - cohorts: base_params.cohorts, - sample_query: Optional[base_params.sample_query] = None, - min_cohort_size: base_params.min_cohort_size = 10, - site_mask: Optional[base_params.site_mask] = None, - sample_sets: Optional[base_params.sample_sets] = None, - drop_invariant: frq_params.drop_invariant = True, - effects: frq_params.effects = True, - ) -> pd.DataFrame: - debug = self._log.debug - - debug("check parameters") - self._check_param_min_cohort_size(min_cohort_size) - - debug("access sample metadata") - df_samples = self.sample_metadata( - sample_sets=sample_sets, sample_query=sample_query - ) - - debug("setup initial dataframe of SNPs") - region, df_snps = self._snp_df(transcript=transcript) - - debug("get genotypes") - gt = self.snp_genotypes( - region=region, - sample_sets=sample_sets, - sample_query=sample_query, - field="GT", - ) - - debug("slice to feature location") - with self._dask_progress(desc="Load SNP genotypes"): - gt = gt.compute() - - debug("build coh dict") - coh_dict = self._locate_cohorts(cohorts=cohorts, df_samples=df_samples) - - debug("count alleles") - freq_cols = dict() - cohorts_iterator = self._progress( - coh_dict.items(), desc="Compute allele frequencies" - ) - for coh, loc_coh in cohorts_iterator: - n_samples = np.count_nonzero(loc_coh) - debug(f"{coh}, {n_samples} samples") - if n_samples >= min_cohort_size: - gt_coh = np.compress(loc_coh, gt, axis=1) - ac_coh = allel.GenotypeArray(gt_coh).count_alleles(max_allele=3) - af_coh = ac_coh.to_frequencies() - freq_cols["frq_" + coh] = af_coh[:, 1:].flatten() - - debug("build a dataframe with the frequency columns") - df_freqs = pd.DataFrame(freq_cols) - - debug("compute max_af") - df_max_af = pd.DataFrame({"max_af": df_freqs.max(axis=1)}) - - debug("build the final dataframe") - df_snps.reset_index(drop=True, inplace=True) - df_snps = pd.concat([df_snps, df_freqs, df_max_af], axis=1) - - debug("apply site mask if requested") - if site_mask is not None: - loc_sites = df_snps[f"pass_{site_mask}"] - df_snps = df_snps.loc[loc_sites] - - debug("drop invariants") - if drop_invariant: - loc_variant = df_snps["max_af"] > 0 - df_snps = df_snps.loc[loc_variant] - - debug("reset index after filtering") - df_snps.reset_index(inplace=True, drop=True) - - if effects: - debug("add effect annotations") - ann = self._annotator() - ann.get_effects( - transcript=transcript, variants=df_snps, progress=self._progress - ) - - debug("add label") - df_snps["label"] = self._pandas_apply( - self._make_snp_label_effect, - df_snps, - columns=["contig", "position", "ref_allele", "alt_allele", "aa_change"], - ) - - debug("set index") - df_snps.set_index( - ["contig", "position", "ref_allele", "alt_allele", "aa_change"], - inplace=True, - ) - - else: - debug("add label") - df_snps["label"] = self._pandas_apply( - self._make_snp_label, - df_snps, - columns=["contig", "position", "ref_allele", "alt_allele"], - ) - - debug("set index") - df_snps.set_index( - ["contig", "position", "ref_allele", "alt_allele"], - inplace=True, - ) - - debug("add dataframe metadata") - gene_name = self._transcript_to_gene_name(transcript) - title = transcript - if gene_name: - title += f" ({gene_name})" - title += " SNP frequencies" - df_snps.attrs["title"] = title - - return df_snps - - @check_types - @doc( - summary=""" - Compute amino acid substitution frequencies for a gene transcript. - """, - returns=""" - A dataframe of amino acid allele frequencies, one row per - substitution. - """, - notes=""" - Cohorts with fewer samples than `min_cohort_size` will be excluded from - output data frame. - """, - ) - def aa_allele_frequencies( - self, - transcript: base_params.transcript, - cohorts: base_params.cohorts, - sample_query: Optional[base_params.sample_query] = None, - min_cohort_size: Optional[base_params.min_cohort_size] = 10, - site_mask: Optional[base_params.site_mask] = None, - sample_sets: Optional[base_params.sample_sets] = None, - drop_invariant: frq_params.drop_invariant = True, - ) -> pd.DataFrame: - debug = self._log.debug - - df_snps = self.snp_allele_frequencies( - transcript=transcript, - cohorts=cohorts, - sample_query=sample_query, - min_cohort_size=min_cohort_size, - site_mask=site_mask, - sample_sets=sample_sets, - drop_invariant=drop_invariant, - effects=True, - ) - df_snps.reset_index(inplace=True) - - # we just want aa change - df_ns_snps = df_snps.query(AA_CHANGE_QUERY).copy() - - # N.B., we need to worry about the possibility of the - # same aa change due to SNPs at different positions. We cannot - # sum frequencies of SNPs at different genomic positions. This - # is why we group by position and aa_change, not just aa_change. - - debug("group and sum to collapse multi variant allele changes") - freq_cols = [col for col in df_ns_snps if col.startswith("frq")] - agg: Dict[str, Union[Callable, str]] = {c: np.nansum for c in freq_cols} - keep_cols = ( - "contig", - "transcript", - "aa_pos", - "ref_allele", - "ref_aa", - "alt_aa", - "effect", - "impact", - ) - for c in keep_cols: - agg[c] = "first" - agg["alt_allele"] = lambda v: "{" + ",".join(v) + "}" if len(v) > 1 else v - df_aaf = df_ns_snps.groupby(["position", "aa_change"]).agg(agg).reset_index() - - debug("compute new max_af") - df_aaf["max_af"] = df_aaf[freq_cols].max(axis=1) - - debug("add label") - df_aaf["label"] = self._pandas_apply( - self._make_snp_label_aa, - df_aaf, - columns=["aa_change", "contig", "position", "ref_allele", "alt_allele"], - ) - - debug("sort by genomic position") - df_aaf = df_aaf.sort_values(["position", "aa_change"]) - - debug("set index") - df_aaf.set_index(["aa_change", "contig", "position"], inplace=True) - - debug("add metadata") - gene_name = self._transcript_to_gene_name(transcript) - title = transcript - if gene_name: - title += f" ({gene_name})" - title += " SNP frequencies" - df_aaf.attrs["title"] = title - - return df_aaf - - @check_types - @doc( - summary=""" - Group samples by taxon, area (space) and period (time), then compute - amino acid change allele frequencies. - """, - returns=""" - The resulting dataset contains data has dimensions "cohorts" and - "variants". Variables prefixed with "cohort" are 1-dimensional - arrays with data about the cohorts, such as the area, period, taxon - and cohort size. Variables prefixed with "variant" are - 1-dimensional arrays with data about the variants, such as the - contig, position, reference and alternate alleles. Variables - prefixed with "event" are 2-dimensional arrays with the allele - counts and frequency calculations. - """, - ) - def aa_allele_frequencies_advanced( - self, - transcript: base_params.transcript, - area_by: frq_params.area_by, - period_by: frq_params.period_by, - sample_sets: Optional[base_params.sample_sets] = None, - sample_query: Optional[base_params.sample_query] = None, - min_cohort_size: base_params.min_cohort_size = 10, - variant_query: Optional[frq_params.variant_query] = None, - site_mask: Optional[base_params.site_mask] = None, - nobs_mode: frq_params.nobs_mode = "called", - ci_method: Optional[frq_params.ci_method] = "wilson", - ) -> xr.Dataset: - debug = self._log.debug - - debug("begin by computing SNP allele frequencies") - ds_snp_frq = self.snp_allele_frequencies_advanced( - transcript=transcript, - area_by=area_by, - period_by=period_by, - sample_sets=sample_sets, - sample_query=sample_query, - min_cohort_size=min_cohort_size, - drop_invariant=True, # always drop invariant for aa frequencies - variant_query=AA_CHANGE_QUERY, # we'll also apply a variant query later - site_mask=site_mask, - nobs_mode=nobs_mode, - ci_method=None, # we will recompute confidence intervals later - ) - - # N.B., we need to worry about the possibility of the - # same aa change due to SNPs at different positions. We cannot - # sum frequencies of SNPs at different genomic positions. This - # is why we group by position and aa_change, not just aa_change. - - # add in a special grouping column to work around the fact that xarray currently - # doesn't support grouping by multiple variables in the same dimension - df_grouper = ds_snp_frq[ - ["variant_position", "variant_aa_change"] - ].to_dataframe() - grouper_var = df_grouper.apply( - lambda row: "_".join([str(v) for v in row]), axis="columns" - ) - ds_snp_frq["variant_position_aa_change"] = "variants", grouper_var - - debug("group by position and amino acid change") - group_by_aa_change = ds_snp_frq.groupby("variant_position_aa_change") - - debug("apply aggregation") - ds_aa_frq = group_by_aa_change.map(self._map_snp_to_aa_change_frq_ds) - - debug("add back in cohort variables, unaffected by aggregation") - cohort_vars = [v for v in ds_snp_frq if v.startswith("cohort_")] - for v in cohort_vars: - ds_aa_frq[v] = ds_snp_frq[v] - - debug("sort by genomic position") - ds_aa_frq = ds_aa_frq.sortby(["variant_position", "variant_aa_change"]) - - debug("recompute frequency") - count = ds_aa_frq["event_count"].values - nobs = ds_aa_frq["event_nobs"].values - with np.errstate(divide="ignore", invalid="ignore"): - frequency = count / nobs # ignore division warnings - ds_aa_frq["event_frequency"] = ("variants", "cohorts"), frequency - - debug("recompute max frequency over cohorts") - with warnings.catch_warnings(): - # ignore "All-NaN slice encountered" warnings - warnings.simplefilter("ignore", category=RuntimeWarning) - max_af = np.nanmax(ds_aa_frq["event_frequency"].values, axis=1) - ds_aa_frq["variant_max_af"] = "variants", max_af - - debug("set up variant dataframe, useful intermediate") - variant_cols = [v for v in ds_aa_frq if v.startswith("variant_")] - df_variants = ds_aa_frq[variant_cols].to_dataframe() - df_variants.columns = [c.split("variant_")[1] for c in df_variants.columns] - - debug("assign new variant label") - label = self._pandas_apply( - self._make_snp_label_aa, - df_variants, - columns=["aa_change", "contig", "position", "ref_allele", "alt_allele"], - ) - ds_aa_frq["variant_label"] = "variants", label - - debug("apply variant query if given") - if variant_query is not None: - loc_variants = df_variants.eval(variant_query).values - ds_aa_frq = ds_aa_frq.isel(variants=loc_variants) - - debug("compute new confidence intervals") - self._add_frequency_ci(ds_aa_frq, ci_method) - - debug("tidy up display by sorting variables") - ds_aa_frq = ds_aa_frq[sorted(ds_aa_frq)] - - gene_name = self._transcript_to_gene_name(transcript) - title = transcript - if gene_name: - title += f" ({gene_name})" - title += " SNP frequencies" - ds_aa_frq.attrs["title"] = title - - return ds_aa_frq - def gene_cnv( self, region: base_params.regions, @@ -1889,7 +1003,6 @@ def gene_cnv_frequencies( debug = self._log.debug debug("check and normalise parameters") - self._check_param_min_cohort_size(min_cohort_size) regions: List[Region] = parse_multi_region(self, region) del region @@ -1996,7 +1109,7 @@ def _gene_cnv_frequencies( is_called = cn >= 0 debug("set up cohort dict") - coh_dict = self._locate_cohorts(cohorts=cohorts, df_samples=df_samples) + coh_dict = locate_cohorts(cohorts=cohorts, data=df_samples) debug("compute cohort frequencies") freq_cols = dict() @@ -2046,7 +1159,7 @@ def _gene_cnv_frequencies( df.reset_index(drop=True, inplace=True) debug("add label") - df["label"] = self._pandas_apply( + df["label"] = pandas_apply( self._make_gene_cnv_label, df, columns=["gene_id", "gene_name", "cnv_type"] ) @@ -2123,8 +1236,6 @@ def gene_cnv_frequencies_advanced( """ - self._check_param_min_cohort_size(min_cohort_size) - regions: List[Region] = parse_multi_region(self, region) del region @@ -2187,7 +1298,7 @@ def _gene_cnv_frequencies_advanced( df_samples = df_samples.set_index("sample_id").loc[sample_id].reset_index() debug("prepare sample metadata for cohort grouping") - df_samples = self._prep_samples_for_cohort_grouping( + df_samples = _prep_samples_for_cohort_grouping( df_samples=df_samples, area_by=area_by, period_by=period_by, @@ -2197,8 +1308,9 @@ def _gene_cnv_frequencies_advanced( group_samples_by_cohort = df_samples.groupby(["taxon", "area", "period"]) debug("build cohorts dataframe") - df_cohorts = self._build_cohorts_from_sample_grouping( - group_samples_by_cohort, min_cohort_size + df_cohorts = _build_cohorts_from_sample_grouping( + group_samples_by_cohort=group_samples_by_cohort, + min_cohort_size=min_cohort_size, ) debug("figure out expected copy number") @@ -2267,7 +1379,7 @@ def _gene_cnv_frequencies_advanced( ) debug("add variant label") - df_variants["label"] = self._pandas_apply( + df_variants["label"] = pandas_apply( self._make_gene_cnv_label, df_variants, columns=["gene_id", "gene_name", "cnv_type"], @@ -2301,7 +1413,7 @@ def _gene_cnv_frequencies_advanced( ds_out = ds_out.isel(variants=loc_variants) debug("add confidence intervals") - self._add_frequency_ci(ds_out, ci_method) + _add_frequency_ci(ds=ds_out, ci_method=ci_method) debug("tidy up display by sorting variables") ds_out = ds_out[sorted(ds_out)] @@ -3021,441 +2133,6 @@ def _fst_gwss( return results - @check_types - @doc( - summary=""" - Plot a heatmap from a pandas DataFrame of frequencies, e.g., output - from `snp_allele_frequencies()` or `gene_cnv_frequencies()`. - """, - parameters=dict( - df=""" - A DataFrame of frequencies, e.g., output from - `snp_allele_frequencies()` or `gene_cnv_frequencies()`. - """, - index=""" - One or more column headers that are present in the input dataframe. - This becomes the heatmap y-axis row labels. The column/s must - produce a unique index. - """, - max_len=""" - Displaying large styled dataframes may cause ipython notebooks to - crash. If the input dataframe is larger than this value, an error - will be raised. - """, - col_width=""" - Plot width per column in pixels (px). - """, - row_height=""" - Plot height per row in pixels (px). - """, - kwargs=""" - Passed through to `px.imshow()`. - """, - ), - notes=""" - It's recommended to filter the input DataFrame to just rows of interest, - i.e., fewer rows than `max_len`. - """, - ) - def plot_frequencies_heatmap( - self, - df: pd.DataFrame, - index: Union[str, List[str]] = "label", - max_len: Optional[int] = 100, - col_width: int = 40, - row_height: int = 20, - x_label: plotly_params.x_label = "Cohorts", - y_label: plotly_params.y_label = "Variants", - colorbar: plotly_params.colorbar = True, - width: plotly_params.width = None, - height: plotly_params.height = None, - text_auto: plotly_params.text_auto = ".0%", - aspect: plotly_params.aspect = "auto", - color_continuous_scale: plotly_params.color_continuous_scale = "Reds", - title: plotly_params.title = True, - show: plotly_params.show = True, - renderer: plotly_params.renderer = None, - **kwargs, - ) -> plotly_params.figure: - debug = self._log.debug - - debug("check len of input") - if max_len and len(df) > max_len: - raise ValueError(f"Input DataFrame is longer than {max_len}") - - debug("handle title") - if title is True: - title = df.attrs.get("title", None) - - debug("indexing") - if index is None: - index = list(df.index.names) - df = df.reset_index().copy() - if isinstance(index, list): - index_col = ( - df[index] - .astype(str) - .apply( - lambda row: ", ".join([o for o in row if o is not None]), - axis="columns", - ) - ) - elif isinstance(index, str): - index_col = df[index].astype(str) - else: - raise TypeError("wrong type for index parameter, expected list or str") - - debug("check that index is unique") - if not index_col.is_unique: - raise ValueError(f"{index} does not produce a unique index") - - debug("drop and re-order columns") - frq_cols = [col for col in df.columns if col.startswith("frq_")] - - debug("keep only freq cols") - heatmap_df = df[frq_cols].copy() - - debug("set index") - heatmap_df.set_index(index_col, inplace=True) - - debug("clean column names") - heatmap_df.columns = heatmap_df.columns.str.lstrip("frq_") - - debug("deal with width and height") - if width is None: - width = 400 + col_width * len(heatmap_df.columns) - if colorbar: - width += 40 - if height is None: - height = 200 + row_height * len(heatmap_df) - if title is not None: - height += 40 - - debug("plotly heatmap styling") - fig = px.imshow( - img=heatmap_df, - zmin=0, - zmax=1, - width=width, - height=height, - text_auto=text_auto, - aspect=aspect, - color_continuous_scale=color_continuous_scale, - title=title, - **kwargs, - ) - - fig.update_xaxes(side="bottom", tickangle=30) - if x_label is not None: - fig.update_xaxes(title=x_label) - if y_label is not None: - fig.update_yaxes(title=y_label) - fig.update_layout( - coloraxis_colorbar=dict( - title="Frequency", - tickvals=[0, 0.2, 0.4, 0.6, 0.8, 1.0], - ticktext=["0%", "20%", "40%", "60%", "80%", "100%"], - ) - ) - if not colorbar: - fig.update(layout_coloraxis_showscale=False) - - if show: # pragma: no cover - fig.show(renderer=renderer) - return None - else: - return fig - - @check_types - @doc( - summary="Create a time series plot of variant frequencies using plotly.", - parameters=dict( - ds=""" - A dataset of variant frequencies, such as returned by - `snp_allele_frequencies_advanced()`, - `aa_allele_frequencies_advanced()` or - `gene_cnv_frequencies_advanced()`. - """, - kwargs="Passed through to `px.line()`.", - ), - returns=""" - A plotly figure containing line graphs. The resulting figure will - have one panel per cohort, grouped into columns by taxon, and - grouped into rows by area. Markers and lines show frequencies of - variants. - """, - ) - def plot_frequencies_time_series( - self, - ds: xr.Dataset, - height: plotly_params.height = None, - width: plotly_params.width = None, - title: plotly_params.title = True, - legend_sizing: plotly_params.legend_sizing = "constant", - show: plotly_params.show = True, - renderer: plotly_params.renderer = None, - **kwargs, - ) -> plotly_params.figure: - debug = self._log.debug - - debug("handle title") - if title is True: - title = ds.attrs.get("title", None) - - debug("extract cohorts into a dataframe") - cohort_vars = [v for v in ds if str(v).startswith("cohort_")] - df_cohorts = ds[cohort_vars].to_dataframe() - df_cohorts.columns = [c.split("cohort_")[1] for c in df_cohorts.columns] # type: ignore - - debug("extract variant labels") - variant_labels = ds["variant_label"].values - - debug("build a long-form dataframe from the dataset") - dfs = [] - for cohort_index, cohort in enumerate(df_cohorts.itertuples()): - ds_cohort = ds.isel(cohorts=cohort_index) - df = pd.DataFrame( - { - "taxon": cohort.taxon, - "area": cohort.area, - "date": cohort.period_start, - "period": str( - cohort.period - ), # use string representation for hover label - "sample_size": cohort.size, - "variant": variant_labels, - "count": ds_cohort["event_count"].values, - "nobs": ds_cohort["event_nobs"].values, - "frequency": ds_cohort["event_frequency"].values, - "frequency_ci_low": ds_cohort["event_frequency_ci_low"].values, - "frequency_ci_upp": ds_cohort["event_frequency_ci_upp"].values, - } - ) - dfs.append(df) - df_events = pd.concat(dfs, axis=0).reset_index(drop=True) - - debug("remove events with no observations") - df_events = df_events.query("nobs > 0") - - debug("calculate error bars") - frq = df_events["frequency"] - frq_ci_low = df_events["frequency_ci_low"] - frq_ci_upp = df_events["frequency_ci_upp"] - df_events["frequency_error"] = frq_ci_upp - frq - df_events["frequency_error_minus"] = frq - frq_ci_low - - debug("make a plot") - fig = px.line( - df_events, - facet_col="taxon", - facet_row="area", - x="date", - y="frequency", - error_y="frequency_error", - error_y_minus="frequency_error_minus", - color="variant", - markers=True, - hover_name="variant", - hover_data={ - "frequency": ":.0%", - "period": True, - "area": True, - "taxon": True, - "sample_size": True, - "date": False, - "variant": False, - }, - height=height, - width=width, - title=title, - labels={ - "date": "Date", - "frequency": "Frequency", - "variant": "Variant", - "taxon": "Taxon", - "area": "Area", - "period": "Period", - "sample_size": "Sample size", - }, - **kwargs, - ) - - debug("tidy plot") - fig.update_layout( - yaxis_range=[-0.05, 1.05], - legend=dict(itemsizing=legend_sizing, tracegroupgap=0), - ) - - if show: # pragma: no cover - fig.show(renderer=renderer) - return None - else: - return fig - - @check_types - @doc( - summary=""" - Plot markers on a map showing variant frequencies for cohorts grouped - by area (space), period (time) and taxon. - """, - parameters=dict( - m="The map on which to add the markers.", - variant="Index or label of variant to plot.", - taxon="Taxon to show markers for.", - period="Time period to show markers for.", - clear=""" - If True, clear all layers (except the base layer) from the map - before adding new markers. - """, - ), - ) - def plot_frequencies_map_markers( - self, - m, - ds: frq_params.ds_frequencies_advanced, - variant: Union[int, str], - taxon: str, - period: pd.Period, - clear: bool = True, - ): - debug = self._log.debug - # only import here because of some problems importing globally - import ipyleaflet # type: ignore - import ipywidgets # type: ignore - - debug("slice dataset to variant of interest") - if isinstance(variant, int): - ds_variant = ds.isel(variants=variant) - variant_label = ds["variant_label"].values[variant] - elif isinstance(variant, str): - ds_variant = ds.set_index(variants="variant_label").sel(variants=variant) - variant_label = variant - else: - raise TypeError( - f"Bad type for variant parameter; expected int or str, found {type(variant)}." - ) - - debug("convert to a dataframe for convenience") - df_markers = ds_variant[ - [ - "cohort_taxon", - "cohort_area", - "cohort_period", - "cohort_lat_mean", - "cohort_lon_mean", - "cohort_size", - "event_frequency", - "event_frequency_ci_low", - "event_frequency_ci_upp", - ] - ].to_dataframe() - - debug("select data matching taxon and period parameters") - df_markers = df_markers.loc[ - ( - (df_markers["cohort_taxon"] == taxon) - & (df_markers["cohort_period"] == period) - ) - ] - - debug("clear existing layers in the map") - if clear: - for layer in m.layers[1:]: - m.remove_layer(layer) - - debug("add markers") - for x in df_markers.itertuples(): - marker = ipyleaflet.CircleMarker() - marker.location = (x.cohort_lat_mean, x.cohort_lon_mean) - marker.radius = 20 - marker.color = "black" - marker.weight = 1 - marker.fill_color = "red" - marker.fill_opacity = x.event_frequency - popup_html = f""" - {variant_label}
- Taxon: {x.cohort_taxon}
- Area: {x.cohort_area}
- Period: {x.cohort_period}
- Sample size: {x.cohort_size}
- Frequency: {x.event_frequency:.0%} - (95% CI: {x.event_frequency_ci_low:.0%} - {x.event_frequency_ci_upp:.0%}) - """ - marker.popup = ipyleaflet.Popup( - child=ipywidgets.HTML(popup_html), - ) - m.add_layer(marker) - - @check_types - @doc( - summary=""" - Create an interactive map with markers showing variant frequencies or - cohorts grouped by area (space), period (time) and taxon. - """, - parameters=dict( - title=""" - If True, attempt to use metadata from input dataset as a plot - title. Otherwise, use supplied value as a title. - """, - epilogue="Additional text to display below the map.", - ), - returns=""" - An interactive map with widgets for selecting which variant, taxon - and time period to display. - """, - ) - def plot_frequencies_interactive_map( - self, - ds: frq_params.ds_frequencies_advanced, - center: map_params.center = map_params.center_default, - zoom: map_params.zoom = map_params.zoom_default, - title: Union[bool, str] = True, - epilogue: Union[bool, str] = True, - ): - debug = self._log.debug - - import ipyleaflet - import ipywidgets - - debug("handle title") - if title is True: - title = ds.attrs.get("title", None) - - debug("create a map") - freq_map = ipyleaflet.Map(center=center, zoom=zoom) - - debug("set up interactive controls") - variants = ds["variant_label"].values - taxa = np.unique(ds["cohort_taxon"].values) - periods = np.unique(ds["cohort_period"].values) - controls = ipywidgets.interactive( - self.plot_frequencies_map_markers, - m=ipywidgets.fixed(freq_map), - ds=ipywidgets.fixed(ds), - variant=ipywidgets.Dropdown(options=variants, description="Variant: "), - taxon=ipywidgets.Dropdown(options=taxa, description="Taxon: "), - period=ipywidgets.Dropdown(options=periods, description="Period: "), - clear=ipywidgets.fixed(True), - ) - - debug("lay out widgets") - components = [] - if title is not None: - components.append(ipywidgets.HTML(value=f"

{title}

")) - components.append(controls) - components.append(freq_map) - if epilogue is True: - epilogue = """ - Variant frequencies are shown as coloured markers. Opacity of color - denotes frequency. Click on a marker for more information. - """ - if epilogue: - components.append(ipywidgets.HTML(value=f"{epilogue}")) - - out = ipywidgets.VBox(components) - - return out - @check_types @doc( summary="Plot diversity summary statistics for multiple cohorts.", diff --git a/malariagen_data/util.py b/malariagen_data/util.py index d756f8430..8e3c52b6b 100644 --- a/malariagen_data/util.py +++ b/malariagen_data/util.py @@ -10,26 +10,27 @@ from textwrap import dedent, fill from typing import IO, Dict, Hashable, List, Mapping, Optional, Tuple, Union from urllib.parse import unquote_plus +from numpy.testing import assert_allclose, assert_array_equal try: - from google import colab + from google import colab # type: ignore except ImportError: colab = None -import allel +import allel # type: ignore import dask.array as da -import ipinfo -import numba +import ipinfo # type: ignore +import numba # type: ignore import numpy as np import pandas import pandas as pd -import plotly.express as px +import plotly.express as px # type: ignore import typeguard import xarray as xr -import zarr -from fsspec.core import url_to_fs -from fsspec.mapping import FSMap -from numpydoc_decorator.impl import humanize_type +import zarr # type: ignore +from fsspec.core import url_to_fs # type: ignore +from fsspec.mapping import FSMap # type: ignore +from numpydoc_decorator.impl import humanize_type # type: ignore from typing_extensions import TypeAlias, get_type_hints DIM_VARIANT = "variants" @@ -116,8 +117,7 @@ def unpack_gff3_attributes(df: pd.DataFrame, attributes: Tuple[str, ...]): try: # zarr >= 2.11.0 - # noinspection PyUnresolvedReferences - from zarr.storage import KVStore + from zarr.storage import KVStore # type: ignore class SafeStore(KVStore): def __getitem__(self, key): @@ -788,7 +788,7 @@ def jackknife_ci(stat_data, jack_stat, confidence_level): https://github.com/astropy/astropy/blob/8aba9632597e6bb489488109222bf2feff5835a6/astropy/stats/jackknife.py#L55 """ - from scipy.special import erfinv + from scipy.special import erfinv # type: ignore n = len(jack_stat) @@ -931,7 +931,7 @@ def check_types(f): """ @wraps(f) - def wrapper(*args, **kwargs): + def check_types_wrapper(*args, **kwargs): type_hints = get_type_hints(f) call_args = getcallargs(f, *args, **kwargs) for k, t in type_hints.items(): @@ -955,7 +955,7 @@ def wrapper(*args, **kwargs): raise error from None return f(*args, **kwargs) - return wrapper + return check_types_wrapper @numba.njit @@ -1151,3 +1151,27 @@ def apply_allele_mapping(x, mapping, max_allele): out[i, new_allele_index] = x[i, allele_index] return out + + +def pandas_apply(f, df, columns): + """Optimised alternative to pandas apply.""" + df = df.reset_index(drop=True) + iterator = zip(*[df[c].values for c in columns]) + ret = pd.Series((f(*vals) for vals in iterator)) + return ret + + +def compare_series_like(actual, expect): + """Compare pandas series-like objects for equality or floating point + similarity, handling missing values appropriately.""" + + # Handle object arrays, these don't get nans compared properly. + t = actual.dtype + if t == object: + expect = expect.fillna("NA") + actual = actual.fillna("NA") + + if t.kind == "f": + assert_allclose(actual.values, expect.values) + else: + assert_array_equal(actual.values, expect.values) diff --git a/malariagen_data/veff.py b/malariagen_data/veff.py index 2dcb655cc..918553582 100644 --- a/malariagen_data/veff.py +++ b/malariagen_data/veff.py @@ -1,7 +1,7 @@ import collections import operator -from Bio.Seq import Seq +from Bio.Seq import Seq # type: ignore VariantEffect = collections.namedtuple( "VariantEffect", @@ -212,8 +212,9 @@ def _get_within_transcript_effect(ann, base_effect, cdss, utr5, utr3, introns): effect = base_effect._replace(effect="THREE_PRIME_UTR", impact="LOW") return effect - # if none of the above - effect = base_effect._replace(effect="TODO", impact="UNKNOWN") + # If none of the above, all we can say is that the variant hits + # a transcript. + effect = base_effect._replace(effect="TRANSCRIPT", impact="MODIFIER") return effect @@ -364,7 +365,9 @@ def _get_within_cds_effect(ann, base_effect, cds, cdss): else: # TODO in-frame complex variation (MNP + INDEL) - effect = base_effect._replace(effect="TODO", impact="UNKNOWN") + effect = base_effect._replace( + effect="TODO in-frame complex variation (MNP + INDEL)", impact="UNKNOWN" + ) return effect @@ -536,7 +539,7 @@ def _get_within_intron_effect(base_effect, intron): effect = base_effect._replace(effect="INTRONIC", impact="MODIFIER") else: - # TODO INDELs and MNPs - effect = base_effect._replace(effect="TODO") + # TODO intronic INDELs and MNPs + effect = base_effect._replace(effect="TODO intronic indels and MNPs") return effect diff --git a/notebooks/plot_frequencies_space_time.ipynb b/notebooks/plot_frequencies_space_time.ipynb index 15512cb5d..a77246fe6 100644 --- a/notebooks/plot_frequencies_space_time.ipynb +++ b/notebooks/plot_frequencies_space_time.ipynb @@ -4,14 +4,7 @@ "cell_type": "code", "execution_count": null, "id": "f820bc66-2fb2-4ca2-9b54-824e50d61a0a", - "metadata": { - "execution": { - "iopub.execute_input": "2022-12-23T17:48:46.703988Z", - "iopub.status.busy": "2022-12-23T17:48:46.703072Z", - "iopub.status.idle": "2022-12-23T17:48:48.388767Z", - "shell.execute_reply": "2022-12-23T17:48:48.388206Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "import malariagen_data\n", @@ -52,14 +45,7 @@ "cell_type": "code", "execution_count": null, "id": "1c612c69-27ee-4f50-b467-786bd998de58", - "metadata": { - "execution": { - "iopub.execute_input": "2022-12-23T17:48:48.391134Z", - "iopub.status.busy": "2022-12-23T17:48:48.390799Z", - "iopub.status.idle": "2022-12-23T17:48:49.928704Z", - "shell.execute_reply": "2022-12-23T17:48:49.928102Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "ds = ag3.gene_cnv_frequencies_advanced(\n", @@ -77,14 +63,7 @@ "cell_type": "code", "execution_count": null, "id": "7b38fb26-9562-4f52-9381-ced0687b5a40", - "metadata": { - "execution": { - "iopub.execute_input": "2022-12-23T17:48:49.931065Z", - "iopub.status.busy": "2022-12-23T17:48:49.930877Z", - "iopub.status.idle": "2022-12-23T17:48:50.235303Z", - "shell.execute_reply": "2022-12-23T17:48:50.234797Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "ag3.plot_frequencies_time_series(ds, height=500, width=1000)" @@ -94,14 +73,7 @@ "cell_type": "code", "execution_count": null, "id": "fb104e66-7ba0-42ec-9ed0-625a43b69d5b", - "metadata": { - "execution": { - "iopub.execute_input": "2022-12-23T17:48:50.278700Z", - "iopub.status.busy": "2022-12-23T17:48:50.278350Z", - "iopub.status.idle": "2022-12-23T17:48:52.814714Z", - "shell.execute_reply": "2022-12-23T17:48:52.814176Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "ds = ag3.gene_cnv_frequencies_advanced(\n", @@ -118,14 +90,7 @@ "cell_type": "code", "execution_count": null, "id": "ccf5279d-fc5a-4efe-972e-60c7b9515558", - "metadata": { - "execution": { - "iopub.execute_input": "2022-12-23T17:48:52.817166Z", - "iopub.status.busy": "2022-12-23T17:48:52.816909Z", - "iopub.status.idle": "2022-12-23T17:48:52.905183Z", - "shell.execute_reply": "2022-12-23T17:48:52.904720Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "ag3.plot_frequencies_interactive_map(ds)" @@ -144,14 +109,7 @@ "cell_type": "code", "execution_count": null, "id": "bf75e5b1-f2ca-41f4-a0fd-08a03c979fa4", - "metadata": { - "execution": { - "iopub.execute_input": "2022-12-23T17:48:52.908656Z", - "iopub.status.busy": "2022-12-23T17:48:52.908467Z", - "iopub.status.idle": "2022-12-23T17:48:54.913588Z", - "shell.execute_reply": "2022-12-23T17:48:54.912997Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "ds = ag3.aa_allele_frequencies_advanced(\n", @@ -170,14 +128,7 @@ "cell_type": "code", "execution_count": null, "id": "2c6ffd89-7edc-4f90-bd23-b03e15ce9714", - "metadata": { - "execution": { - "iopub.execute_input": "2022-12-23T17:48:54.915626Z", - "iopub.status.busy": "2022-12-23T17:48:54.915442Z", - "iopub.status.idle": "2022-12-23T17:48:55.001687Z", - "shell.execute_reply": "2022-12-23T17:48:55.001166Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "ag3.plot_frequencies_time_series(ds, height=400, width=600)" @@ -187,14 +138,7 @@ "cell_type": "code", "execution_count": null, "id": "c058a196-2f78-49a0-8a18-cb2420c3278a", - "metadata": { - "execution": { - "iopub.execute_input": "2022-12-23T17:48:55.003660Z", - "iopub.status.busy": "2022-12-23T17:48:55.003479Z", - "iopub.status.idle": "2022-12-23T17:48:59.787903Z", - "shell.execute_reply": "2022-12-23T17:48:59.787238Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "ds = ag3.aa_allele_frequencies_advanced(\n", @@ -212,14 +156,7 @@ "cell_type": "code", "execution_count": null, "id": "e733c877-a1f1-4dc6-b8a3-a9652430ff39", - "metadata": { - "execution": { - "iopub.execute_input": "2022-12-23T17:48:59.790255Z", - "iopub.status.busy": "2022-12-23T17:48:59.790075Z", - "iopub.status.idle": "2022-12-23T17:48:59.818333Z", - "shell.execute_reply": "2022-12-23T17:48:59.817926Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "ag3.plot_frequencies_interactive_map(ds)" @@ -286,14 +223,7 @@ "cell_type": "code", "execution_count": null, "id": "c3738721-8d16-40dd-8a4d-4b588af36b15", - "metadata": { - "execution": { - "iopub.execute_input": "2022-12-23T17:48:59.821747Z", - "iopub.status.busy": "2022-12-23T17:48:59.821571Z", - "iopub.status.idle": "2022-12-23T17:49:01.166705Z", - "shell.execute_reply": "2022-12-23T17:49:01.166143Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "ds = ag3.snp_allele_frequencies_advanced(\n", @@ -316,14 +246,7 @@ "cell_type": "code", "execution_count": null, "id": "5c743429", - "metadata": { - "execution": { - "iopub.execute_input": "2022-12-23T17:49:01.169029Z", - "iopub.status.busy": "2022-12-23T17:49:01.168811Z", - "iopub.status.idle": "2022-12-23T17:49:01.648657Z", - "shell.execute_reply": "2022-12-23T17:49:01.648133Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "ag3.plot_frequencies_time_series(ds, height=900, width=900)" @@ -333,14 +256,7 @@ "cell_type": "code", "execution_count": null, "id": "9f6a7920", - "metadata": { - "execution": { - "iopub.execute_input": "2022-12-23T17:49:01.653225Z", - "iopub.status.busy": "2022-12-23T17:49:01.653021Z", - "iopub.status.idle": "2022-12-23T17:49:02.676879Z", - "shell.execute_reply": "2022-12-23T17:49:02.676318Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "ds = ag3.snp_allele_frequencies_advanced(\n", @@ -359,14 +275,7 @@ "cell_type": "code", "execution_count": null, "id": "03cfc0e7", - "metadata": { - "execution": { - "iopub.execute_input": "2022-12-23T17:49:02.678825Z", - "iopub.status.busy": "2022-12-23T17:49:02.678648Z", - "iopub.status.idle": "2022-12-23T17:49:02.835617Z", - "shell.execute_reply": "2022-12-23T17:49:02.835088Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "ag3.plot_frequencies_time_series(ds, height=400, width=800)" @@ -376,14 +285,7 @@ "cell_type": "code", "execution_count": null, "id": "28922f73-f38d-424a-a9c5-1e839c86567b", - "metadata": { - "execution": { - "iopub.execute_input": "2022-12-23T17:49:02.837997Z", - "iopub.status.busy": "2022-12-23T17:49:02.837815Z", - "iopub.status.idle": "2022-12-23T17:49:06.334806Z", - "shell.execute_reply": "2022-12-23T17:49:06.334285Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "ds = ag3.snp_allele_frequencies_advanced(\n", @@ -401,14 +303,7 @@ "cell_type": "code", "execution_count": null, "id": "2a8ea4a0-6bf7-41a3-b734-06535e7dbbc8", - "metadata": { - "execution": { - "iopub.execute_input": "2022-12-23T17:49:06.337135Z", - "iopub.status.busy": "2022-12-23T17:49:06.336916Z", - "iopub.status.idle": "2022-12-23T17:49:06.365703Z", - "shell.execute_reply": "2022-12-23T17:49:06.365218Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "ag3.plot_frequencies_interactive_map(ds)" diff --git a/tests/anoph/conftest.py b/tests/anoph/conftest.py index e08e374b4..4869cd4e2 100644 --- a/tests/anoph/conftest.py +++ b/tests/anoph/conftest.py @@ -73,8 +73,8 @@ def __init__( n_exons_low=1, n_exons_high=5, intron_size_low=10, - intron_size_high=1_000, - exon_size_low=10, + intron_size_high=100, + exon_size_low=100, exon_size_high=1_000, source="random", max_genes=1_000, @@ -256,15 +256,21 @@ def simulate_exons( transcript_size = transcript_end - transcript_start exons = [] exon_end = transcript_start - for exon_ix in range(randint(self.n_exons_low, self.n_exons_high)): + n_exons = randint(self.n_exons_low, self.n_exons_high) + for exon_ix in range(n_exons): exon_id = f"exon-{contig}-{gene_ix}-{transcript_ix}-{exon_ix}" - intron_size = randint( - self.intron_size_low, min(transcript_size, self.intron_size_high) - ) - exon_start = exon_end + intron_size - if exon_start >= transcript_end: - # Stop making exons, no more space left in the transcript. - break + if exon_ix > 0: + # Insert an intron between this exon and the previous one. + intron_size = randint( + self.intron_size_low, min(transcript_size, self.intron_size_high) + ) + exon_start = exon_end + intron_size + if exon_start >= transcript_end: + # Stop making exons, no more space left in the transcript. + break + else: + # First exon, assume exon starts where the transcript starts. + exon_start = transcript_start exon_size = randint(self.exon_size_low, self.exon_size_high) exon_end = min(exon_start + exon_size, transcript_end) assert exon_end > exon_start @@ -282,20 +288,20 @@ def simulate_exons( yield exon exons.append(exon) - # Note that this is not perfect, because sometimes we end up - # without any CDSs. Also in reality, an exon can contain - # part of a UTR and part of a CDS, but that is harder to - # simulate. So keep things simple for now. + # Note that this is not perfect, because in reality an exon can contain + # part of a UTR and part of a CDS, but that is harder to simulate. So + # keep things simple for now. if strand == "-": # Take exons in reverse order. exons == exons[::-1] for exon_ix, exon in enumerate(exons): first_exon = exon_ix == 0 last_exon = exon_ix == len(exons) - 1 - if first_exon: + # Ensure at least one CDS. + if first_exon and len(exons) > 1: feature_type = self.utr5_type phase = "." - elif last_exon: + elif last_exon and len(exons) > 2: feature_type = self.utr3_type phase = "." else: @@ -1005,11 +1011,14 @@ def contigs(self) -> Tuple[str, ...]: def random_contig(self): return choice(self.contigs) - def random_region_str(self): + def random_region_str(self, region_size=None): contig = self.random_contig() contig_size = self.contig_sizes[contig] region_start = randint(1, contig_size) - region_end = randint(region_start, contig_size) + if region_size: + region_end = region_start + region_size + else: + region_end = randint(region_start, contig_size) region = f"{contig}:{region_start:,}-{region_end:,}" return region diff --git a/tests/anoph/test_pca.py b/tests/anoph/test_pca.py index e340cde6a..2261f5b6d 100644 --- a/tests/anoph/test_pca.py +++ b/tests/anoph/test_pca.py @@ -2,7 +2,7 @@ import numpy as np import pandas as pd -import plotly.graph_objects as go +import plotly.graph_objects as go # type: ignore import pytest from pytest_cases import parametrize_with_cases diff --git a/tests/anoph/test_sample_metadata.py b/tests/anoph/test_sample_metadata.py index c3babbae8..1fd2e57fa 100644 --- a/tests/anoph/test_sample_metadata.py +++ b/tests/anoph/test_sample_metadata.py @@ -683,14 +683,21 @@ def test_count_samples(fixture, api): ) -@pytest.mark.parametrize( - "basemap", ["satellite", None, ipyleaflet.basemaps.OpenTopoMap] -) @parametrize_with_cases("fixture,api", cases=".") -def test_plot_samples_interactive_map(fixture, api, basemap): - m = api.plot_samples_interactive_map(basemap=basemap) +def test_plot_samples_interactive_map(fixture, api): + # Test behaviour with bad basemap param. + with pytest.raises(ValueError): + api.plot_samples_interactive_map(basemap="foobar") + + # Default params. + m = api.plot_samples_interactive_map() assert isinstance(m, ipyleaflet.Map) + # Explicit params. + for basemap in ["satellite", None, ipyleaflet.basemaps.OpenTopoMap]: + m = api.plot_samples_interactive_map(basemap=basemap, width=500, height=300) + assert isinstance(m, ipyleaflet.Map) + @parametrize_with_cases("fixture,api", cases=".") def test_wgs_data_catalog(fixture, api): diff --git a/tests/anoph/test_snp_frq.py b/tests/anoph/test_snp_frq.py new file mode 100644 index 000000000..6d580b10f --- /dev/null +++ b/tests/anoph/test_snp_frq.py @@ -0,0 +1,1469 @@ +import random + +import numpy as np +import pandas as pd +from pandas.testing import assert_frame_equal +import pytest +from pytest_cases import parametrize_with_cases +import xarray as xr +from numpy.testing import assert_allclose, assert_array_equal +import plotly.graph_objects as go # type: ignore + +from malariagen_data import af1 as _af1 +from malariagen_data import ag3 as _ag3 +from malariagen_data.anoph.snp_frq import AnophelesSnpFrequencyAnalysis +from malariagen_data.util import compare_series_like + + +@pytest.fixture +def ag3_sim_api(ag3_sim_fixture): + return AnophelesSnpFrequencyAnalysis( + url=ag3_sim_fixture.url, + config_path=_ag3.CONFIG_PATH, + gcs_url=_ag3.GCS_URL, + major_version_number=_ag3.MAJOR_VERSION_NUMBER, + major_version_path=_ag3.MAJOR_VERSION_PATH, + pre=True, + aim_metadata_dtype={ + "aim_species_fraction_arab": "float64", + "aim_species_fraction_colu": "float64", + "aim_species_fraction_colu_no2l": "float64", + "aim_species_gambcolu_arabiensis": object, + "aim_species_gambiae_coluzzii": object, + "aim_species": object, + }, + gff_gene_type="gene", + gff_default_attributes=("ID", "Parent", "Name", "description"), + default_site_mask="gamb_colu_arab", + results_cache=ag3_sim_fixture.results_cache_path.as_posix(), + taxon_colors=_ag3.TAXON_COLORS, + ) + + +@pytest.fixture +def af1_sim_api(af1_sim_fixture): + return AnophelesSnpFrequencyAnalysis( + url=af1_sim_fixture.url, + config_path=_af1.CONFIG_PATH, + gcs_url=_af1.GCS_URL, + major_version_number=_af1.MAJOR_VERSION_NUMBER, + major_version_path=_af1.MAJOR_VERSION_PATH, + pre=False, + gff_gene_type="protein_coding_gene", + gff_default_attributes=("ID", "Parent", "Note", "description"), + default_site_mask="funestus", + results_cache=af1_sim_fixture.results_cache_path.as_posix(), + taxon_colors=_af1.TAXON_COLORS, + ) + + +# N.B., here we use pytest_cases to parametrize tests. Each +# function whose name begins with "case_" defines a set of +# inputs to the test functions. See the documentation for +# pytest_cases for more information, e.g.: +# +# https://smarie.github.io/python-pytest-cases/#basic-usage +# +# We use this approach here because we want to use fixtures +# as test parameters, which is otherwise hard to do with +# pytest alone. + + +def case_ag3_sim(ag3_sim_fixture, ag3_sim_api): + return ag3_sim_fixture, ag3_sim_api + + +def case_af1_sim(af1_sim_fixture, af1_sim_api): + return af1_sim_fixture, af1_sim_api + + +expected_alleles = list("ACGT") +expected_effects = [ + "FIVE_PRIME_UTR", + "THREE_PRIME_UTR", + "SYNONYMOUS_CODING", + "NON_SYNONYMOUS_CODING", + "START_LOST", + "STOP_LOST", + "STOP_GAINED", + "SPLICE_CORE", + "SPLICE_REGION", + "INTRONIC", + "TRANSCRIPT", +] +expected_impacts = [ + "HIGH", + "MODERATE", + "LOW", + "MODIFIER", +] + + +def random_transcript(*, api): + df_gff = api.genome_features(attributes=["ID", "Parent"]) + df_transcripts = df_gff.query("type == 'mRNA'") + transcript_ids = df_transcripts["ID"].dropna().to_list() + transcript_id = random.choice(transcript_ids) + transcript = df_transcripts.set_index("ID").loc[transcript_id] + return transcript + + +@parametrize_with_cases("fixture,api", cases=".") +def test_snp_effects(fixture, api: AnophelesSnpFrequencyAnalysis): + # Pick a random transcript. + transcript = random_transcript(api=api) + + # Pick a random site mask. + site_mask = random.choice(api.site_mask_ids + (None,)) + + # Compute effects. + df = api.snp_effects(transcript=transcript.name, site_mask=site_mask) + assert isinstance(df, pd.DataFrame) + + # Check columns. + expected_fields = ( + [ + "contig", + "position", + "ref_allele", + "alt_allele", + ] + + [f"pass_{m}" for m in api.site_mask_ids] + + [ + "transcript", + "effect", + "impact", + "ref_codon", + "alt_codon", + "aa_pos", + "ref_aa", + "alt_aa", + "aa_change", + ] + ) + assert df.columns.tolist() == expected_fields + + # Check some values. + assert np.all(df["contig"] == transcript["contig"]) + position = df["position"].values + assert np.all(position >= transcript["start"]) + assert np.all(position <= transcript["end"]) + assert np.all(position[1:] >= position[:-1]) + expected_alleles = list("ACGT") + assert np.all(df["ref_allele"].isin(expected_alleles)) + assert np.all(df["alt_allele"].isin(expected_alleles)) + assert np.all(df["transcript"] == transcript.name) + assert np.all(df["effect"].isin(expected_effects)) + assert np.all(df["impact"].isin(expected_impacts)) + df_aa = df[~df["aa_change"].isna()] + expected_aa_change = ( + df_aa["ref_aa"] + df_aa["aa_pos"].astype(int).astype(str) + df_aa["alt_aa"] + ) + assert np.all(df_aa["aa_change"] == expected_aa_change) + + +def check_frequency(x): + loc_nan = np.isnan(x) + assert np.all(x[~loc_nan] >= 0) + assert np.all(x[~loc_nan] <= 1) + + +def check_snp_allele_frequencies( + *, + api, + df, + cohort_labels, + transcript, +): + assert isinstance(df, pd.DataFrame) + + # Check columns. + universal_fields = [f"pass_{m}" for m in api.site_mask_ids] + [ + "label", + ] + effects_fields = [ + "transcript", + "effect", + "impact", + "ref_codon", + "alt_codon", + "aa_pos", + "ref_aa", + "alt_aa", + ] + frq_fields = ["frq_" + s for s in cohort_labels] + ["max_af"] + expected_fields = universal_fields + frq_fields + effects_fields + assert sorted(df.columns.tolist()) == sorted(expected_fields) + assert df.index.names == [ + "contig", + "position", + "ref_allele", + "alt_allele", + "aa_change", + ] + + # Check some values. + df = df.reset_index() + assert np.all(df["contig"] == transcript["contig"]) + position = df["position"].values + assert np.all(position >= transcript["start"]) + assert np.all(position <= transcript["end"]) + assert np.all(position[1:] >= position[:-1]) + assert np.all(df["ref_allele"].isin(expected_alleles)) + assert np.all(df["alt_allele"].isin(expected_alleles)) + assert np.all(df["transcript"] == transcript.name) + assert np.all(df["effect"].isin(expected_effects)) + assert np.all(df["impact"].isin(expected_impacts)) + df_aa = df[~df["aa_change"].isna()] + expected_aa_change = ( + df_aa["ref_aa"] + df_aa["aa_pos"].astype(int).astype(str) + df_aa["alt_aa"] + ) + assert np.all(df_aa["aa_change"] == expected_aa_change) + for f in frq_fields: + x = df[f] + check_frequency(x) + + +def check_aa_allele_frequencies( + *, + df, + cohort_labels, + transcript, +): + assert isinstance(df, pd.DataFrame) + + # Check columns. + universal_fields = [ + "label", + ] + effects_fields = [ + "transcript", + "effect", + "impact", + "aa_pos", + "ref_allele", + "alt_allele", + "ref_aa", + "alt_aa", + ] + frq_fields = ["frq_" + s for s in cohort_labels] + ["max_af"] + expected_fields = universal_fields + frq_fields + effects_fields + expected_fields = universal_fields + frq_fields + effects_fields + assert sorted(df.columns.tolist()) == sorted(expected_fields) + assert df.index.names == [ + "aa_change", + "contig", + "position", + ] + + # Check some values. + df = df.reset_index() + assert np.all(df["contig"] == transcript["contig"]) + position = df["position"].values + assert np.all(position >= transcript["start"]) + assert np.all(position <= transcript["end"]) + assert np.all(position[1:] >= position[:-1]) + assert np.all(df["ref_allele"].isin(expected_alleles)) + # N.B., alt_allele may contain multiple alleles, e.g., "{A,T}", if + # multiple SNP alleles at the same position cause the same amino acid + # change. + assert np.all(df["transcript"] == transcript.name) + assert np.all(df["effect"].isin(expected_effects)) + assert np.all(df["impact"].isin(expected_impacts)) + df_aa = df[~df["aa_change"].isna()] + expected_aa_change = ( + df_aa["ref_aa"] + df_aa["aa_pos"].astype(int).astype(str) + df_aa["alt_aa"] + ) + assert np.all(df_aa["aa_change"] == expected_aa_change) + for f in frq_fields: + x = df[f] + check_frequency(x) + + +@pytest.mark.parametrize( + "cohorts", ["admin1_year", "admin2_month", "country", "foobar"] +) +@parametrize_with_cases("fixture,api", cases=".") +def test_allele_frequencies_with_str_cohorts( + fixture, + api: AnophelesSnpFrequencyAnalysis, + cohorts, +): + # Pick test parameters at random. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + sample_sets = random.choice(all_sample_sets) + site_mask = random.choice(api.site_mask_ids + (None,)) + min_cohort_size = random.randint(0, 2) + transcript = random_transcript(api=api) + + # Set up call params. + params = dict( + transcript=transcript.name, + cohorts=cohorts, + min_cohort_size=min_cohort_size, + site_mask=site_mask, + sample_sets=sample_sets, + drop_invariant=True, + ) + + # Test behaviour with bad cohorts param. + if cohorts == "foobar": + with pytest.raises(ValueError): + api.snp_allele_frequencies(**params) + return + + # Run the function under test. + df_snp = api.snp_allele_frequencies(**params) + + # Figure out expected cohort labels. + df_samples = api.sample_metadata(sample_sets=sample_sets) + if "cohort_" + cohorts in df_samples: + cohort_column = "cohort_" + cohorts + else: + cohort_column = cohorts + cohort_counts = df_samples[cohort_column].value_counts() + cohort_labels = cohort_counts[cohort_counts >= min_cohort_size].index.to_list() + + # Standard checks. + check_snp_allele_frequencies( + api=api, + df=df_snp, + cohort_labels=cohort_labels, + transcript=transcript, + ) + + # Run the function under test. + df_aa = api.aa_allele_frequencies(**params) + + # Standard checks. + check_aa_allele_frequencies( + df=df_aa, + cohort_labels=cohort_labels, + transcript=transcript, + ) + + +@pytest.mark.parametrize("min_cohort_size", [0, 10, 100]) +@parametrize_with_cases("fixture,api", cases=".") +def test_allele_frequencies_with_min_cohort_size( + fixture, + api: AnophelesSnpFrequencyAnalysis, + min_cohort_size, +): + # Pick test parameters at random. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + sample_sets = random.choice(all_sample_sets) + site_mask = random.choice(api.site_mask_ids + (None,)) + transcript = random_transcript(api=api) + cohorts = "admin1_year" + + # Set up call params. + params = dict( + transcript=transcript.name, + cohorts=cohorts, + min_cohort_size=min_cohort_size, + site_mask=site_mask, + sample_sets=sample_sets, + drop_invariant=True, + ) + + # Figure out expected cohort labels. + df_samples = api.sample_metadata(sample_sets=sample_sets) + if "cohort_" + cohorts in df_samples: + cohort_column = "cohort_" + cohorts + else: + cohort_column = cohorts + cohort_counts = df_samples[cohort_column].value_counts() + cohort_labels = cohort_counts[cohort_counts >= min_cohort_size].index.to_list() + + if len(cohort_labels) == 0: + # No cohorts, expect error. + with pytest.raises(ValueError): + api.snp_allele_frequencies(**params) + with pytest.raises(ValueError): + api.aa_allele_frequencies(**params) + return + + # Run the function under test. + df_snp = api.snp_allele_frequencies(**params) + + # Standard checks. + check_snp_allele_frequencies( + api=api, + df=df_snp, + cohort_labels=cohort_labels, + transcript=transcript, + ) + + # Run the function under test. + df_aa = api.aa_allele_frequencies(**params) + + # Standard checks. + check_aa_allele_frequencies( + df=df_aa, + cohort_labels=cohort_labels, + transcript=transcript, + ) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_allele_frequencies_with_str_cohorts_and_sample_query( + fixture, + api: AnophelesSnpFrequencyAnalysis, +): + # Pick test parameters at random. + sample_sets = None + site_mask = random.choice(api.site_mask_ids + (None,)) + min_cohort_size = 0 + transcript = random_transcript(api=api) + cohorts = random.choice( + ["admin1_year", "admin1_month", "admin2_year", "admin2_month"] + ) + df_samples = api.sample_metadata(sample_sets=sample_sets) + countries = df_samples["country"].unique() + country = random.choice(countries) + sample_query = f"country == '{country}'" + + # Figure out expected cohort labels. + df_samples = api.sample_metadata(sample_sets=sample_sets, sample_query=sample_query) + cohort_column = "cohort_" + cohorts + cohort_counts = df_samples[cohort_column].value_counts() + cohort_labels = cohort_counts[cohort_counts >= min_cohort_size].index.to_list() + + # Set up call params. + params = dict( + transcript=transcript.name, + cohorts=cohorts, + min_cohort_size=min_cohort_size, + site_mask=site_mask, + sample_sets=sample_sets, + sample_query=sample_query, + drop_invariant=True, + ) + + # Run the function under test. + df_snp = api.snp_allele_frequencies(**params) + + # Standard checks. + check_snp_allele_frequencies( + api=api, + df=df_snp, + cohort_labels=cohort_labels, + transcript=transcript, + ) + + # Run the function under test. + df_aa = api.aa_allele_frequencies(**params) + + # Standard checks. + check_aa_allele_frequencies( + df=df_aa, + cohort_labels=cohort_labels, + transcript=transcript, + ) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_allele_frequencies_with_dict_cohorts( + fixture, api: AnophelesSnpFrequencyAnalysis +): + # Pick test parameters at random. + sample_sets = None # all sample sets + site_mask = random.choice(api.site_mask_ids + (None,)) + min_cohort_size = random.randint(0, 2) + transcript = random_transcript(api=api) + + # Create cohorts by country. + df_samples = api.sample_metadata(sample_sets=sample_sets) + cohort_counts = df_samples["country"].value_counts() + cohorts = {cohort: f"country == '{cohort}'" for cohort in cohort_counts.index} + cohort_labels = cohort_counts[cohort_counts >= min_cohort_size].index.to_list() + + # Set up call params. + params = dict( + transcript=transcript.name, + cohorts=cohorts, + min_cohort_size=min_cohort_size, + site_mask=site_mask, + sample_sets=sample_sets, + drop_invariant=True, + ) + + # Run the function under test. + df_snp = api.snp_allele_frequencies(**params) + + # Standard checks. + check_snp_allele_frequencies( + api=api, + df=df_snp, + cohort_labels=cohort_labels, + transcript=transcript, + ) + + # Run the function under test. + df_aa = api.aa_allele_frequencies(**params) + + # Standard checks. + check_aa_allele_frequencies( + df=df_aa, + cohort_labels=cohort_labels, + transcript=transcript, + ) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_allele_frequencies_without_drop_invariant( + fixture, + api: AnophelesSnpFrequencyAnalysis, +): + # Pick test parameters at random. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + sample_sets = random.choice(all_sample_sets) + site_mask = random.choice(api.site_mask_ids + (None,)) + min_cohort_size = random.randint(0, 2) + transcript = random_transcript(api=api) + cohorts = random.choice(["admin1_year", "admin2_month", "country"]) + + # Figure out expected cohort labels. + df_samples = api.sample_metadata(sample_sets=sample_sets) + if "cohort_" + cohorts in df_samples: + cohort_column = "cohort_" + cohorts + else: + cohort_column = cohorts + cohort_counts = df_samples[cohort_column].value_counts() + cohort_labels = cohort_counts[cohort_counts >= min_cohort_size].index.to_list() + + # Set up call params. + params = dict( + transcript=transcript.name, + cohorts=cohorts, + min_cohort_size=min_cohort_size, + site_mask=site_mask, + sample_sets=sample_sets, + ) + + # Run the function under test. + df_snp_a = api.snp_allele_frequencies(drop_invariant=True, **params) + df_snp_b = api.snp_allele_frequencies(drop_invariant=False, **params) + + # Standard checks. + check_snp_allele_frequencies( + api=api, + df=df_snp_a, + cohort_labels=cohort_labels, + transcript=transcript, + ) + check_snp_allele_frequencies( + api=api, + df=df_snp_b, + cohort_labels=cohort_labels, + transcript=transcript, + ) + + # Check specifics. + assert len(df_snp_b) > len(df_snp_a) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_allele_frequencies_without_effects( + fixture, + api: AnophelesSnpFrequencyAnalysis, +): + # Pick test parameters at random. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + sample_sets = random.choice(all_sample_sets) + site_mask = random.choice(api.site_mask_ids + (None,)) + min_cohort_size = random.randint(0, 2) + transcript = random_transcript(api=api) + cohorts = random.choice(["admin1_year", "admin2_month", "country"]) + + # Figure out expected cohort labels. + df_samples = api.sample_metadata(sample_sets=sample_sets) + if "cohort_" + cohorts in df_samples: + cohort_column = "cohort_" + cohorts + else: + cohort_column = cohorts + cohort_counts = df_samples[cohort_column].value_counts() + cohort_labels = cohort_counts[cohort_counts >= min_cohort_size].index.to_list() + + # Set up call params. + params = dict( + transcript=transcript.name, + cohorts=cohorts, + min_cohort_size=min_cohort_size, + site_mask=site_mask, + sample_sets=sample_sets, + drop_invariant=True, + ) + + # Run the function under test. + df_snp_a = api.snp_allele_frequencies(effects=True, **params) + df_snp_b = api.snp_allele_frequencies(effects=False, **params) + + # Standard checks. + check_snp_allele_frequencies( + api=api, + df=df_snp_a, + cohort_labels=cohort_labels, + transcript=transcript, + ) + + # Check specifics. + assert len(df_snp_b) == len(df_snp_a) + + # Check columns and index names. + filter_fields = [f"pass_{m}" for m in api.site_mask_ids] + universal_fields = filter_fields + ["label"] + frq_fields = ["frq_" + s for s in cohort_labels] + ["max_af"] + expected_fields = universal_fields + frq_fields + assert sorted(df_snp_b.columns.tolist()) == sorted(expected_fields) + assert df_snp_b.index.names == [ + "contig", + "position", + "ref_allele", + "alt_allele", + ] + + # Compare values with and without effects. + comparable_fields = ( + [ + "contig", + "position", + "ref_allele", + "alt_allele", + ] + + filter_fields + + frq_fields + ) + # N.B., values of the "label" field are different with and without + # effects, so don't compare them. + assert_frame_equal( + df_snp_b.reset_index()[comparable_fields], + df_snp_a.reset_index()[comparable_fields], + ) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_allele_frequencies_with_bad_transcript( + fixture, + api: AnophelesSnpFrequencyAnalysis, +): + # Pick test parameters at random. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + sample_sets = random.choice(all_sample_sets) + site_mask = random.choice(api.site_mask_ids + (None,)) + min_cohort_size = random.randint(0, 2) + cohorts = random.choice(["admin1_year", "admin2_month", "country"]) + + # Set up call params. + params = dict( + transcript="foobar", + cohorts=cohorts, + min_cohort_size=min_cohort_size, + site_mask=site_mask, + sample_sets=sample_sets, + drop_invariant=True, + ) + + # Run the function under test. + with pytest.raises(ValueError): + api.snp_allele_frequencies(**params) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_allele_frequencies_with_region( + fixture, + api: AnophelesSnpFrequencyAnalysis, +): + # Pick test parameters at random. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + sample_sets = random.choice(all_sample_sets) + site_mask = random.choice(api.site_mask_ids + (None,)) + min_cohort_size = random.randint(0, 2) + cohorts = random.choice(["admin1_year", "admin2_month", "country"]) + # This should work, as long as effects=False - i.e., can get frequencies + # for any genome region. + transcript = fixture.random_region_str(region_size=500) + + # Set up call params. + params = dict( + transcript=transcript, + cohorts=cohorts, + min_cohort_size=min_cohort_size, + site_mask=site_mask, + sample_sets=sample_sets, + drop_invariant=False, + effects=False, + ) + + # Run the function under test. + df_snp = api.snp_allele_frequencies(**params) + + # Basic checks. + assert isinstance(df_snp, pd.DataFrame) + assert len(df_snp) > 0 + + # Figure out expected cohort labels. + df_samples = api.sample_metadata(sample_sets=sample_sets) + if "cohort_" + cohorts in df_samples: + cohort_column = "cohort_" + cohorts + else: + cohort_column = cohorts + cohort_counts = df_samples[cohort_column].value_counts() + cohort_labels = cohort_counts[cohort_counts >= min_cohort_size].index.to_list() + + # Check columns and index names. + filter_fields = [f"pass_{m}" for m in api.site_mask_ids] + universal_fields = filter_fields + ["label"] + frq_fields = ["frq_" + s for s in cohort_labels] + ["max_af"] + expected_fields = universal_fields + frq_fields + assert sorted(df_snp.columns.tolist()) == sorted(expected_fields) + assert df_snp.index.names == [ + "contig", + "position", + "ref_allele", + "alt_allele", + ] + + +@parametrize_with_cases("fixture,api", cases=".") +def test_allele_frequencies_with_dup_samples( + fixture, + api: AnophelesSnpFrequencyAnalysis, +): + # Pick test parameters at random. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + sample_set = random.choice(all_sample_sets) + site_mask = random.choice(api.site_mask_ids + (None,)) + min_cohort_size = random.randint(0, 2) + transcript = random_transcript(api=api) + cohorts = random.choice(["admin1_year", "admin2_month", "country"]) + + # Set up call params. + params = dict( + transcript=transcript.name, + cohorts=cohorts, + min_cohort_size=min_cohort_size, + site_mask=site_mask, + ) + + # Run the function under test. + df_snp_a = api.snp_allele_frequencies(sample_sets=[sample_set], **params) + df_snp_b = api.snp_allele_frequencies( + sample_sets=[sample_set, sample_set], **params + ) + + # Expect automatically deduplicate sample sets. + assert_frame_equal(df_snp_b, df_snp_a) + + # Run the function under test. + df_aa_a = api.aa_allele_frequencies(sample_sets=[sample_set], **params) + df_aa_b = api.aa_allele_frequencies(sample_sets=[sample_set, sample_set], **params) + + # Expect automatically deduplicate sample sets. + assert_frame_equal(df_aa_b, df_aa_a) + + +def check_snp_allele_frequencies_advanced( + *, + api: AnophelesSnpFrequencyAnalysis, + transcript=None, + area_by="admin1_iso", + period_by="year", + sample_sets=None, + sample_query=None, + min_cohort_size=None, + nobs_mode="called", + variant_query=None, + site_mask=None, +): + # Pick test parameters at random. + if transcript is None: + transcript = random_transcript(api=api).name + if area_by is None: + area_by = random.choice(["country", "admin1_iso", "admin2_name"]) + if period_by is None: + period_by = random.choice(["year", "quarter", "month"]) + if sample_sets is None: + all_sample_sets = api.sample_sets()["sample_set"].to_list() + sample_sets = random.choice(all_sample_sets) + if min_cohort_size is None: + min_cohort_size = random.randint(0, 2) + if site_mask is None: + site_mask = random.choice(api.site_mask_ids + (None,)) + + # Run function under test. + ds = api.snp_allele_frequencies_advanced( + transcript=transcript, + area_by=area_by, + period_by=period_by, + sample_sets=sample_sets, + sample_query=sample_query, + min_cohort_size=min_cohort_size, + nobs_mode=nobs_mode, + variant_query=variant_query, + site_mask=site_mask, + ) + + # Check the result. + assert isinstance(ds, xr.Dataset) + assert set(ds.dims) == {"cohorts", "variants"} + + # Check variant variables. + expected_variant_vars = [ + "variant_label", + "variant_contig", + "variant_position", + "variant_ref_allele", + "variant_alt_allele", + "variant_max_af", + "variant_transcript", + "variant_effect", + "variant_impact", + "variant_ref_codon", + "variant_alt_codon", + "variant_ref_aa", + "variant_alt_aa", + "variant_aa_pos", + "variant_aa_change", + ] + expected_variant_vars += [f"variant_pass_{m}" for m in api.site_mask_ids] + for v in expected_variant_vars: + a = ds[v] + assert isinstance(a, xr.DataArray) + assert a.dims == ("variants",) + + # Check cohort variables. + expected_cohort_vars = [ + "cohort_label", + "cohort_size", + "cohort_taxon", + "cohort_area", + "cohort_period", + "cohort_period_start", + "cohort_period_end", + "cohort_lat_mean", + "cohort_lat_min", + "cohort_lat_max", + "cohort_lon_mean", + "cohort_lon_min", + "cohort_lon_max", + ] + for v in expected_cohort_vars: + a = ds[v] + assert isinstance(a, xr.DataArray) + assert a.dims == ("cohorts",) + + # Check event variables. + expected_event_vars = [ + "event_count", + "event_nobs", + "event_frequency", + "event_frequency_ci_low", + "event_frequency_ci_upp", + ] + for v in expected_event_vars: + a = ds[v] + assert isinstance(a, xr.DataArray) + assert a.dims == ("variants", "cohorts") + + # Sanity check for frequency values. + x = ds["event_frequency"].values + check_frequency(x) + + # Sanity check area values. + df_samples = api.sample_metadata(sample_sets=sample_sets, sample_query=sample_query) + expected_area_values = np.unique(df_samples[area_by].dropna().values) + area_values = ds["cohort_area"].values + # N.B., some areas may not end up in final dataset if cohort + # size is too small, so do a set membership test + for a in area_values: + assert a in expected_area_values + + # Sanity checks for period values. + period_values = ds["cohort_period"].values + if period_by == "year": + expected_freqstr = "A-DEC" + elif period_by == "month": + expected_freqstr = "M" + elif period_by == "quarter": + expected_freqstr = "Q-DEC" + else: + assert False, "not implemented" + for p in period_values: + assert isinstance(p, pd.Period) + assert p.freqstr == expected_freqstr + + # Sanity check cohort sizes. + cohort_size_values = ds["cohort_size"].values + for s in cohort_size_values: + assert s >= min_cohort_size + + if area_by == "admin1_iso" and period_by == "year" and nobs_mode == "called": + # Here we test the behaviour of the function when grouping by admin level + # 1 and year. We can do some more in-depth testing in this case because + # we can compare results directly against the simpler snp_allele_frequencies() + # function with the admin1_year cohorts. + + # Check consistency with the basic snp allele frequencies method. + df_af = api.snp_allele_frequencies( + transcript=transcript, + cohorts="admin1_year", + sample_sets=sample_sets, + sample_query=sample_query, + min_cohort_size=min_cohort_size, + site_mask=site_mask, + include_counts=True, + ) + # Make sure all variables available to check. + df_af = df_af.reset_index() + if variant_query is not None: + df_af = df_af.query(variant_query) + + # Check cohorts are consistent. + expect_cohort_labels = sorted( + [c.split("frq_")[1] for c in df_af.columns if c.startswith("frq_")] + ) + cohort_labels = sorted(ds["cohort_label"].values) + assert cohort_labels == expect_cohort_labels + + # Check variants are consistent. + assert ds.sizes["variants"] == len(df_af) + for v in expected_variant_vars: + c = v.split("variant_")[1] + actual = ds[v] + expect = df_af[c] + compare_series_like(actual, expect) + + # Check frequencies are consistent. + for cohort_index, cohort_label in enumerate(ds["cohort_label"].values): + actual_nobs = ds["event_nobs"].values[:, cohort_index] + expect_nobs = df_af[f"nobs_{cohort_label}"].values + assert_array_equal(actual_nobs, expect_nobs) + actual_count = ds["event_count"].values[:, cohort_index] + expect_count = df_af[f"count_{cohort_label}"].values + assert_array_equal(actual_count, expect_count) + actual_frq = ds["event_frequency"].values[:, cohort_index] + expect_frq = df_af[f"frq_{cohort_label}"].values + assert_allclose(actual_frq, expect_frq) + + +def check_aa_allele_frequencies_advanced( + *, + api: AnophelesSnpFrequencyAnalysis, + transcript=None, + area_by="admin1_iso", + period_by="year", + sample_sets=None, + sample_query=None, + min_cohort_size=None, + nobs_mode="called", + variant_query=None, +): + # Pick test parameters at random. + if transcript is None: + transcript = random_transcript(api=api).name + if area_by is None: + area_by = random.choice(["country", "admin1_iso", "admin2_name"]) + if period_by is None: + period_by = random.choice(["year", "quarter", "month"]) + if sample_sets is None: + all_sample_sets = api.sample_sets()["sample_set"].to_list() + sample_sets = random.choice(all_sample_sets) + if min_cohort_size is None: + min_cohort_size = random.randint(0, 2) + + # Run function under test. + ds = api.aa_allele_frequencies_advanced( + transcript=transcript, + area_by=area_by, + period_by=period_by, + sample_sets=sample_sets, + sample_query=sample_query, + min_cohort_size=min_cohort_size, + nobs_mode=nobs_mode, + variant_query=variant_query, + ) + + # Check the result. + assert isinstance(ds, xr.Dataset) + assert set(ds.dims) == {"cohorts", "variants"} + + expected_variant_vars = ( + "variant_label", + "variant_contig", + "variant_position", + "variant_max_af", + "variant_transcript", + "variant_effect", + "variant_impact", + "variant_ref_aa", + "variant_alt_aa", + "variant_aa_pos", + "variant_aa_change", + ) + for v in expected_variant_vars: + a = ds[v] + assert isinstance(a, xr.DataArray) + assert a.dims == ("variants",) + + expected_cohort_vars = ( + "cohort_label", + "cohort_size", + "cohort_taxon", + "cohort_area", + "cohort_period", + "cohort_period_start", + "cohort_period_end", + "cohort_lat_mean", + "cohort_lat_min", + "cohort_lat_max", + "cohort_lon_mean", + "cohort_lon_min", + "cohort_lon_max", + ) + for v in expected_cohort_vars: + a = ds[v] + assert isinstance(a, xr.DataArray) + assert a.dims == ("cohorts",) + + expected_event_vars = ( + "event_count", + "event_nobs", + "event_frequency", + "event_frequency_ci_low", + "event_frequency_ci_upp", + ) + for v in expected_event_vars: + a = ds[v] + assert isinstance(a, xr.DataArray) + assert a.dims == ("variants", "cohorts") + + # Sanity check for frequency values. + x = ds["event_frequency"].values + check_frequency(x) + + # Sanity checks for area values. + df_samples = api.sample_metadata(sample_sets=sample_sets, sample_query=sample_query) + expected_area_values = np.unique(df_samples[area_by].dropna().values) + area_values = ds["cohort_area"].values + # N.B., some areas may not end up in final dataset if cohort + # size is too small, so do a set membership test + for a in area_values: + assert a in expected_area_values + + # Sanity checks for period values. + period_values = ds["cohort_period"].values + if period_by == "year": + expected_freqstr = "A-DEC" + elif period_by == "month": + expected_freqstr = "M" + elif period_by == "quarter": + expected_freqstr = "Q-DEC" + else: + assert False, "not implemented" + for p in period_values: + assert isinstance(p, pd.Period) + assert p.freqstr == expected_freqstr + + # Sanity check cohort size. + cohort_size_values = ds["cohort_size"].values + for s in cohort_size_values: + assert s >= min_cohort_size + + if area_by == "admin1_iso" and period_by == "year" and nobs_mode == "called": + # Here we test the behaviour of the function when grouping by admin level + # 1 and year. We can do some more in-depth testing in this case because + # we can compare results directly against the simpler aa_allele_frequencies() + # function with the admin1_year cohorts. + + # Check consistency with the basic aa allele frequencies method. + df_af = api.aa_allele_frequencies( + transcript=transcript, + cohorts="admin1_year", + sample_sets=sample_sets, + sample_query=sample_query, + min_cohort_size=min_cohort_size, + include_counts=True, + ) + # Make sure all variables available to check. + df_af = df_af.reset_index() + if variant_query is not None: + df_af = df_af.query(variant_query) + + # Check cohorts are consistent. + expect_cohort_labels = sorted( + [c.split("frq_")[1] for c in df_af.columns if c.startswith("frq_")] + ) + cohort_labels = sorted(ds["cohort_label"].values) + assert cohort_labels == expect_cohort_labels + + # Check variants are consistent. + assert ds.sizes["variants"] == len(df_af) + for v in expected_variant_vars: + c = v.split("variant_")[1] + actual = ds[v] + expect = df_af[c] + compare_series_like(actual, expect) + + # Check frequencies are consistent. + for cohort_index, cohort_label in enumerate(ds["cohort_label"].values): + actual_nobs = ds["event_nobs"].values[:, cohort_index] + expect_nobs = df_af[f"nobs_{cohort_label}"].values + assert_array_equal(actual_nobs, expect_nobs) + actual_count = ds["event_count"].values[:, cohort_index] + expect_count = df_af[f"count_{cohort_label}"].values + assert_array_equal(actual_count, expect_count) + actual_frq = ds["event_frequency"].values[:, cohort_index] + expect_frq = df_af[f"frq_{cohort_label}"].values + assert_allclose(actual_frq, expect_frq) + + +# Here we don't explore the full matrix, but vary one parameter at a time, otherwise +# the test suite would take too long to run. + + +@pytest.mark.parametrize("area_by", ["country", "admin1_iso", "admin2_name"]) +@parametrize_with_cases("fixture,api", cases=".") +def test_allele_frequencies_advanced_with_area_by( + fixture, + api: AnophelesSnpFrequencyAnalysis, + area_by, +): + check_snp_allele_frequencies_advanced( + api=api, + area_by=area_by, + ) + check_aa_allele_frequencies_advanced( + api=api, + area_by=area_by, + ) + + +@pytest.mark.parametrize("period_by", ["year", "quarter", "month"]) +@parametrize_with_cases("fixture,api", cases=".") +def test_allele_frequencies_advanced_with_period_by( + fixture, + api: AnophelesSnpFrequencyAnalysis, + period_by, +): + check_snp_allele_frequencies_advanced( + api=api, + period_by=period_by, + ) + check_aa_allele_frequencies_advanced( + api=api, + period_by=period_by, + ) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_allele_frequencies_advanced_with_sample_query( + fixture, + api: AnophelesSnpFrequencyAnalysis, +): + all_sample_sets = api.sample_sets()["sample_set"].to_list() + df_samples = api.sample_metadata(sample_sets=all_sample_sets) + countries = df_samples["country"].unique() + country = random.choice(countries) + sample_query = f"country == '{country}'" + + check_snp_allele_frequencies_advanced( + api=api, + sample_sets=all_sample_sets, + sample_query=sample_query, + min_cohort_size=0, + ) + check_aa_allele_frequencies_advanced( + api=api, + sample_sets=all_sample_sets, + sample_query=sample_query, + min_cohort_size=0, + ) + + +@pytest.mark.parametrize("min_cohort_size", [0, 10, 100]) +@parametrize_with_cases("fixture,api", cases=".") +def test_allele_frequencies_advanced_with_min_cohort_size( + fixture, + api: AnophelesSnpFrequencyAnalysis, + min_cohort_size, +): + all_sample_sets = api.sample_sets()["sample_set"].to_list() + area_by = "admin1_iso" + period_by = "year" + transcript = random_transcript(api=api).name + + if min_cohort_size <= 10: + # Expect this to find at least one cohort, so go ahead with full + # checks. + check_snp_allele_frequencies_advanced( + api=api, + transcript=transcript, + sample_sets=all_sample_sets, + min_cohort_size=min_cohort_size, + area_by=area_by, + period_by=period_by, + ) + check_aa_allele_frequencies_advanced( + api=api, + transcript=transcript, + sample_sets=all_sample_sets, + min_cohort_size=min_cohort_size, + area_by=area_by, + period_by=period_by, + ) + else: + # Expect this to find no cohorts. + with pytest.raises(ValueError): + api.snp_allele_frequencies_advanced( + transcript=transcript, + sample_sets=all_sample_sets, + min_cohort_size=min_cohort_size, + area_by=area_by, + period_by=period_by, + ) + with pytest.raises(ValueError): + api.aa_allele_frequencies_advanced( + transcript=transcript, + sample_sets=all_sample_sets, + min_cohort_size=min_cohort_size, + area_by=area_by, + period_by=period_by, + ) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_allele_frequencies_advanced_with_variant_query( + fixture, + api: AnophelesSnpFrequencyAnalysis, +): + all_sample_sets = api.sample_sets()["sample_set"].to_list() + area_by = "admin1_iso" + period_by = "year" + transcript = random_transcript(api=api).name + + # Test a query that should succeed. + variant_query = "effect == 'NON_SYNONYMOUS_CODING'" + check_snp_allele_frequencies_advanced( + api=api, + transcript=transcript, + sample_sets=all_sample_sets, + area_by=area_by, + period_by=period_by, + variant_query=variant_query, + ) + check_aa_allele_frequencies_advanced( + api=api, + transcript=transcript, + sample_sets=all_sample_sets, + area_by=area_by, + period_by=period_by, + variant_query=variant_query, + ) + + # Test a query that should fail. + variant_query = "effect == 'foobar'" + with pytest.raises(ValueError): + api.snp_allele_frequencies_advanced( + transcript=transcript, + sample_sets=all_sample_sets, + area_by=area_by, + period_by=period_by, + variant_query=variant_query, + ) + with pytest.raises(ValueError): + api.aa_allele_frequencies_advanced( + transcript=transcript, + sample_sets=all_sample_sets, + area_by=area_by, + period_by=period_by, + variant_query=variant_query, + ) + + +@pytest.mark.parametrize("nobs_mode", ["called", "fixed"]) +@parametrize_with_cases("fixture,api", cases=".") +def test_allele_frequencies_advanced_with_nobs_mode( + fixture, + api: AnophelesSnpFrequencyAnalysis, + nobs_mode, +): + check_snp_allele_frequencies_advanced( + api=api, + nobs_mode=nobs_mode, + ) + check_aa_allele_frequencies_advanced( + api=api, + nobs_mode=nobs_mode, + ) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_allele_frequencies_advanced_with_dup_samples( + fixture, + api: AnophelesSnpFrequencyAnalysis, +): + all_sample_sets = api.sample_sets()["sample_set"].to_list() + sample_set = random.choice(all_sample_sets) + sample_sets = [sample_set, sample_set] + + check_snp_allele_frequencies_advanced( + api=api, + sample_sets=sample_sets, + ) + check_aa_allele_frequencies_advanced( + api=api, + sample_sets=sample_sets, + ) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_plot_frequencies_heatmap( + fixture, + api: AnophelesSnpFrequencyAnalysis, +): + # Pick test parameters at random. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + sample_sets = random.choice(all_sample_sets) + site_mask = random.choice(api.site_mask_ids + (None,)) + min_cohort_size = random.randint(0, 2) + transcript = random_transcript(api=api).name + cohorts = random.choice( + ["admin1_year", "admin1_month", "admin2_year", "admin2_month"] + ) + + # Set up call params. + params = dict( + transcript=transcript, + cohorts=cohorts, + min_cohort_size=min_cohort_size, + site_mask=site_mask, + sample_sets=sample_sets, + ) + + # Test SNP allele frequencies. + df_snp = api.snp_allele_frequencies(**params) + fig = api.plot_frequencies_heatmap(df_snp, show=False, max_len=None) + assert isinstance(fig, go.Figure) + + # Test amino acid change allele frequencies. + df_aa = api.aa_allele_frequencies(**params) + fig = api.plot_frequencies_heatmap(df_aa, show=False, max_len=None) + assert isinstance(fig, go.Figure) + + # Test max_len behaviour. + with pytest.raises(ValueError): + api.plot_frequencies_heatmap(df_snp, show=False, max_len=len(df_snp) - 1) + + # Test index parameter - if None, should use dataframe index. + fig = api.plot_frequencies_heatmap(df_snp, show=False, index=None, max_len=None) + # Not unique. + with pytest.raises(ValueError): + api.plot_frequencies_heatmap(df_snp, show=False, index="contig", max_len=None) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_plot_frequencies_time_series( + fixture, + api: AnophelesSnpFrequencyAnalysis, +): + # Pick test parameters at random. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + sample_sets = random.choice(all_sample_sets) + site_mask = random.choice(api.site_mask_ids + (None,)) + min_cohort_size = random.randint(0, 2) + transcript = random_transcript(api=api).name + area_by = random.choice(["country", "admin1_iso", "admin2_name"]) + period_by = random.choice(["year", "quarter", "month"]) + + # Compute SNP frequencies. + ds = api.snp_allele_frequencies_advanced( + transcript=transcript, + area_by=area_by, + period_by=period_by, + sample_sets=sample_sets, + min_cohort_size=min_cohort_size, + site_mask=site_mask, + ) + + # Trim things down a bit for speed. + ds = ds.isel(variants=slice(0, 100)) + + # Plot. + fig = api.plot_frequencies_time_series(ds, show=False) + + # Test. + assert isinstance(fig, go.Figure) + + # Compute amino acid change frequencies. + ds = api.aa_allele_frequencies_advanced( + transcript=transcript, + area_by=area_by, + period_by=period_by, + sample_sets=sample_sets, + min_cohort_size=min_cohort_size, + ) + + # Trim things down a bit for speed. + ds = ds.isel(variants=slice(0, 100)) + + # Plot. + fig = api.plot_frequencies_time_series(ds, show=False) + + # Test. + assert isinstance(fig, go.Figure) + + +@parametrize_with_cases("fixture,api", cases=".") +def test_plot_frequencies_interactive_map( + fixture, + api: AnophelesSnpFrequencyAnalysis, +): + import ipywidgets # type: ignore + + # Pick test parameters at random. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + sample_sets = random.choice(all_sample_sets) + site_mask = random.choice(api.site_mask_ids + (None,)) + min_cohort_size = random.randint(0, 2) + transcript = random_transcript(api=api).name + area_by = random.choice(["country", "admin1_iso", "admin2_name"]) + period_by = random.choice(["year", "quarter", "month"]) + + # Compute SNP frequencies. + ds = api.snp_allele_frequencies_advanced( + transcript=transcript, + area_by=area_by, + period_by=period_by, + sample_sets=sample_sets, + min_cohort_size=min_cohort_size, + site_mask=site_mask, + ) + + # Trim things down a bit for speed. + ds = ds.isel(variants=slice(0, 100)) + + # Plot. + fig = api.plot_frequencies_interactive_map(ds) + + # Test. + assert isinstance(fig, ipywidgets.Widget) + + # Compute amino acid change frequencies. + ds = api.aa_allele_frequencies_advanced( + transcript=transcript, + area_by=area_by, + period_by=period_by, + sample_sets=sample_sets, + min_cohort_size=min_cohort_size, + ) + + # Trim things down a bit for speed. + ds = ds.isel(variants=slice(0, 100)) + + # Plot. + fig = api.plot_frequencies_interactive_map(ds) + + # Test. + assert isinstance(fig, ipywidgets.Widget) diff --git a/tests/test_af1.py b/tests/test_af1.py index b849225cd..2ef10dcc8 100644 --- a/tests/test_af1.py +++ b/tests/test_af1.py @@ -1,8 +1,6 @@ import numpy as np -import pandas as pd import pytest -import xarray as xr -from numpy.testing import assert_allclose, assert_array_equal +from numpy.testing import assert_allclose from malariagen_data import Af1, Region from malariagen_data.util import locate_region, resolve_region @@ -29,142 +27,6 @@ def test_repr(): assert isinstance(r, str) -# TODO: test_snp_effects() for Af1.0 -# # reverse strand gene -# gste2 = "LOC125761549" -# -# # test forward strand gene gste6 -# gste6 = "LOC125767311" -# -# # check 5' utr intron and the different intron effects -# utr_intron5 = "utr_LOC125767311_t1_1" -# -# # check 3' utr intron -# utr_intron3 = "utr_LOC125767311_t1_3" - - -def test_snp_allele_frequencies__dict_cohorts(): - af1 = setup_af1(cohorts_analysis="20221129") - cohorts = { - "ke": "country == 'Kenya'", - "gh_2017": "country == 'Ghana' and year == 2017", - } - universal_fields = [ - "pass_funestus", - "label", - ] - - # test drop invariants - df = af1.snp_allele_frequencies( - transcript="LOC125761549_t5", - cohorts=cohorts, - site_mask="funestus", - sample_sets="1.0", - drop_invariant=True, - effects=False, - ) - - assert isinstance(df, pd.DataFrame) - frq_columns = ["frq_" + s for s in list(cohorts.keys())] - expected_fields = universal_fields + frq_columns + ["max_af"] - assert sorted(df.columns.tolist()) == sorted(expected_fields) - assert df.shape == (3185, len(expected_fields)) - assert df.iloc[2].frq_ke == 0 - assert df.iloc[2].frq_gh_2017 == pytest.approx(0.013889, abs=1e-6) - assert df.iloc[2].max_af == pytest.approx(0.013889, abs=1e-6) - # check invariant have been dropped - assert df.max_af.min() > 0 - - # test keep invariants - df = af1.snp_allele_frequencies( - transcript="LOC125761549_t7", - cohorts=cohorts, - site_mask="funestus", - sample_sets="1.0", - drop_invariant=False, - effects=False, - ) - assert isinstance(df, pd.DataFrame) - assert sorted(df.columns.tolist()) == sorted(expected_fields) - assert df.shape == (29418, len(expected_fields)) - # check invariant positions are still present - assert np.any(df.max_af == 0) - - -def test_snp_allele_frequencies__str_cohorts__effects(): - af1 = setup_af1(cohorts_analysis="20221129") - cohorts = "admin1_month" - min_cohort_size = 10 - universal_fields = [ - "pass_funestus", - "label", - ] - effects_fields = [ - "transcript", - "effect", - "impact", - "ref_codon", - "alt_codon", - "aa_pos", - "ref_aa", - "alt_aa", - ] - df = af1.snp_allele_frequencies( - transcript="LOC125767311_t2", - cohorts=cohorts, - min_cohort_size=min_cohort_size, - site_mask="funestus", - sample_sets="1.0", - drop_invariant=True, - effects=True, - ) - df_coh = af1.cohorts_metadata(sample_sets="1.0") - coh_nm = "cohort_" + cohorts - coh_counts = df_coh[coh_nm].dropna().value_counts() - cohort_labels = coh_counts[coh_counts >= min_cohort_size].index.to_list() - frq_cohort_labels = ["frq_" + s for s in cohort_labels] - expected_fields = universal_fields + frq_cohort_labels + ["max_af"] + effects_fields - - assert isinstance(df, pd.DataFrame) - assert len(df) == 4221 - assert sorted(df.columns.tolist()) == sorted(expected_fields) - assert df.index.names == [ - "contig", - "position", - "ref_allele", - "alt_allele", - "aa_change", - ] - - -def test_snp_allele_frequencies__query(): - af1 = setup_af1(cohorts_analysis="20221129") - cohorts = "admin1_year" - min_cohort_size = 10 - expected_columns = [ - "pass_funestus", - "frq_GH-AH_fune_2014", - "frq_GH-NP_fune_2017", - "max_af", - "label", - ] - - df = af1.snp_allele_frequencies( - transcript="LOC125767311_t2", - cohorts=cohorts, - sample_query="country == 'Ghana'", - min_cohort_size=min_cohort_size, - site_mask="funestus", - sample_sets="1.0", - drop_invariant=True, - effects=False, - ) - - assert isinstance(df, pd.DataFrame) - assert sorted(df.columns) == sorted(expected_columns) - assert len(df) == 1309 - - @pytest.mark.parametrize( "region_raw", [ @@ -216,470 +78,6 @@ def test_locate_region(region_raw): assert region == Region("2RL", 24630355, 24633221) -def test_aa_allele_frequencies(): - af1 = setup_af1(cohorts_analysis="20221129") - - expected_fields = [ - "transcript", - "aa_pos", - "ref_allele", - "alt_allele", - "ref_aa", - "alt_aa", - "effect", - "impact", - "frq_CD-HU_fune_2017", - "frq_MZ-P_fune_2015", - "max_af", - "label", - ] - - df = af1.aa_allele_frequencies( - transcript="LOC125767311_t2", - cohorts="admin1_year", - min_cohort_size=10, - site_mask="funestus", - sample_sets=("1240-VO-CD-KOEKEMOER-VMF00099", "1240-VO-MZ-KOEKEMOER-VMF00101"), - drop_invariant=True, - ) - - assert sorted(df.columns.tolist()) == sorted(expected_fields) - assert isinstance(df, pd.DataFrame) - assert df.index.names == ["aa_change", "contig", "position"] - assert df.shape == (53, len(expected_fields)) - assert df.loc["V947L"].max_af[0] == pytest.approx(0.025, abs=1e-6) - - -# noinspection PyDefaultArgument -def _check_snp_allele_frequencies_advanced( - transcript="LOC125767311_t2", - area_by="admin1_iso", - period_by="year", - sample_sets=[ - "1229-VO-GH-DADZIE-VMF00095", - "1240-VO-CD-KOEKEMOER-VMF00099", - "1240-VO-MZ-KOEKEMOER-VMF00101", - ], - sample_query=None, - min_cohort_size=10, - nobs_mode="called", - variant_query="max_af > 0.02", -): - af1 = setup_af1(cohorts_analysis="20221129") - - ds = af1.snp_allele_frequencies_advanced( - transcript=transcript, - area_by=area_by, - period_by=period_by, - sample_sets=sample_sets, - sample_query=sample_query, - min_cohort_size=min_cohort_size, - nobs_mode=nobs_mode, - variant_query=variant_query, - ) - - assert isinstance(ds, xr.Dataset) - - # noinspection PyTypeChecker - assert sorted(ds.dims) == ["cohorts", "variants"] - - expected_variant_vars = ( - "variant_label", - "variant_contig", - "variant_position", - "variant_ref_allele", - "variant_alt_allele", - "variant_max_af", - "variant_pass_funestus", - "variant_transcript", - "variant_effect", - "variant_impact", - "variant_ref_codon", - "variant_alt_codon", - "variant_ref_aa", - "variant_alt_aa", - "variant_aa_pos", - "variant_aa_change", - ) - for v in expected_variant_vars: - a = ds[v] - assert isinstance(a, xr.DataArray) - assert a.dims == ("variants",) - - expected_cohort_vars = ( - "cohort_label", - "cohort_size", - "cohort_taxon", - "cohort_area", - "cohort_period", - "cohort_period_start", - "cohort_period_end", - "cohort_lat_mean", - "cohort_lat_min", - "cohort_lat_max", - "cohort_lon_mean", - "cohort_lon_min", - "cohort_lon_max", - ) - for v in expected_cohort_vars: - a = ds[v] - assert isinstance(a, xr.DataArray) - assert a.dims == ("cohorts",) - - expected_event_vars = ( - "event_count", - "event_nobs", - "event_frequency", - "event_frequency_ci_low", - "event_frequency_ci_upp", - ) - for v in expected_event_vars: - a = ds[v] - assert isinstance(a, xr.DataArray) - assert a.dims == ("variants", "cohorts") - - # sanity checks for area values - df_samples = af1.sample_metadata(sample_sets=sample_sets) - if sample_query is not None: - df_samples = df_samples.query(sample_query) - expected_area = np.unique(df_samples[area_by].dropna().values) - area = ds["cohort_area"].values - # N.B., some areas may not end up in final dataset if cohort - # size is too small, so do a set membership test - for a in area: - assert a in expected_area - - # sanity checks for period values - period = ds["cohort_period"].values - if period_by == "year": - expected_freqstr = "A-DEC" - elif period_by == "month": - expected_freqstr = "M" - elif period_by == "quarter": - expected_freqstr = "Q-DEC" - else: - assert False, "not implemented" - for p in period: - assert isinstance(p, pd.Period) - assert p.freqstr == expected_freqstr - - # sanity check cohort size - size = ds["cohort_size"].values - for s in size: - assert s >= min_cohort_size - - if area_by == "admin1_iso" and period_by == "year" and nobs_mode == "called": - # Here we test the behaviour of the function when grouping by admin level - # 1 and year. We can do some more in-depth testing in this case because - # we can compare results directly against the simpler snp_allele_frequencies() - # function with the admin1_year cohorts. - - # check consistency with the basic snp allele frequencies method - df_af = af1.snp_allele_frequencies( - transcript=transcript, - cohorts="admin1_year", - sample_sets=sample_sets, - sample_query=sample_query, - min_cohort_size=min_cohort_size, - ) - df_af = df_af.reset_index() # make sure all variables available to check - if variant_query is not None: - df_af = df_af.query(variant_query) - - # check cohorts are consistent - expect_cohort_labels = sorted( - [c.split("frq_")[1] for c in df_af.columns if c.startswith("frq_")] - ) - cohort_labels = sorted(ds["cohort_label"].values) - assert cohort_labels == expect_cohort_labels - - # check variants are consistent - assert ds.sizes["variants"] == len(df_af) - for v in expected_variant_vars: - c = v.split("variant_")[1] - actual = ds[v] - expect = df_af[c] - _compare_series_like(actual, expect) - - # check frequencies are consistent - for cohort_index, cohort_label in enumerate(ds["cohort_label"].values): - actual_frq = ds["event_frequency"].values[:, cohort_index] - expect_frq = df_af[f"frq_{cohort_label}"].values - assert_allclose(actual_frq, expect_frq) - - -# noinspection PyDefaultArgument -def _check_aa_allele_frequencies_advanced( - transcript="LOC125767311_t2", - area_by="admin1_iso", - period_by="year", - sample_sets=[ - "1229-VO-GH-DADZIE-VMF00095", - "1240-VO-CD-KOEKEMOER-VMF00099", - "1240-VO-MZ-KOEKEMOER-VMF00101", - ], - sample_query=None, - min_cohort_size=10, - nobs_mode="called", - variant_query="max_af > 0.02", -): - af1 = setup_af1(cohorts_analysis="20221129") - - ds = af1.aa_allele_frequencies_advanced( - transcript=transcript, - area_by=area_by, - period_by=period_by, - sample_sets=sample_sets, - sample_query=sample_query, - min_cohort_size=min_cohort_size, - nobs_mode=nobs_mode, - variant_query=variant_query, - ) - - assert isinstance(ds, xr.Dataset) - - # noinspection PyTypeChecker - assert sorted(ds.dims) == ["cohorts", "variants"] - - expected_variant_vars = ( - "variant_label", - "variant_contig", - "variant_position", - "variant_max_af", - "variant_transcript", - "variant_effect", - "variant_impact", - "variant_ref_aa", - "variant_alt_aa", - "variant_aa_pos", - "variant_aa_change", - ) - for v in expected_variant_vars: - a = ds[v] - assert isinstance(a, xr.DataArray) - assert a.dims == ("variants",) - - expected_cohort_vars = ( - "cohort_label", - "cohort_size", - "cohort_taxon", - "cohort_area", - "cohort_period", - "cohort_period_start", - "cohort_period_end", - "cohort_lat_mean", - "cohort_lat_min", - "cohort_lat_max", - "cohort_lon_mean", - "cohort_lon_min", - "cohort_lon_max", - ) - for v in expected_cohort_vars: - a = ds[v] - assert isinstance(a, xr.DataArray) - assert a.dims == ("cohorts",) - - expected_event_vars = ( - "event_count", - "event_nobs", - "event_frequency", - "event_frequency_ci_low", - "event_frequency_ci_upp", - ) - for v in expected_event_vars: - a = ds[v] - assert isinstance(a, xr.DataArray) - assert a.dims == ("variants", "cohorts") - - # sanity checks for area values - df_samples = af1.sample_metadata(sample_sets=sample_sets) - if sample_query is not None: - df_samples = df_samples.query(sample_query) - expected_area = np.unique(df_samples[area_by].dropna().values) - area = ds["cohort_area"].values - # N.B., some areas may not end up in final dataset if cohort - # size is too small, so do a set membership test - for a in area: - assert a in expected_area - - # sanity checks for period values - period = ds["cohort_period"].values - if period_by == "year": - expected_freqstr = "A-DEC" - elif period_by == "month": - expected_freqstr = "M" - elif period_by == "quarter": - expected_freqstr = "Q-DEC" - else: - assert False, "not implemented" - for p in period: - assert isinstance(p, pd.Period) - assert p.freqstr == expected_freqstr - - # sanity check cohort size - size = ds["cohort_size"].values - for s in size: - assert s >= min_cohort_size - - if area_by == "admin1_iso" and period_by == "year" and nobs_mode == "called": - # Here we test the behaviour of the function when grouping by admin level - # 1 and year. We can do some more in-depth testing in this case because - # we can compare results directly against the simpler aa_allele_frequencies() - # function with the admin1_year cohorts. - - # check consistency with the basic snp allele frequencies method - df_af = af1.aa_allele_frequencies( - transcript=transcript, - cohorts="admin1_year", - sample_sets=sample_sets, - sample_query=sample_query, - min_cohort_size=min_cohort_size, - ) - df_af = df_af.reset_index() # make sure all variables available to check - if variant_query is not None: - df_af = df_af.query(variant_query) - - # check cohorts are consistent - expect_cohort_labels = sorted( - [c.split("frq_")[1] for c in df_af.columns if c.startswith("frq_")] - ) - cohort_labels = sorted(ds["cohort_label"].values) - assert cohort_labels == expect_cohort_labels - - # check variants are consistent - assert ds.sizes["variants"] == len(df_af) - for v in expected_variant_vars: - c = v.split("variant_")[1] - actual = ds[v] - expect = df_af[c] - _compare_series_like(actual, expect) - - # check frequencies are consistent - for cohort_index, cohort_label in enumerate(ds["cohort_label"].values): - print(cohort_label) - actual_frq = ds["event_frequency"].values[:, cohort_index] - expect_frq = df_af[f"frq_{cohort_label}"].values - assert_allclose(actual_frq, expect_frq) - - -# Here we don't explore the full matrix, but vary one parameter at a time, otherwise -# the test suite would take too long to run. - - -@pytest.mark.parametrize( - "transcript", ["LOC125767311_t2", "LOC125761549_t5", "LOC125761549_t7"] -) -def test_allele_frequencies_advanced__transcript(transcript): - _check_snp_allele_frequencies_advanced( - transcript=transcript, - ) - _check_aa_allele_frequencies_advanced( - transcript=transcript, - ) - - -@pytest.mark.parametrize("area_by", ["country", "admin1_iso", "admin2_name"]) -def test_allele_frequencies_advanced__area_by(area_by): - _check_snp_allele_frequencies_advanced( - area_by=area_by, - ) - _check_aa_allele_frequencies_advanced( - area_by=area_by, - ) - - -@pytest.mark.parametrize("period_by", ["year", "quarter", "month"]) -def test_allele_frequencies_advanced__period_by(period_by): - _check_snp_allele_frequencies_advanced( - period_by=period_by, - ) - _check_aa_allele_frequencies_advanced( - period_by=period_by, - ) - - -@pytest.mark.parametrize( - "sample_sets", - [ - "1229-VO-GH-DADZIE-VMF00095", - ["1240-VO-CD-KOEKEMOER-VMF00099", "1240-VO-MZ-KOEKEMOER-VMF00101"], - "1.0", - ], -) -def test_allele_frequencies_advanced__sample_sets(sample_sets): - _check_snp_allele_frequencies_advanced( - sample_sets=sample_sets, - ) - _check_aa_allele_frequencies_advanced( - sample_sets=sample_sets, - ) - - -def test_allele_frequencies_advanced__sample_query(): - _check_snp_allele_frequencies_advanced( - sample_query="taxon == 'funestus' and country in ['Ghana', 'Gabon']", - ) - # noinspection PyTypeChecker - _check_aa_allele_frequencies_advanced( - sample_query="taxon == 'funestus' and country in ['Ghana', 'Gabon']", - variant_query=None, - ) - - -@pytest.mark.parametrize("min_cohort_size", [10, 40]) -def test_allele_frequencies_advanced__min_cohort_size(min_cohort_size): - _check_snp_allele_frequencies_advanced( - min_cohort_size=min_cohort_size, - ) - _check_aa_allele_frequencies_advanced( - min_cohort_size=min_cohort_size, - ) - - -@pytest.mark.parametrize( - "variant_query", - [ - None, - "effect == 'NON_SYNONYMOUS_CODING' and max_af > 0.05", - "effect == 'foobar'", # no variants - ], -) -def test_allele_frequencies_advanced__variant_query(variant_query): - _check_snp_allele_frequencies_advanced( - variant_query=variant_query, - ) - _check_aa_allele_frequencies_advanced( - variant_query=variant_query, - ) - - -@pytest.mark.parametrize("nobs_mode", ["called", "fixed"]) -def test_allele_frequencies_advanced__nobs_mode(nobs_mode): - _check_snp_allele_frequencies_advanced( - nobs_mode=nobs_mode, - ) - _check_aa_allele_frequencies_advanced( - nobs_mode=nobs_mode, - ) - - -# TODO: this function is a verbatim duplicate, from test_ag3.py -def _compare_series_like(actual, expect): - # compare pandas series-like objects for equality or floating point - # similarity, handling missing values appropriately - - # handle object arrays, these don't get nans compared properly - t = actual.dtype - if t == object: - expect = expect.fillna("NA") - actual = actual.fillna("NA") - - if t.kind == "f": - assert_allclose(actual.values, expect.values) - else: - assert_array_equal(actual.values, expect.values) - - def test_h12_gwss(): af1 = setup_af1(cohorts_analysis="20230823") sample_query = "country == 'Ghana'" diff --git a/tests/test_ag3.py b/tests/test_ag3.py index c20c5b617..feee66add 100644 --- a/tests/test_ag3.py +++ b/tests/test_ag3.py @@ -10,7 +10,7 @@ from malariagen_data import Ag3, Region from malariagen_data.anopheles import _cn_mode -from malariagen_data.util import locate_region, resolve_region +from malariagen_data.util import locate_region, resolve_region, compare_series_like contigs = "2R", "2L", "3R", "3L", "X" @@ -54,238 +54,6 @@ def test_cross_metadata(): assert df_crosses["sex"].unique().tolist() == expected_sex_values -def test_snp_effects(): - ag3 = setup_ag3() - gste2 = "AGAP009194-RA" - site_mask = "gamb_colu" - expected_fields = [ - "contig", - "position", - "ref_allele", - "alt_allele", - "pass_gamb_colu_arab", - "pass_gamb_colu", - "pass_arab", - "transcript", - "effect", - "impact", - "ref_codon", - "alt_codon", - "aa_pos", - "ref_aa", - "alt_aa", - "aa_change", - ] - - df = ag3.snp_effects(transcript=gste2, site_mask=site_mask) - assert isinstance(df, pd.DataFrame) - assert df.columns.tolist() == expected_fields - - # reverse strand gene - assert df.shape == (2838, len(expected_fields)) - # check first, second, third codon position non-syn - assert df.iloc[1454].aa_change == "I114L" - assert df.iloc[1446].aa_change == "I114M" - # while we are here, check all columns for a position - assert df.iloc[1451].position == 28598166 - assert df.iloc[1451].ref_allele == "A" - assert df.iloc[1451].alt_allele == "G" - assert df.iloc[1451].effect == "NON_SYNONYMOUS_CODING" - assert df.iloc[1451].impact == "MODERATE" - assert df.iloc[1451].ref_codon == "aTt" - assert df.iloc[1451].alt_codon == "aCt" - assert df.iloc[1451].aa_pos == 114 - assert df.iloc[1451].ref_aa == "I" - assert df.iloc[1451].alt_aa == "T" - assert df.iloc[1451].aa_change == "I114T" - # check syn - assert df.iloc[1447].aa_change == "I114I" - # check intronic - assert df.iloc[1197].effect == "INTRONIC" - # check 5' utr - assert df.iloc[2661].effect == "FIVE_PRIME_UTR" - # check 3' utr - assert df.iloc[0].effect == "THREE_PRIME_UTR" - - # test forward strand gene gste6 - gste6 = "AGAP009196-RA" - df = ag3.snp_effects(transcript=gste6, site_mask=site_mask) - assert isinstance(df, pd.DataFrame) - assert df.columns.tolist() == expected_fields - assert df.shape == (2829, len(expected_fields)) - - # check first, second, third codon position non-syn - assert df.iloc[701].aa_change == "E35*" - assert df.iloc[703].aa_change == "E35V" - # while we are here, check all columns for a position - assert df.iloc[706].position == 28600605 - assert df.iloc[706].ref_allele == "G" - assert df.iloc[706].alt_allele == "C" - assert df.iloc[706].effect == "NON_SYNONYMOUS_CODING" - assert df.iloc[706].impact == "MODERATE" - assert df.iloc[706].ref_codon == "gaG" - assert df.iloc[706].alt_codon == "gaC" - assert df.iloc[706].aa_pos == 35 - assert df.iloc[706].ref_aa == "E" - assert df.iloc[706].alt_aa == "D" - assert df.iloc[706].aa_change == "E35D" - # check syn - assert df.iloc[705].aa_change == "E35E" - # check intronic - assert df.iloc[900].effect == "INTRONIC" - # check 5' utr - assert df.iloc[0].effect == "FIVE_PRIME_UTR" - # check 3' utr - assert df.iloc[2828].effect == "THREE_PRIME_UTR" - - # check 5' utr intron and the different intron effects - utr_intron5 = "AGAP004679-RB" - df = ag3.snp_effects(transcript=utr_intron5, site_mask=site_mask) - assert isinstance(df, pd.DataFrame) - assert df.columns.tolist() == expected_fields - assert df.shape == (7686, len(expected_fields)) - assert df.iloc[180].effect == "SPLICE_CORE" - assert df.iloc[198].effect == "SPLICE_REGION" - assert df.iloc[202].effect == "INTRONIC" - - # check 3' utr intron - utr_intron3 = "AGAP000689-RA" - df = ag3.snp_effects(transcript=utr_intron3, site_mask=site_mask) - assert isinstance(df, pd.DataFrame) - assert df.columns.tolist() == expected_fields - assert df.shape == (5397, len(expected_fields)) - assert df.iloc[646].effect == "SPLICE_CORE" - assert df.iloc[652].effect == "SPLICE_REGION" - assert df.iloc[674].effect == "INTRONIC" - - -def test_snp_allele_frequencies__dict_cohorts(): - ag3 = setup_ag3(cohorts_analysis="20230516") - cohorts = { - "ke": "country == 'Kenya'", - "bf_2012_col": "country == 'Burkina Faso' and year == 2012 and aim_species == 'coluzzii'", - } - universal_fields = [ - "pass_gamb_colu_arab", - "pass_gamb_colu", - "pass_arab", - "label", - ] - - # test drop invariants - df = ag3.snp_allele_frequencies( - transcript="AGAP009194-RA", - cohorts=cohorts, - site_mask="gamb_colu", - sample_sets="3.0", - drop_invariant=True, - effects=False, - ) - - assert isinstance(df, pd.DataFrame) - frq_columns = ["frq_" + s for s in list(cohorts.keys())] - expected_fields = universal_fields + frq_columns + ["max_af"] - assert sorted(df.columns.tolist()) == sorted(expected_fields) - assert df.shape == (133, len(expected_fields)) - assert df.iloc[3].frq_ke == 0 - assert df.iloc[4].frq_bf_2012_col == pytest.approx(0.006097, abs=1e-6) - assert df.iloc[4].max_af == pytest.approx(0.006097, abs=1e-6) - # check invariant have been dropped - assert df.max_af.min() > 0 - - # test keep invariants - df = ag3.snp_allele_frequencies( - transcript="AGAP004707-RD", - cohorts=cohorts, - site_mask="gamb_colu", - sample_sets="3.0", - drop_invariant=False, - effects=False, - ) - assert isinstance(df, pd.DataFrame) - assert sorted(df.columns.tolist()) == sorted(expected_fields) - assert df.shape == (132306, len(expected_fields)) - # check invariant positions are still present - assert np.any(df.max_af == 0) - - -def test_snp_allele_frequencies__str_cohorts__effects(): - ag3 = setup_ag3(cohorts_analysis="20230516") - cohorts = "admin1_month" - min_cohort_size = 10 - universal_fields = [ - "pass_gamb_colu_arab", - "pass_gamb_colu", - "pass_arab", - "label", - ] - effects_fields = [ - "transcript", - "effect", - "impact", - "ref_codon", - "alt_codon", - "aa_pos", - "ref_aa", - "alt_aa", - ] - df = ag3.snp_allele_frequencies( - transcript="AGAP004707-RD", - cohorts=cohorts, - min_cohort_size=min_cohort_size, - site_mask="gamb_colu", - sample_sets="3.0", - drop_invariant=True, - effects=True, - ) - df_coh = ag3.cohorts_metadata(sample_sets="3.0") - coh_nm = "cohort_" + cohorts - coh_counts = df_coh[coh_nm].dropna().value_counts() - cohort_labels = coh_counts[coh_counts >= min_cohort_size].index.to_list() - frq_cohort_labels = ["frq_" + s for s in cohort_labels] - expected_fields = universal_fields + frq_cohort_labels + ["max_af"] + effects_fields - - assert isinstance(df, pd.DataFrame) - assert len(df) == 16641 - assert sorted(df.columns.tolist()) == sorted(expected_fields) - assert df.index.names == [ - "contig", - "position", - "ref_allele", - "alt_allele", - "aa_change", - ] - - -def test_snp_allele_frequencies__query(): - ag3 = setup_ag3(cohorts_analysis="20230516") - cohorts = "admin1_year" - min_cohort_size = 10 - expected_columns = [ - "pass_gamb_colu_arab", - "pass_gamb_colu", - "pass_arab", - "frq_AO-LUA_colu_2009", - "max_af", - "label", - ] - - df = ag3.snp_allele_frequencies( - transcript="AGAP004707-RD", - cohorts=cohorts, - sample_query="country == 'Angola'", - min_cohort_size=min_cohort_size, - site_mask="gamb_colu", - sample_sets="3.0", - drop_invariant=True, - effects=False, - ) - - assert isinstance(df, pd.DataFrame) - assert sorted(df.columns) == sorted(expected_columns) - assert len(df) == 695 - - @pytest.mark.parametrize("rows", [10, 100, 1000]) @pytest.mark.parametrize("cols", [10, 100, 1000]) @pytest.mark.parametrize("vmax", [2, 12, 100]) @@ -716,450 +484,6 @@ def test_locate_region(region_raw): assert region == Region("2R", 24630355, 24633221) -def test_aa_allele_frequencies(): - ag3 = setup_ag3(cohorts_analysis="20230516") - - expected_fields = [ - "transcript", - "aa_pos", - "ref_allele", - "alt_allele", - "ref_aa", - "alt_aa", - "effect", - "impact", - "frq_BF-09_gamb_2012", - "frq_BF-09_colu_2012", - "frq_BF-09_colu_2014", - "frq_BF-09_gamb_2014", - "frq_BF-07_gamb_2004", - "max_af", - "label", - ] - - df = ag3.aa_allele_frequencies( - transcript="AGAP004707-RD", - cohorts="admin1_year", - min_cohort_size=10, - site_mask="gamb_colu", - sample_sets=("AG1000G-BF-A", "AG1000G-BF-B", "AG1000G-BF-C"), - drop_invariant=True, - ) - - assert sorted(df.columns.tolist()) == sorted(expected_fields) - assert isinstance(df, pd.DataFrame) - assert df.index.names == ["aa_change", "contig", "position"] - assert df.shape == (61, len(expected_fields)) - assert df.loc["V402L"].max_af[0] == pytest.approx(0.121951, abs=1e-6) - - -# noinspection PyDefaultArgument -def _check_snp_allele_frequencies_advanced( - transcript="AGAP004707-RD", - area_by="admin1_iso", - period_by="year", - sample_sets=["AG1000G-BF-A", "AG1000G-ML-A", "AG1000G-UG"], - sample_query=None, - min_cohort_size=10, - nobs_mode="called", - variant_query="max_af > 0.02", -): - ag3 = setup_ag3(cohorts_analysis="20230516") - - ds = ag3.snp_allele_frequencies_advanced( - transcript=transcript, - area_by=area_by, - period_by=period_by, - sample_sets=sample_sets, - sample_query=sample_query, - min_cohort_size=min_cohort_size, - nobs_mode=nobs_mode, - variant_query=variant_query, - ) - - assert isinstance(ds, xr.Dataset) - - # noinspection PyTypeChecker - assert sorted(ds.dims) == ["cohorts", "variants"] - - expected_variant_vars = ( - "variant_label", - "variant_contig", - "variant_position", - "variant_ref_allele", - "variant_alt_allele", - "variant_max_af", - "variant_pass_gamb_colu_arab", - "variant_pass_gamb_colu", - "variant_pass_arab", - "variant_transcript", - "variant_effect", - "variant_impact", - "variant_ref_codon", - "variant_alt_codon", - "variant_ref_aa", - "variant_alt_aa", - "variant_aa_pos", - "variant_aa_change", - ) - for v in expected_variant_vars: - a = ds[v] - assert isinstance(a, xr.DataArray) - assert a.dims == ("variants",) - - expected_cohort_vars = ( - "cohort_label", - "cohort_size", - "cohort_taxon", - "cohort_area", - "cohort_period", - "cohort_period_start", - "cohort_period_end", - "cohort_lat_mean", - "cohort_lat_min", - "cohort_lat_max", - "cohort_lon_mean", - "cohort_lon_min", - "cohort_lon_max", - ) - for v in expected_cohort_vars: - a = ds[v] - assert isinstance(a, xr.DataArray) - assert a.dims == ("cohorts",) - - expected_event_vars = ( - "event_count", - "event_nobs", - "event_frequency", - "event_frequency_ci_low", - "event_frequency_ci_upp", - ) - for v in expected_event_vars: - a = ds[v] - assert isinstance(a, xr.DataArray) - assert a.dims == ("variants", "cohorts") - - # sanity checks for area values - df_samples = ag3.sample_metadata(sample_sets=sample_sets) - if sample_query is not None: - df_samples = df_samples.query(sample_query) - expected_area = np.unique(df_samples[area_by].dropna().values) - area = ds["cohort_area"].values - # N.B., some areas may not end up in final dataset if cohort - # size is too small, so do a set membership test - for a in area: - assert a in expected_area - - # sanity checks for period values - period = ds["cohort_period"].values - if period_by == "year": - expected_freqstr = "A-DEC" - elif period_by == "month": - expected_freqstr = "M" - elif period_by == "quarter": - expected_freqstr = "Q-DEC" - else: - assert False, "not implemented" - for p in period: - assert isinstance(p, pd.Period) - assert p.freqstr == expected_freqstr - - # sanity check cohort size - size = ds["cohort_size"].values - for s in size: - assert s >= min_cohort_size - - if area_by == "admin1_iso" and period_by == "year" and nobs_mode == "called": - # Here we test the behaviour of the function when grouping by admin level - # 1 and year. We can do some more in-depth testing in this case because - # we can compare results directly against the simpler snp_allele_frequencies() - # function with the admin1_year cohorts. - - # check consistency with the basic snp allele frequencies method - df_af = ag3.snp_allele_frequencies( - transcript=transcript, - cohorts="admin1_year", - sample_sets=sample_sets, - sample_query=sample_query, - min_cohort_size=min_cohort_size, - ) - df_af = df_af.reset_index() # make sure all variables available to check - if variant_query is not None: - df_af = df_af.query(variant_query) - - # check cohorts are consistent - expect_cohort_labels = sorted( - [c.split("frq_")[1] for c in df_af.columns if c.startswith("frq_")] - ) - cohort_labels = sorted(ds["cohort_label"].values) - assert cohort_labels == expect_cohort_labels - - # check variants are consistent - assert ds.sizes["variants"] == len(df_af) - for v in expected_variant_vars: - c = v.split("variant_")[1] - actual = ds[v] - expect = df_af[c] - _compare_series_like(actual, expect) - - # check frequencies are consistent - for cohort_index, cohort_label in enumerate(ds["cohort_label"].values): - actual_frq = ds["event_frequency"].values[:, cohort_index] - expect_frq = df_af[f"frq_{cohort_label}"].values - assert_allclose(actual_frq, expect_frq) - - -# noinspection PyDefaultArgument -def _check_aa_allele_frequencies_advanced( - transcript="AGAP004707-RD", - area_by="admin1_iso", - period_by="year", - sample_sets=["AG1000G-BF-A", "AG1000G-ML-A", "AG1000G-UG"], - sample_query=None, - min_cohort_size=10, - nobs_mode="called", - variant_query="max_af > 0.02", -): - ag3 = setup_ag3(cohorts_analysis="20230516") - - ds = ag3.aa_allele_frequencies_advanced( - transcript=transcript, - area_by=area_by, - period_by=period_by, - sample_sets=sample_sets, - sample_query=sample_query, - min_cohort_size=min_cohort_size, - nobs_mode=nobs_mode, - variant_query=variant_query, - ) - - assert isinstance(ds, xr.Dataset) - - # noinspection PyTypeChecker - assert sorted(ds.dims) == ["cohorts", "variants"] - - expected_variant_vars = ( - "variant_label", - "variant_contig", - "variant_position", - "variant_max_af", - "variant_transcript", - "variant_effect", - "variant_impact", - "variant_ref_aa", - "variant_alt_aa", - "variant_aa_pos", - "variant_aa_change", - ) - for v in expected_variant_vars: - a = ds[v] - assert isinstance(a, xr.DataArray) - assert a.dims == ("variants",) - - expected_cohort_vars = ( - "cohort_label", - "cohort_size", - "cohort_taxon", - "cohort_area", - "cohort_period", - "cohort_period_start", - "cohort_period_end", - "cohort_lat_mean", - "cohort_lat_min", - "cohort_lat_max", - "cohort_lon_mean", - "cohort_lon_min", - "cohort_lon_max", - ) - for v in expected_cohort_vars: - a = ds[v] - assert isinstance(a, xr.DataArray) - assert a.dims == ("cohorts",) - - expected_event_vars = ( - "event_count", - "event_nobs", - "event_frequency", - "event_frequency_ci_low", - "event_frequency_ci_upp", - ) - for v in expected_event_vars: - a = ds[v] - assert isinstance(a, xr.DataArray) - assert a.dims == ("variants", "cohorts") - - # sanity checks for area values - df_samples = ag3.sample_metadata(sample_sets=sample_sets) - if sample_query is not None: - df_samples = df_samples.query(sample_query) - expected_area = np.unique(df_samples[area_by].dropna().values) - area = ds["cohort_area"].values - # N.B., some areas may not end up in final dataset if cohort - # size is too small, so do a set membership test - for a in area: - assert a in expected_area - - # sanity checks for period values - period = ds["cohort_period"].values - if period_by == "year": - expected_freqstr = "A-DEC" - elif period_by == "month": - expected_freqstr = "M" - elif period_by == "quarter": - expected_freqstr = "Q-DEC" - else: - assert False, "not implemented" - for p in period: - assert isinstance(p, pd.Period) - assert p.freqstr == expected_freqstr - - # sanity check cohort size - size = ds["cohort_size"].values - for s in size: - assert s >= min_cohort_size - - if area_by == "admin1_iso" and period_by == "year" and nobs_mode == "called": - # Here we test the behaviour of the function when grouping by admin level - # 1 and year. We can do some more in-depth testing in this case because - # we can compare results directly against the simpler aa_allele_frequencies() - # function with the admin1_year cohorts. - - # check consistency with the basic snp allele frequencies method - df_af = ag3.aa_allele_frequencies( - transcript=transcript, - cohorts="admin1_year", - sample_sets=sample_sets, - sample_query=sample_query, - min_cohort_size=min_cohort_size, - ) - df_af = df_af.reset_index() # make sure all variables available to check - if variant_query is not None: - df_af = df_af.query(variant_query) - - # check cohorts are consistent - expect_cohort_labels = sorted( - [c.split("frq_")[1] for c in df_af.columns if c.startswith("frq_")] - ) - cohort_labels = sorted(ds["cohort_label"].values) - assert cohort_labels == expect_cohort_labels - - # check variants are consistent - assert ds.sizes["variants"] == len(df_af) - for v in expected_variant_vars: - c = v.split("variant_")[1] - actual = ds[v] - expect = df_af[c] - _compare_series_like(actual, expect) - - # check frequencies are consistent - for cohort_index, cohort_label in enumerate(ds["cohort_label"].values): - print(cohort_label) - actual_frq = ds["event_frequency"].values[:, cohort_index] - expect_frq = df_af[f"frq_{cohort_label}"].values - assert_allclose(actual_frq, expect_frq) - - -# Here we don't explore the full matrix, but vary one parameter at a time, otherwise -# the test suite would take too long to run. - - -@pytest.mark.parametrize("transcript", ["AGAP004707-RD", "AGAP006028-RA"]) -def test_allele_frequencies_advanced__transcript(transcript): - _check_snp_allele_frequencies_advanced( - transcript=transcript, - ) - _check_aa_allele_frequencies_advanced( - transcript=transcript, - ) - - -@pytest.mark.parametrize("area_by", ["country", "admin1_iso", "admin2_name"]) -def test_allele_frequencies_advanced__area_by(area_by): - _check_snp_allele_frequencies_advanced( - area_by=area_by, - ) - _check_aa_allele_frequencies_advanced( - area_by=area_by, - ) - - -@pytest.mark.parametrize("period_by", ["year", "quarter", "month"]) -def test_allele_frequencies_advanced__period_by(period_by): - _check_snp_allele_frequencies_advanced( - period_by=period_by, - ) - _check_aa_allele_frequencies_advanced( - period_by=period_by, - ) - - -@pytest.mark.parametrize( - "sample_sets", ["AG1000G-BF-A", ["AG1000G-BF-A", "AG1000G-ML-A"], "3.0"] -) -def test_allele_frequencies_advanced__sample_sets(sample_sets): - _check_snp_allele_frequencies_advanced( - sample_sets=sample_sets, - ) - _check_aa_allele_frequencies_advanced( - sample_sets=sample_sets, - ) - - -@pytest.mark.parametrize( - "sample_query", - [ - "taxon in ['gambiae', 'coluzzii'] and country == 'Mali'", - "taxon == 'arabiensis' and country in ['Uganda', 'Tanzania']", - ], -) -def test_allele_frequencies_advanced__sample_query(sample_query): - _check_snp_allele_frequencies_advanced( - sample_query=sample_query, - ) - # noinspection PyTypeChecker - _check_aa_allele_frequencies_advanced( - sample_query=sample_query, - variant_query=None, - ) - - -@pytest.mark.parametrize("min_cohort_size", [10, 100]) -def test_allele_frequencies_advanced__min_cohort_size(min_cohort_size): - _check_snp_allele_frequencies_advanced( - min_cohort_size=min_cohort_size, - ) - _check_aa_allele_frequencies_advanced( - min_cohort_size=min_cohort_size, - ) - - -@pytest.mark.parametrize( - "variant_query", - [ - None, - "effect == 'NON_SYNONYMOUS_CODING' and max_af > 0.05", - "effect == 'foobar'", # no variants - ], -) -def test_allele_frequencies_advanced__variant_query(variant_query): - _check_snp_allele_frequencies_advanced( - variant_query=variant_query, - ) - _check_aa_allele_frequencies_advanced( - variant_query=variant_query, - ) - - -@pytest.mark.parametrize("nobs_mode", ["called", "fixed"]) -def test_allele_frequencies_advanced__nobs_mode(nobs_mode): - _check_snp_allele_frequencies_advanced( - nobs_mode=nobs_mode, - ) - _check_aa_allele_frequencies_advanced( - nobs_mode=nobs_mode, - ) - - # noinspection PyDefaultArgument def _check_gene_cnv_frequencies_advanced( region="2L", @@ -1303,7 +627,7 @@ def _check_gene_cnv_frequencies_advanced( c = v.split("variant_")[1] actual = ds[v] expect = df_af[c] - _compare_series_like(actual, expect) + compare_series_like(actual, expect) # check frequencies are consistent for cohort_index, cohort_label in enumerate(ds["cohort_label"].values): @@ -1435,7 +759,7 @@ def test_gene_cnv_frequencies_advanced__multi_contig_x(): for v in ds1: a = ds1[v] b = ds2[v] - _compare_series_like(a, b) + compare_series_like(a, b) def test_gene_cnv_frequencies_advanced__missing_samples(): @@ -1469,22 +793,6 @@ def test_gene_cnv_frequencies_advanced__dup_samples(): assert ds.dims == ds_dup.dims -def _compare_series_like(actual, expect): - # compare pandas series-like objects for equality or floating point - # similarity, handling missing values appropriately - - # handle object arrays, these don't get nans compared properly - t = actual.dtype - if t == object: - expect = expect.fillna("NA") - actual = actual.fillna("NA") - - if t.kind == "f": - assert_allclose(actual.values, expect.values) - else: - assert_array_equal(actual.values, expect.values) - - def test_h12_gwss(): ag3 = setup_ag3(cohorts_analysis="20230516") sample_query = "country == 'Ghana'" diff --git a/tests/test_anopheles.py b/tests/test_anopheles.py index 9b2ca7bd1..1046d0682 100644 --- a/tests/test_anopheles.py +++ b/tests/test_anopheles.py @@ -1,8 +1,6 @@ import numpy as np -import pandas as pd import pytest from numpy.testing import assert_allclose -from pandas.testing import assert_frame_equal from malariagen_data import Af1, Ag3 from malariagen_data.af1 import GCS_URL as AF1_GCS_URL @@ -43,243 +41,6 @@ def setup_subclass_cached(subclass, **kwargs): return setup_subclass(subclass, url=url, **kwargs) -@pytest.mark.parametrize( - "subclass, sample_sets, universal_fields, transcript, site_mask, cohorts_analysis, expected_snp_count", - [ - ( - Ag3, - "3.0", - [ - "pass_gamb_colu_arab", - "pass_gamb_colu", - "pass_arab", - "label", - ], - "AGAP004707-RD", - "gamb_colu", - "20211101", - 16526, - ), - ( - Af1, - "1.0", - [ - "pass_funestus", - "label", - ], - "LOC125767311_t2", - "funestus", - "20221129", - 4221, - ), - ], -) -def test_snp_allele_frequencies__str_cohorts( - subclass, - sample_sets, - universal_fields, - transcript, - site_mask, - cohorts_analysis, - expected_snp_count, -): - anoph = setup_subclass_cached(subclass, cohorts_analysis=cohorts_analysis) - - cohorts = "admin1_month" - min_cohort_size = 10 - df = anoph.snp_allele_frequencies( - transcript=transcript, - cohorts=cohorts, - min_cohort_size=min_cohort_size, - site_mask=site_mask, - sample_sets=sample_sets, - drop_invariant=True, - effects=False, - ) - df_coh = anoph.cohorts_metadata(sample_sets=sample_sets) - coh_nm = "cohort_" + cohorts - coh_counts = df_coh[coh_nm].dropna().value_counts() - cohort_labels = coh_counts[coh_counts >= min_cohort_size].index.to_list() - frq_cohort_labels = ["frq_" + s for s in cohort_labels] - expected_fields = universal_fields + frq_cohort_labels + ["max_af"] - - assert isinstance(df, pd.DataFrame) - assert sorted(df.columns.tolist()) == sorted(expected_fields) - assert df.index.names == ["contig", "position", "ref_allele", "alt_allele"] - assert len(df) == expected_snp_count - - -@pytest.mark.parametrize( - "subclass, transcript, sample_set", - [ - ( - Ag3, - "AGAP004707-RD", - "AG1000G-FR", - ), - ( - Af1, - "LOC125767311_t2", - "1229-VO-GH-DADZIE-VMF00095", - ), - ], -) -def test_snp_allele_frequencies__dup_samples( - subclass, - transcript, - sample_set, -): - # Expect automatically deduplicate any sample sets. - anoph = setup_subclass_cached(subclass) - df = anoph.snp_allele_frequencies( - transcript=transcript, - cohorts="admin1_year", - sample_sets=[sample_set], - ) - df_dup = anoph.snp_allele_frequencies( - transcript=transcript, - cohorts="admin1_year", - sample_sets=[sample_set, sample_set], - ) - assert_frame_equal(df, df_dup) - - -@pytest.mark.parametrize( - "subclass, transcript, sample_sets", - [ - ( - Ag3, - "foobar", - "3.0", - ), - ( - Af1, - "foobar", - "1.0", - ), - ], -) -def test_snp_allele_frequencies__bad_transcript( - subclass, - transcript, - sample_sets, -): - anoph = setup_subclass_cached(subclass) - with pytest.raises(ValueError): - anoph.snp_allele_frequencies( - transcript=transcript, - cohorts="admin1_year", - sample_sets=sample_sets, - ) - - -@pytest.mark.parametrize( - "subclass, cohorts_analysis, transcript, sample_set", - [ - ( - Ag3, - "20211101", - "AGAP004707-RD", - "AG1000G-FR", - ), - ( - Af1, - "20221129", - "LOC125767311_t2", - "1229-VO-GH-DADZIE-VMF00095", - ), - ], -) -def test_aa_allele_frequencies__dup_samples( - subclass, cohorts_analysis, transcript, sample_set -): - # Expect automatically deduplicate sample sets. - anoph = setup_subclass_cached(subclass=subclass, cohorts_analysis=cohorts_analysis) - df = anoph.aa_allele_frequencies( - transcript=transcript, - cohorts="admin1_year", - sample_sets=[sample_set], - ) - df_dup = anoph.aa_allele_frequencies( - transcript=transcript, - cohorts="admin1_year", - sample_sets=[sample_set, sample_set], - ) - assert_frame_equal(df, df_dup) - - -@pytest.mark.parametrize( - "subclass, cohorts_analysis, transcript, sample_set", - [ - ( - Ag3, - "20211101", - "AGAP004707-RD", - "AG1000G-FR", - ), - ( - Af1, - "20221129", - "LOC125767311_t2", - "1229-VO-GH-DADZIE-VMF00095", - ), - ], -) -def test_snp_allele_frequencies_advanced__dup_samples( - subclass, cohorts_analysis, transcript, sample_set -): - anoph = setup_subclass_cached(subclass=subclass, cohorts_analysis=cohorts_analysis) - ds = anoph.snp_allele_frequencies_advanced( - transcript=transcript, - area_by="admin1_iso", - period_by="year", - sample_sets=[sample_set], - ) - ds_dup = anoph.snp_allele_frequencies_advanced( - transcript=transcript, - area_by="admin1_iso", - period_by="year", - sample_sets=[sample_set, sample_set], - ) - assert ds.dims == ds_dup.dims - - -@pytest.mark.parametrize( - "subclass, cohorts_analysis, transcript, sample_set", - [ - ( - Ag3, - "20211101", - "AGAP004707-RD", - "AG1000G-FR", - ), - ( - Af1, - "20221129", - "LOC125767311_t2", - "1229-VO-GH-DADZIE-VMF00095", - ), - ], -) -def test_aa_allele_frequencies_advanced__dup_samples( - subclass, cohorts_analysis, transcript, sample_set -): - anoph = setup_subclass_cached(subclass=subclass, cohorts_analysis=cohorts_analysis) - ds_dup = anoph.aa_allele_frequencies_advanced( - transcript=transcript, - area_by="admin1_iso", - period_by="year", - sample_sets=[sample_set, sample_set], - ) - ds = anoph.aa_allele_frequencies_advanced( - transcript=transcript, - area_by="admin1_iso", - period_by="year", - sample_sets=[sample_set], - ) - assert ds.dims == ds_dup.dims - - def test_haplotype_frequencies(): h1 = np.array( [