From fac34695811a065275617a3390e960159f718450 Mon Sep 17 00:00:00 2001 From: Sanjay Curtis Nagi <34922269+sanjaynagi@users.noreply.github.com> Date: Wed, 20 Mar 2024 18:54:57 +0000 Subject: [PATCH] Implement diplotype clustering, refactor haplotype clustering, and fix random_region_str() (#507) * add diplotype clustering * pre commit fixes * fix remove self from args * fix not being imported * change capitalisation of DipClust class * fix * fixes2 * fix nb * add tests * precommit lint * test fix * fix * test fix * fix * fix * fix random_region_str to prevent simulated regions exceeding contig size * add sample_query to testW * remove assert distance metric line * fix * refactor hap clust * fix dip clust y axes + axes title * slighly more buffer dipclust * improve codecov * precommit * alimanfoo suggestions * small edits * fix lint --- malariagen_data/anoph/dipclust.py | 275 +++++++++++++ malariagen_data/anoph/dipclust_params.py | 24 ++ malariagen_data/anoph/hapclust.py | 271 +++++++++++++ malariagen_data/anopheles.py | 252 +----------- malariagen_data/plotly_dendrogram.py | 6 +- malariagen_data/util.py | 135 +++++++ notebooks/plot_diplotype_clustering.ipynb | 449 ++++++++++++++++++++++ tests/anoph/conftest.py | 4 + tests/anoph/test_dipclust.py | 102 +++++ tests/anoph/test_hapclust.py | 98 +++++ 10 files changed, 1366 insertions(+), 250 deletions(-) create mode 100644 malariagen_data/anoph/dipclust.py create mode 100644 malariagen_data/anoph/dipclust_params.py create mode 100644 malariagen_data/anoph/hapclust.py create mode 100644 notebooks/plot_diplotype_clustering.ipynb create mode 100644 tests/anoph/test_dipclust.py create mode 100644 tests/anoph/test_hapclust.py diff --git a/malariagen_data/anoph/dipclust.py b/malariagen_data/anoph/dipclust.py new file mode 100644 index 000000000..a8173958c --- /dev/null +++ b/malariagen_data/anoph/dipclust.py @@ -0,0 +1,275 @@ +from typing import Optional, Tuple + +import allel # type: ignore +import numpy as np +from numpydoc_decorator import doc # type: ignore + +from ..util import ( + CacheMiss, + check_types, + multiallelic_diplotype_pdist, + multiallelic_diplotype_mean_sqeuclidean, + multiallelic_diplotype_mean_cityblock, +) +from ..plotly_dendrogram import plot_dendrogram +from . import base_params, plotly_params, tree_params, dipclust_params +from .base_params import DEFAULT +from .snp_data import AnophelesSnpData + + +class AnophelesDipClustAnalysis( + AnophelesSnpData, +): + def __init__( + self, + **kwargs, + ): + # N.B., this class is designed to work cooperatively, and + # so it's important that any remaining parameters are passed + # to the superclass constructor. + super().__init__(**kwargs) + + @check_types + @doc( + summary="Hierarchically cluster diplotypes in region and produce an interactive plot.", + parameters=dict( + leaf_y="Y coordinate at which to plot the leaf markers.", + ), + ) + def plot_diplotype_clustering( + self, + region: base_params.regions, + site_mask: base_params.site_mask = DEFAULT, + sample_sets: Optional[base_params.sample_sets] = None, + sample_query: Optional[base_params.sample_query] = None, + cohort_size: Optional[base_params.cohort_size] = None, + random_seed: base_params.random_seed = 42, + color: plotly_params.color = None, + symbol: plotly_params.symbol = None, + linkage_method: dipclust_params.linkage_method = dipclust_params.linkage_method_default, + distance_metric: dipclust_params.distance_metric = dipclust_params.distance_metric_default, + count_sort: Optional[tree_params.count_sort] = None, + distance_sort: Optional[tree_params.distance_sort] = None, + title: plotly_params.title = True, + title_font_size: plotly_params.title_font_size = 14, + width: plotly_params.width = None, + height: plotly_params.height = 500, + show: plotly_params.show = True, + renderer: plotly_params.renderer = None, + render_mode: plotly_params.render_mode = "svg", + leaf_y: int = 0, + marker_size: plotly_params.marker_size = 5, + line_width: plotly_params.line_width = 0.5, + line_color: plotly_params.line_color = "black", + color_discrete_sequence: plotly_params.color_discrete_sequence = None, + color_discrete_map: plotly_params.color_discrete_map = None, + category_orders: plotly_params.category_order = None, + legend_sizing: plotly_params.legend_sizing = "constant", + ) -> plotly_params.figure: + import sys + + debug = self._log.debug + + # Normalise params. + if count_sort is None and distance_sort is None: + count_sort = True + distance_sort = False + + # This is needed to avoid RecursionError on some haplotype clustering analyses + # with larger numbers of haplotypes. + sys.setrecursionlimit(10_000) + + debug("load sample metadata") + df_samples = self.sample_metadata( + sample_sets=sample_sets, sample_query=sample_query + ) + + dist, gt_samples, n_snps_used = self.diplotype_pairwise_distances( + region=region, + site_mask=site_mask, + sample_sets=sample_sets, + sample_query=sample_query, + cohort_size=cohort_size, + distance_metric=distance_metric, + random_seed=random_seed, + ) + + # Align sample metadata with genotypes. + df_samples = ( + df_samples.set_index("sample_id").loc[gt_samples.tolist()].reset_index() + ) + + # Normalise color and symbol parameters. + symbol_prepped = self._setup_sample_symbol( + data=df_samples, + symbol=symbol, + ) + del symbol + ( + color_prepped, + color_discrete_map_prepped, + category_orders_prepped, + ) = self._setup_sample_colors_plotly( + data=df_samples, + color=color, + color_discrete_map=color_discrete_map, + color_discrete_sequence=color_discrete_sequence, + category_orders=category_orders, + ) + del color + del color_discrete_map + del color_discrete_sequence + + # Configure hover data. + hover_data = self._setup_sample_hover_data_plotly( + color=color_prepped, symbol=symbol_prepped + ) + + # Construct plot title. + if title is True: + title_lines = [] + if sample_sets is not None: + title_lines.append(f"Sample sets: {sample_sets}") + if sample_query is not None: + title_lines.append(f"Sample query: {sample_query}") + title_lines.append(f"Genomic region: {region} ({n_snps_used:,} SNPs)") + title = "
".join(title_lines) + + # Create the plot. + with self._spinner("Plot dendrogram"): + fig = plot_dendrogram( + dist=dist, + linkage_method=linkage_method, + count_sort=count_sort, + distance_sort=distance_sort, + render_mode=render_mode, + width=width, + height=height, + title=title, + line_width=line_width, + line_color=line_color, + marker_size=marker_size, + leaf_data=df_samples, + leaf_hover_name="sample_id", + leaf_hover_data=hover_data, + leaf_color=color_prepped, + leaf_symbol=symbol_prepped, + leaf_y=leaf_y, + leaf_color_discrete_map=color_discrete_map_prepped, + leaf_category_orders=category_orders_prepped, + template="simple_white", + y_axis_title=f"Distance ({distance_metric})", + y_axis_buffer=0.1, + ) + + # Tidy up. + fig.update_layout( + title_font=dict( + size=title_font_size, + ), + legend=dict(itemsizing=legend_sizing, tracegroupgap=0), + ) + + if show: # pragma: no cover + fig.show(renderer=renderer) + return None + else: + return fig + + def diplotype_pairwise_distances( + self, + region: base_params.regions, + site_mask: base_params.site_mask = DEFAULT, + sample_sets: Optional[base_params.sample_sets] = None, + sample_query: Optional[base_params.sample_query] = None, + site_class: Optional[base_params.site_class] = None, + cohort_size: Optional[base_params.cohort_size] = None, + distance_metric: dipclust_params.distance_metric = dipclust_params.distance_metric_default, + random_seed: base_params.random_seed = 42, + ) -> Tuple[np.ndarray, np.ndarray, int]: + # Change this name if you ever change the behaviour of this function, to + # invalidate any previously cached data. + name = "diplotype_pairwise_distances_v1" + + # Normalize params for consistent hash value. + sample_sets_prepped = self._prep_sample_sets_param(sample_sets=sample_sets) + region_prepped = self._prep_region_cache_param(region=region) + params = dict( + region=region_prepped, + site_mask=site_mask, + sample_sets=sample_sets_prepped, + sample_query=sample_query, + site_class=site_class, + cohort_size=cohort_size, + distance_metric=distance_metric, + random_seed=random_seed, + ) + + # Try to retrieve results from the cache. + try: + results = self.results_cache_get(name=name, params=params) + + except CacheMiss: + results = self._diplotype_pairwise_distances(**params) + self.results_cache_set(name=name, params=params, results=results) + + # Unpack results") + dist: np.ndarray = results["dist"] + gt_samples: np.ndarray = results["gt_samples"] + n_snps: int = int(results["n_snps"][()]) # ensure scalar + + return dist, gt_samples, n_snps + + def _diplotype_pairwise_distances( + self, + *, + region, + site_mask, + sample_sets, + sample_query, + site_class, + cohort_size, + distance_metric, + random_seed, + ): + if distance_metric == "cityblock": + metric = multiallelic_diplotype_mean_cityblock + elif distance_metric == "euclidean": + metric = multiallelic_diplotype_mean_sqeuclidean + + # Load haplotypes. + ds_snps = self.snp_calls( + region=region, + sample_query=sample_query, + sample_sets=sample_sets, + site_mask=site_mask, + site_class=site_class, + cohort_size=cohort_size, + random_seed=random_seed, + ) + + with self._dask_progress(desc="Load genotypes"): + gt = ds_snps["call_genotype"].data.compute() + + with self._spinner( + desc="Compute allele counts and remove non-segregating sites" + ): + # Compute allele count, remove non-segregating sites. + ac = allel.GenotypeArray(gt).count_alleles(max_allele=3) + gt_seg = gt.compress(ac.is_segregating(), axis=0) + ac_seg = allel.GenotypeArray(gt_seg).to_allele_counts(max_allele=3) + X = np.ascontiguousarray(np.swapaxes(ac_seg.values, 0, 1)) + + # Compute pairwise distances. + with self._spinner(desc="Compute pairwise distances"): + dist = multiallelic_diplotype_pdist(X, metric=metric) + + # Extract IDs of samples. Convert to "U" dtype here + # to allow these to be saved to the results cache. + gt_samples = ds_snps["sample_id"].values.astype("U") + + return dict( + dist=dist, + gt_samples=gt_samples, + n_snps=np.array(gt_seg.shape[0]), + ) diff --git a/malariagen_data/anoph/dipclust_params.py b/malariagen_data/anoph/dipclust_params.py new file mode 100644 index 000000000..d55436b5d --- /dev/null +++ b/malariagen_data/anoph/dipclust_params.py @@ -0,0 +1,24 @@ +"""Parameters for diplotype clustering functions.""" + +from typing import Literal + +from typing_extensions import Annotated, TypeAlias + +linkage_method: TypeAlias = Annotated[ + Literal["single", "complete", "average", "weighted", "centroid", "median", "ward"], + """ + The linkage algorithm to use. See the Linkage Methods section of the + scipy.cluster.hierarchy.linkage docs for full descriptions. + """, +] + +linkage_method_default: linkage_method = "complete" + +distance_metric: TypeAlias = Annotated[ + Literal["cityblock", "euclidean"], + """ + The distance metric to use. Either "cityblock" or "euclidean". + """, +] + +distance_metric_default: distance_metric = "cityblock" diff --git a/malariagen_data/anoph/hapclust.py b/malariagen_data/anoph/hapclust.py new file mode 100644 index 000000000..fa3bc2ea2 --- /dev/null +++ b/malariagen_data/anoph/hapclust.py @@ -0,0 +1,271 @@ +from typing import Optional, Tuple + +import allel # type: ignore +import numpy as np +import pandas as pd +from numpydoc_decorator import doc # type: ignore + +from ..util import CacheMiss, check_types, pdist_abs_hamming +from ..plotly_dendrogram import plot_dendrogram +from . import base_params, plotly_params, tree_params, hap_params, hapclust_params +from .base_params import DEFAULT +from .snp_data import AnophelesSnpData +from .hap_data import AnophelesHapData + + +class AnophelesHapClustAnalysis(AnophelesHapData, AnophelesSnpData): + def __init__( + self, + **kwargs, + ): + # N.B., this class is designed to work cooperatively, and + # so it's important that any remaining parameters are passed + # to the superclass constructor. + super().__init__(**kwargs) + + @check_types + @doc( + summary=""" + Hierarchically cluster haplotypes in region and produce an interactive plot. + """, + parameters=dict( + leaf_y="Y coordinate at which to plot the leaf markers.", + ), + ) + def plot_haplotype_clustering( + self, + region: base_params.regions, + analysis: hap_params.analysis = DEFAULT, + sample_sets: Optional[base_params.sample_sets] = None, + sample_query: Optional[base_params.sample_query] = None, + cohort_size: Optional[base_params.cohort_size] = None, + random_seed: base_params.random_seed = 42, + color: plotly_params.color = None, + symbol: plotly_params.symbol = None, + linkage_method: hapclust_params.linkage_method = hapclust_params.linkage_method_default, + count_sort: Optional[tree_params.count_sort] = None, + distance_sort: Optional[tree_params.distance_sort] = None, + title: plotly_params.title = True, + title_font_size: plotly_params.title_font_size = 14, + width: plotly_params.width = None, + height: plotly_params.height = 500, + show: plotly_params.show = True, + renderer: plotly_params.renderer = None, + render_mode: plotly_params.render_mode = "svg", + leaf_y: int = 0, + marker_size: plotly_params.marker_size = 5, + line_width: plotly_params.line_width = 0.5, + line_color: plotly_params.line_color = "black", + color_discrete_sequence: plotly_params.color_discrete_sequence = None, + color_discrete_map: plotly_params.color_discrete_map = None, + category_orders: plotly_params.category_order = None, + legend_sizing: plotly_params.legend_sizing = "constant", + ) -> plotly_params.figure: + import sys + + # Normalise params. + if count_sort is None and distance_sort is None: + count_sort = True + distance_sort = False + + # This is needed to avoid RecursionError on some haplotype clustering analyses + # with larger numbers of haplotypes. + sys.setrecursionlimit(10_000) + + # Load sample metadata. + df_samples = self.sample_metadata( + sample_sets=sample_sets, sample_query=sample_query + ) + + # Compute pairwise distances. + dist, phased_samples, n_snps_used = self.haplotype_pairwise_distances( + region=region, + analysis=analysis, + sample_sets=sample_sets, + sample_query=sample_query, + cohort_size=cohort_size, + random_seed=random_seed, + ) + + # Align sample metadata with haplotypes. + df_samples_phased = ( + df_samples.set_index("sample_id").loc[phased_samples.tolist()].reset_index() + ) + + # Normalise color and symbol parameters. + symbol_prepped = self._setup_sample_symbol( + data=df_samples_phased, + symbol=symbol, + ) + del symbol + ( + color_prepped, + color_discrete_map_prepped, + category_orders_prepped, + ) = self._setup_sample_colors_plotly( + data=df_samples_phased, + color=color, + color_discrete_map=color_discrete_map, + color_discrete_sequence=color_discrete_sequence, + category_orders=category_orders, + ) + del color + del color_discrete_map + del color_discrete_sequence + + # Repeat the dataframe so there is one row of metadata for each haplotype. + df_haps = pd.DataFrame(np.repeat(df_samples_phased.values, 2, axis=0)) + df_haps.columns = df_samples_phased.columns + + # Configure hover data. + hover_data = self._setup_sample_hover_data_plotly( + color=color_prepped, symbol=symbol_prepped + ) + + # Construct plot title. + if title is True: + title_lines = [] + if sample_sets is not None: + title_lines.append(f"Sample sets: {sample_sets}") + if sample_query is not None: + title_lines.append(f"Sample query: {sample_query}") + title_lines.append(f"Genomic region: {region} ({n_snps_used:,} SNPs)") + title = "
".join(title_lines) + + # Create the plot. + with self._spinner("Plot dendrogram"): + fig = plot_dendrogram( + dist=dist, + linkage_method=linkage_method, + count_sort=count_sort, + distance_sort=distance_sort, + render_mode=render_mode, + width=width, + height=height, + title=title, + line_width=line_width, + line_color=line_color, + marker_size=marker_size, + leaf_data=df_haps, + leaf_hover_name="sample_id", + leaf_hover_data=hover_data, + leaf_color=color_prepped, + leaf_symbol=symbol_prepped, + leaf_y=leaf_y, + leaf_color_discrete_map=color_discrete_map_prepped, + leaf_category_orders=category_orders_prepped, + template="simple_white", + y_axis_title="Distance (no. SNPs)", + y_axis_buffer=1, + ) + + # Tidy up. + fig.update_layout( + title_font=dict( + size=title_font_size, + ), + legend=dict(itemsizing=legend_sizing, tracegroupgap=0), + ) + + if show: # pragma: no cover + fig.show(renderer=renderer) + return None + else: + return fig + + @doc( + summary=""" + Compute pairwise distances between haplotypes. + """, + returns=dict( + dist="Pairwise distance.", + phased_samples="Sample identifiers for haplotypes.", + n_snps="Number of SNPs used.", + ), + ) + def haplotype_pairwise_distances( + self, + region: base_params.regions, + analysis: hap_params.analysis = DEFAULT, + sample_sets: Optional[base_params.sample_sets] = None, + sample_query: Optional[base_params.sample_query] = None, + cohort_size: Optional[base_params.cohort_size] = None, + random_seed: base_params.random_seed = 42, + ) -> Tuple[np.ndarray, np.ndarray, int]: + # Change this name if you ever change the behaviour of this function, to + # invalidate any previously cached data. + name = "haplotype_pairwise_distances" + + # Normalize params for consistent hash value. + sample_sets_prepped = self._prep_sample_sets_param(sample_sets=sample_sets) + region_prepped = self._prep_region_cache_param(region=region) + params = dict( + region=region_prepped, + analysis=analysis, + sample_sets=sample_sets_prepped, + sample_query=sample_query, + cohort_size=cohort_size, + random_seed=random_seed, + ) + + # Try to retrieve results from the cache. + try: + results = self.results_cache_get(name=name, params=params) + + except CacheMiss: + results = self._haplotype_pairwise_distances(**params) + self.results_cache_set(name=name, params=params, results=results) + + # Unpack results") + dist: np.ndarray = results["dist"] + phased_samples: np.ndarray = results["phased_samples"] + n_snps: int = int(results["n_snps"][()]) # ensure scalar + + return dist, phased_samples, n_snps + + def _haplotype_pairwise_distances( + self, + *, + region, + analysis, + sample_sets, + sample_query, + cohort_size, + random_seed, + ): + from scipy.spatial.distance import squareform # type: ignore + + # Load haplotypes. + ds_haps = self.haplotypes( + region=region, + analysis=analysis, + sample_query=sample_query, + sample_sets=sample_sets, + cohort_size=cohort_size, + random_seed=random_seed, + ) + gt = allel.GenotypeDaskArray(ds_haps["call_genotype"].data) + with self._dask_progress(desc="Load haplotypes"): + ht = gt.to_haplotypes().compute().values + + # Compute allele count, remove non-segregating sites. + ac = allel.HaplotypeArray(ht).count_alleles(max_allele=1) + ht_seg = ht[ac.is_segregating()] + + # Transpose memory layout for faster hamming distance calculations. + ht_t = np.ascontiguousarray(ht_seg.T) + + # Compute pairwise distances. + with self._spinner(desc="Compute pairwise distances"): + dist_sq = pdist_abs_hamming(ht_t) + dist = squareform(dist_sq) + + # Extract IDs of phased samples. Convert to "U" dtype here + # to allow these to be saved to the results cache. + phased_samples = ds_haps["sample_id"].values.astype("U") + + return dict( + dist=dist, + phased_samples=phased_samples, + n_snps=np.array(ht.shape[0]), + ) diff --git a/malariagen_data/anopheles.py b/malariagen_data/anopheles.py index 5fb8ddb70..1962031e8 100644 --- a/malariagen_data/anopheles.py +++ b/malariagen_data/anopheles.py @@ -32,7 +32,6 @@ dash_params, diplotype_distance_params, gplt_params, - hapclust_params, hapnet_params, het_params, ihs_params, @@ -55,6 +54,8 @@ from .anoph.h12 import AnophelesH12Analysis from .anoph.h1x import AnophelesH1XAnalysis from .mjn import median_joining_network, mjn_graph +from .anoph.hapclust import AnophelesHapClustAnalysis +from .anoph.dipclust import AnophelesDipClustAnalysis from .util import ( CacheMiss, Region, @@ -66,7 +67,6 @@ jackknife_ci, parse_multi_region, parse_single_region, - pdist_abs_hamming, plotly_discrete_legend, region_str, simple_xarray_concat, @@ -98,6 +98,8 @@ # work around pycharm failing to recognise that doc() is callable # noinspection PyCallingNonCallable class AnophelesDataResource( + AnophelesDipClustAnalysis, + AnophelesHapClustAnalysis, AnophelesH1XAnalysis, AnophelesH12Analysis, AnophelesG123Analysis, @@ -2632,252 +2634,6 @@ def plot_xpehh_gwss_track( else: return fig - @doc( - summary=""" - Compute pairwise distances between haplotypes. - """, - returns=dict( - dist="Pairwise distance.", - phased_samples="Sample identifiers for haplotypes.", - n_snps="Number of SNPs used.", - ), - ) - def haplotype_pairwise_distances( - self, - region: base_params.regions, - analysis: hap_params.analysis = DEFAULT, - sample_sets: Optional[base_params.sample_sets] = None, - sample_query: Optional[base_params.sample_query] = None, - cohort_size: Optional[base_params.cohort_size] = None, - random_seed: base_params.random_seed = 42, - ) -> Tuple[np.ndarray, np.ndarray, int]: - # Change this name if you ever change the behaviour of this function, to - # invalidate any previously cached data. - name = "haplotype_pairwise_distances" - - # Normalize params for consistent hash value. - sample_sets_prepped = self._prep_sample_sets_param(sample_sets=sample_sets) - region_prepped = self._prep_region_cache_param(region=region) - params = dict( - region=region_prepped, - analysis=analysis, - sample_sets=sample_sets_prepped, - sample_query=sample_query, - cohort_size=cohort_size, - random_seed=random_seed, - ) - - # Try to retrieve results from the cache. - try: - results = self.results_cache_get(name=name, params=params) - - except CacheMiss: - results = self._haplotype_pairwise_distances(**params) - self.results_cache_set(name=name, params=params, results=results) - - # Unpack results") - dist: np.ndarray = results["dist"] - phased_samples: np.ndarray = results["phased_samples"] - n_snps: int = int(results["n_snps"][()]) # ensure scalar - - return dist, phased_samples, n_snps - - def _haplotype_pairwise_distances( - self, - *, - region, - analysis, - sample_sets, - sample_query, - cohort_size, - random_seed, - ): - from scipy.spatial.distance import squareform # type: ignore - - # Load haplotypes. - ds_haps = self.haplotypes( - region=region, - analysis=analysis, - sample_query=sample_query, - sample_sets=sample_sets, - cohort_size=cohort_size, - random_seed=random_seed, - ) - gt = allel.GenotypeDaskArray(ds_haps["call_genotype"].data) - with self._dask_progress(desc="Load haplotypes"): - ht = gt.to_haplotypes().compute().values - - # Compute allele count, remove non-segregating sites. - ac = allel.HaplotypeArray(ht).count_alleles(max_allele=1) - ht_seg = ht[ac.is_segregating()] - - # Transpose memory layout for faster hamming distance calculations. - ht_t = np.ascontiguousarray(ht_seg.T) - - # Compute pairwise distances. - with self._spinner(desc="Compute pairwise distances"): - dist_sq = pdist_abs_hamming(ht_t) - dist = squareform(dist_sq) - - # Extract IDs of phased samples. Convert to "U" dtype here - # to allow these to be saved to the results cache. - phased_samples = ds_haps["sample_id"].values.astype("U") - - return dict( - dist=dist, - phased_samples=phased_samples, - n_snps=np.array(ht.shape[0]), - ) - - @doc( - summary=""" - Hierarchically cluster haplotypes in region and produce an interactive plot. - """, - parameters=dict( - leaf_y="Y coordinate at which to plot the leaf markers.", - ), - ) - def plot_haplotype_clustering( - self, - region: base_params.regions, - analysis: hap_params.analysis = DEFAULT, - sample_sets: Optional[base_params.sample_sets] = None, - sample_query: Optional[base_params.sample_query] = None, - cohort_size: Optional[base_params.cohort_size] = None, - random_seed: base_params.random_seed = 42, - color: plotly_params.color = None, - symbol: plotly_params.symbol = None, - linkage_method: hapclust_params.linkage_method = hapclust_params.linkage_method_default, - count_sort: Optional[tree_params.count_sort] = None, - distance_sort: Optional[tree_params.distance_sort] = None, - title: plotly_params.title = True, - title_font_size: plotly_params.title_font_size = 14, - width: plotly_params.width = None, - height: plotly_params.height = 500, - show: plotly_params.show = True, - renderer: plotly_params.renderer = None, - render_mode: plotly_params.render_mode = "svg", - leaf_y: int = 0, - marker_size: plotly_params.marker_size = 5, - line_width: plotly_params.line_width = 0.5, - line_color: plotly_params.line_color = "black", - color_discrete_sequence: plotly_params.color_discrete_sequence = None, - color_discrete_map: plotly_params.color_discrete_map = None, - category_orders: plotly_params.category_order = None, - legend_sizing: plotly_params.legend_sizing = "constant", - ) -> plotly_params.figure: - import sys - - from .plotly_dendrogram import plot_dendrogram - - # Normalise params. - if count_sort is None and distance_sort is None: - count_sort = True - distance_sort = False - - # This is needed to avoid RecursionError on some haplotype clustering analyses - # with larger numbers of haplotypes. - sys.setrecursionlimit(10_000) - - # Load sample metadata. - df_samples = self.sample_metadata( - sample_sets=sample_sets, sample_query=sample_query - ) - - # Compute pairwise distances. - dist, phased_samples, n_snps_used = self.haplotype_pairwise_distances( - region=region, - analysis=analysis, - sample_sets=sample_sets, - sample_query=sample_query, - cohort_size=cohort_size, - random_seed=random_seed, - ) - - # Align sample metadata with haplotypes. - df_samples_phased = ( - df_samples.set_index("sample_id").loc[phased_samples.tolist()].reset_index() - ) - - # Normalise color and symbol parameters. - symbol_prepped = self._setup_sample_symbol( - data=df_samples_phased, - symbol=symbol, - ) - del symbol - ( - color_prepped, - color_discrete_map_prepped, - category_orders_prepped, - ) = self._setup_sample_colors_plotly( - data=df_samples_phased, - color=color, - color_discrete_map=color_discrete_map, - color_discrete_sequence=color_discrete_sequence, - category_orders=category_orders, - ) - del color - del color_discrete_map - del color_discrete_sequence - - # Repeat the dataframe so there is one row of metadata for each haplotype. - df_haps = pd.DataFrame(np.repeat(df_samples_phased.values, 2, axis=0)) - df_haps.columns = df_samples_phased.columns - - # Configure hover data. - hover_data = self._setup_sample_hover_data_plotly( - color=color_prepped, symbol=symbol_prepped - ) - - # Construct plot title. - if title is True: - title_lines = [] - if sample_sets is not None: - title_lines.append(f"Sample sets: {sample_sets}") - if sample_query is not None: - title_lines.append(f"Sample query: {sample_query}") - title_lines.append(f"Genomic region: {region} ({n_snps_used:,} SNPs)") - title = "
".join(title_lines) - - # Create the plot. - with self._spinner("Plot dendrogram"): - fig = plot_dendrogram( - dist=dist, - linkage_method=linkage_method, - count_sort=count_sort, - distance_sort=distance_sort, - render_mode=render_mode, - width=width, - height=height, - title=title, - line_width=line_width, - line_color=line_color, - marker_size=marker_size, - leaf_data=df_haps, - leaf_hover_name="sample_id", - leaf_hover_data=hover_data, - leaf_color=color_prepped, - leaf_symbol=symbol_prepped, - leaf_y=leaf_y, - leaf_color_discrete_map=color_discrete_map_prepped, - leaf_category_orders=category_orders_prepped, - template="simple_white", - ) - - # Tidy up. - fig.update_layout( - title_font=dict( - size=title_font_size, - ), - legend=dict(itemsizing=legend_sizing, tracegroupgap=0), - ) - - if show: # pragma: no cover - fig.show(renderer=renderer) - return None - else: - return fig - @check_types @doc( summary=""" diff --git a/malariagen_data/plotly_dendrogram.py b/malariagen_data/plotly_dendrogram.py index 20dcbc3b2..4c1bca06f 100644 --- a/malariagen_data/plotly_dendrogram.py +++ b/malariagen_data/plotly_dendrogram.py @@ -25,6 +25,8 @@ def plot_dendrogram( leaf_color_discrete_map, leaf_category_orders, template, + y_axis_title, + y_axis_buffer, ): # Hierarchical clustering. Z = sch.linkage(dist, method=linkage_method) @@ -105,7 +107,7 @@ def plot_dendrogram( # it's above the plot it often overlaps the title, so hiding it # for now. xaxis_title=None, - yaxis_title="Distance (no. SNPs)", + yaxis_title=y_axis_title, showlegend=True, ) @@ -124,7 +126,7 @@ def plot_dendrogram( showline=False, showticklabels=True, ticks="outside", - range=(leaf_y - 1, np.max(dcoord) + 1), + range=(leaf_y - y_axis_buffer, np.max(dcoord) + y_axis_buffer), ) return fig diff --git a/malariagen_data/util.py b/malariagen_data/util.py index ceb5e440e..35360c29e 100644 --- a/malariagen_data/util.py +++ b/malariagen_data/util.py @@ -1077,6 +1077,141 @@ def biallelic_diplotype_euclidean(x, y): return np.sqrt(biallelic_diplotype_sqeuclidean(x, y)) +@numba.njit(parallel=True) +def multiallelic_diplotype_pdist(X, metric): + """Optimised implementation of pairwise distance between diplotypes. + + N.B., here we assume the array X provides diplotypes as genotype allele + counts, with axes in the order (n_samples, n_sites, n_alleles). + + Computation will be faster if X is a contiguous (C order) array. + + The metric argument is the function to compute distance for a pair of + diplotypes. This can be a numba jitted function. + + """ + n_samples = X.shape[0] + n_pairs = (n_samples * (n_samples - 1)) // 2 + out = np.zeros(n_pairs, dtype=np.float32) + + # Loop over samples, first in pair. + for i in range(n_samples): + x = X[i, :, :] + + # Loop over observations again, second in pair. + for j in numba.prange(i + 1, n_samples): + y = X[j, :, :] + + # Compute distance for the current pair. + d = metric(x, y) + + # Store result for the current pair. + k = square_to_condensed(i, j, n_samples) + out[k] = d + + return out + + +@numba.njit +def multiallelic_diplotype_mean_cityblock(x, y): + """Compute the mean cityblock distance between two diplotypes x and y. The + diplotype vectors are expected as genotype allele counts, i.e., x and y + should have the same shape (n_sites, n_alleles). + + N.B., here we compute the mean value of the distance over sites where + both individuals have a called genotype. This avoids computing distance + at missing sites. + + """ + n_sites = x.shape[0] + n_alleles = x.shape[1] + distance = np.float32(0) + n_sites_called = np.float32(0) + + # Loop over sites. + for i in range(n_sites): + x_is_called = False + y_is_called = False + d = np.float32(0) + + # Loop over alleles. + for j in range(n_alleles): + # Access allele counts. + xc = np.float32(x[i, j]) + yc = np.float32(y[i, j]) + + # Check if any alleles observed. + x_is_called = x_is_called or (xc > 0) + y_is_called = y_is_called or (yc > 0) + + # Compute cityblock distance (absolute difference). + d += np.fabs(xc - yc) + + # Accumulate distance for the current pair, but only if both samples + # have a called genotype. + if x_is_called and y_is_called: + distance += d + n_sites_called += np.float32(1) + + # Compute the mean distance over sites with called genotypes. + if n_sites_called > 0: + mean_distance = distance / n_sites_called + else: + mean_distance = np.nan + + return mean_distance + + +@numba.njit +def multiallelic_diplotype_mean_sqeuclidean(x, y): + """Compute the mean squared euclidean distance between two diplotypes x and + y. The diplotype vectors are expected as genotype allele counts, i.e., x and + y should have the same shape (n_sites, n_alleles). + + N.B., here we compute the mean value of the distance over sites where + both individuals have a called genotype. This avoids computing distance + at missing sites. + + """ + n_sites = x.shape[0] + n_alleles = x.shape[1] + distance = np.float32(0) + n_sites_called = np.float32(0) + + # Loop over sites. + for i in range(n_sites): + x_is_called = False + y_is_called = False + d = np.float32(0) + + # Loop over alleles. + for j in range(n_alleles): + # Access allele counts. + xc = np.float32(x[i, j]) + yc = np.float32(y[i, j]) + + # Check if any alleles observed. + x_is_called = x_is_called or (xc > 0) + y_is_called = y_is_called or (yc > 0) + + # Compute squared euclidean distance. + d += (xc - yc) ** 2 + + # Accumulate distance for the current pair, but only if both samples + # have a called genotype. + if x_is_called and y_is_called: + distance += d + n_sites_called += np.float32(1) + + # Compute the mean distance over sites with called genotypes. + if n_sites_called > 0: + mean_distance = distance / n_sites_called + else: + mean_distance = np.nan + + return mean_distance + + @numba.njit def trim_alleles(ac): """Remap allele indices to trim out unobserved alleles. diff --git a/notebooks/plot_diplotype_clustering.ipynb b/notebooks/plot_diplotype_clustering.ipynb new file mode 100644 index 000000000..b34fe416b --- /dev/null +++ b/notebooks/plot_diplotype_clustering.ipynb @@ -0,0 +1,449 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "d3f73666", + "metadata": {}, + "outputs": [], + "source": [ + "import malariagen_data" + ] + }, + { + "cell_type": "markdown", + "id": "0b889fd3", + "metadata": {}, + "source": [ + "## Ag3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38abbb28", + "metadata": {}, + "outputs": [], + "source": [ + "ag3 = malariagen_data.Ag3(\n", + " \"simplecache::gs://vo_agam_release\",\n", + " simplecache=dict(cache_storage=\"../gcs_cache\"),\n", + " results_cache=\"results_cache\",\n", + " pre=True,\n", + ")\n", + "ag3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "60bd6d1f-e8bd-499e-8c78-e64d2d7e9210", + "metadata": {}, + "outputs": [], + "source": [ + "# %%time\n", + "# ag3.plot_haplotype_clustering(\n", + "# region=\"2L:2,410,000-2,430,000\",\n", + "# sample_sets=[\"3.0\", \"3.1\", \"3.2\", \"3.3\", \"3.4\", \"3.5\", \"3.6\", \"3.7\"],\n", + "# site_mask=\"gamb_colu\",\n", + "# color=\"taxon\",\n", + "# symbol=\"country\",\n", + "# linkage_method=\"single\",\n", + "# width=1000,\n", + "# height=500,\n", + "# count_sort=True,\n", + "# distance_sort=False,\n", + "# render_mode=\"auto\",\n", + "# show=False,\n", + "# );" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "04e74ac9-51db-4bae-875c-1ee50e02d024", + "metadata": {}, + "outputs": [], + "source": [ + "# ag3.plot_haplotype_clustering(\n", + "# region=\"2L:2,410,000-2,430,000\",\n", + "# sample_sets=[\"3.0\", \"3.1\", \"3.2\", \"3.3\", \"3.4\", \"3.5\", \"3.6\", \"3.7\"],\n", + "# site_mask=\"gamb_colu\",\n", + "# color=\"country\",\n", + "# linkage_method=\"single\",\n", + "# width=None,\n", + "# height=500,\n", + "# count_sort=True,\n", + "# )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "637ae668-c3b8-4914-a0c2-1587cc252c0c", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_diplotype_clustering(\n", + " region=\"2L:2,410,000-2,430,000\",\n", + " sample_sets=[\"AG1000G-GH\", \"AG1000G-BF-B\"],\n", + " site_mask=\"gamb_colu\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9bca4ab4-89b4-4238-9f82-d7920e107583", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_diplotype_clustering(\n", + " region=\"2L:2,410,000-2,430,000\",\n", + " sample_sets=[\"AG1000G-GH\", \"AG1000G-BF-B\"],\n", + " site_mask=\"gamb_colu\",\n", + " color=\"taxon\",\n", + " symbol=\"country\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ccb2366f", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_diplotype_clustering(\n", + " region=\"2L:2,410,000-2,430,000\",\n", + " sample_sets=[\"AG1000G-GH\", \"AG1000G-BF-B\"],\n", + " site_mask=\"gamb_colu\",\n", + " color=\"taxon\",\n", + " symbol=\"country\",\n", + " linkage_method=\"complete\",\n", + " width=1000,\n", + " height=500,\n", + " count_sort=True,\n", + " distance_sort=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0420334-124d-4bab-9ef7-bb9ceb99d6f0", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_diplotype_clustering(\n", + " region=\"2L:2,410,000-2,430,000\",\n", + " sample_sets=[\"AG1000G-GH\", \"AG1000G-BF-B\"],\n", + " site_mask=\"gamb_colu\",\n", + " color=\"taxon\",\n", + " symbol=\"country\",\n", + " linkage_method=\"complete\",\n", + " width=1000,\n", + " height=500,\n", + " count_sort=True,\n", + " distance_sort=False,\n", + " render_mode=\"svg\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e957de19-af70-412a-8c75-2e6835ffb2cd", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_diplotype_clustering(\n", + " region=\"2L:2,410,000-2,430,000\",\n", + " sample_sets=[\"AG1000G-GH\", \"AG1000G-BF-B\"],\n", + " site_mask=\"gamb_colu\",\n", + " color=\"country\",\n", + " symbol=\"taxon\",\n", + " linkage_method=\"single\",\n", + " width=1000,\n", + " height=500,\n", + " count_sort=True,\n", + " distance_sort=False,\n", + " render_mode=\"webgl\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "857145af", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_diplotype_clustering(\n", + " region=\"2L:28,545,000-28,550,000\",\n", + " sample_sets=[\"AG1000G-GH\", \"1244-VO-GH-YAWSON-VMF00051\"],\n", + " site_mask=\"gamb_colu\",\n", + " color=\"admin1_name\",\n", + " symbol=\"taxon\",\n", + " cohort_size=500,\n", + " linkage_method=\"weighted\",\n", + " distance_metric=\"euclidean\",\n", + " count_sort=True,\n", + " width=1200,\n", + " height=600,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f41b7e63", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_diplotype_clustering(\n", + " region=\"2R:28,480,000-28,490,000\",\n", + " sample_sets=[\"3.0\"],\n", + " sample_query=\"taxon == 'arabiensis'\",\n", + " site_mask=\"gamb_colu_arab\",\n", + " color=\"sample_set\",\n", + " cohort_size=None,\n", + " width=1000,\n", + " height=400,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b93d6d94-1add-4ba4-8618-87a11c843e6e", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_diplotype_clustering(\n", + " region=\"2R:28,480,000-28,490,000\",\n", + " sample_sets=[\"3.0\"],\n", + " sample_query=\"taxon == 'arabiensis'\",\n", + " site_mask=\"gamb_colu_arab\",\n", + " color=\"admin1_year\",\n", + " cohort_size=None,\n", + " width=1000,\n", + " height=400,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5773cf5a-6040-4ed9-8bdd-01d7f16408a7", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_diplotype_clustering(\n", + " region=\"2R:28,480,000-28,490,000\",\n", + " sample_sets=[\"3.0\"],\n", + " sample_query=\"taxon == 'arabiensis'\",\n", + " site_mask=\"gamb_colu_arab\",\n", + " color=\"admin1_year\",\n", + " cohort_size=None,\n", + " width=1000,\n", + " height=400,\n", + " title=None,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52c159c8-1751-4be7-92d9-52bd17751259", + "metadata": {}, + "outputs": [], + "source": [ + "new_cohorts = {\n", + " \"East\": \"country in ['Malawi', 'Tanzania', 'Kenya', 'Uganda']\",\n", + " \"West\": \"country in ['Mali', 'Burkina Faso', 'Cameroon']\",\n", + "}\n", + "other_cohorts = {\n", + " \"East\": \"country in ['Malawi']\",\n", + " \"West\": \"country in ['Mali', 'Burkina Faso', 'Cameroon']\",\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7bd7b04c-5401-4c1e-92eb-cae5b3d87851", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_diplotype_clustering(\n", + " region=\"2R:28,480,000-28,490,000\",\n", + " sample_sets=[\"3.0\"],\n", + " sample_query=\"taxon == 'arabiensis'\",\n", + " distance_metric=\"euclidean\",\n", + " site_mask=\"gamb_colu_arab\",\n", + " color=new_cohorts,\n", + " cohort_size=None,\n", + " width=1000,\n", + " height=400,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f6415e6-8d0d-4446-9adb-58c5b236dbcc", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_diplotype_clustering(\n", + " region=\"2R:28,480,000-28,490,000\",\n", + " sample_sets=[\"3.0\"],\n", + " sample_query=\"taxon == 'arabiensis'\",\n", + " site_mask=\"gamb_colu_arab\",\n", + " color=other_cohorts,\n", + " cohort_size=None,\n", + " width=1000,\n", + " height=400,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31df6632-8025-4f55-9d5d-9d04978b1690", + "metadata": {}, + "outputs": [], + "source": [ + "ag3.plot_diplotype_clustering(\n", + " region=\"2R:28,480,000-28,490,000\",\n", + " sample_sets=[\"3.0\"],\n", + " sample_query=\"taxon == 'arabiensis'\",\n", + " site_mask=\"gamb_colu_arab\",\n", + " symbol=new_cohorts,\n", + " color=\"year\",\n", + " cohort_size=None,\n", + " width=1000,\n", + " height=400,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6016ee62", + "metadata": {}, + "source": [ + "## Af1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87f36d07", + "metadata": {}, + "outputs": [], + "source": [ + "af1 = malariagen_data.Af1(\n", + " \"simplecache::gs://vo_afun_release\",\n", + " simplecache=dict(cache_storage=\"../gcs_cache\"),\n", + " debug=False,\n", + " pre=True,\n", + ")\n", + "af1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1f68c6c", + "metadata": {}, + "outputs": [], + "source": [ + "af1.plot_diplotype_clustering(\n", + " region=\"2RL:2,410,000-2,430,000\",\n", + " sample_sets=[\"1240-VO-CD-KOEKEMOER-VMF00099\", \"1240-VO-MZ-KOEKEMOER-VMF00101\"],\n", + " color=\"sample_set\",\n", + " symbol=\"country\",\n", + " linkage_method=\"complete\",\n", + " distance_metric=\"euclidean\",\n", + " width=1000,\n", + " height=500,\n", + " count_sort=True,\n", + " distance_sort=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5bc93110", + "metadata": {}, + "outputs": [], + "source": [ + "af1.plot_diplotype_clustering(\n", + " region=\"2RL:28,545,000-28,550,000\",\n", + " sample_sets=[\"1240-VO-CD-KOEKEMOER-VMF00099\", \"1240-VO-MZ-KOEKEMOER-VMF00101\"],\n", + " color=\"country\",\n", + " symbol=\"sample_set\",\n", + " cohort_size=80,\n", + " linkage_method=\"weighted\",\n", + " count_sort=True,\n", + " width=1200,\n", + " height=600,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45bdce7c", + "metadata": {}, + "outputs": [], + "source": [ + "af1.plot_diplotype_clustering(\n", + " region=\"2RL:28,480,000-28,490,000\",\n", + " sample_sets=[\"1.0\"],\n", + " sample_query=\"country == 'Ghana'\",\n", + " color=\"sample_set\",\n", + " cohort_size=None,\n", + " width=1000,\n", + " height=400,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5979cca3", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "vscode": { + "interpreter": { + "hash": "3b9ddb1005cd06989fd869b9e3d566470f1be01faa610bb17d64e58e32302e8b" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tests/anoph/conftest.py b/tests/anoph/conftest.py index ad15f2c0b..95f5eb006 100644 --- a/tests/anoph/conftest.py +++ b/tests/anoph/conftest.py @@ -1016,6 +1016,10 @@ def random_region_str(self, region_size=None): contig_size = self.contig_sizes[contig] region_start = randint(1, contig_size) if region_size: + # Ensure we the region span doesn't exceed the contig size. + if contig_size - region_start < region_size: + region_start = contig_size - region_size + region_end = region_start + region_size else: region_end = randint(region_start, contig_size) diff --git a/tests/anoph/test_dipclust.py b/tests/anoph/test_dipclust.py new file mode 100644 index 000000000..69dc0cb3c --- /dev/null +++ b/tests/anoph/test_dipclust.py @@ -0,0 +1,102 @@ +import random +import pytest +from pytest_cases import parametrize_with_cases + +from malariagen_data import af1 as _af1 +from malariagen_data import ag3 as _ag3 +from malariagen_data.anoph.dipclust import AnophelesDipClustAnalysis + + +@pytest.fixture +def ag3_sim_api(ag3_sim_fixture): + return AnophelesDipClustAnalysis( + url=ag3_sim_fixture.url, + config_path=_ag3.CONFIG_PATH, + gcs_url=_ag3.GCS_URL, + major_version_number=_ag3.MAJOR_VERSION_NUMBER, + major_version_path=_ag3.MAJOR_VERSION_PATH, + pre=True, + aim_metadata_dtype={ + "aim_species_fraction_arab": "float64", + "aim_species_fraction_colu": "float64", + "aim_species_fraction_colu_no2l": "float64", + "aim_species_gambcolu_arabiensis": object, + "aim_species_gambiae_coluzzii": object, + "aim_species": object, + }, + gff_gene_type="gene", + gff_gene_name_attribute="Name", + gff_default_attributes=("ID", "Parent", "Name", "description"), + default_site_mask="gamb_colu_arab", + results_cache=ag3_sim_fixture.results_cache_path.as_posix(), + taxon_colors=_ag3.TAXON_COLORS, + virtual_contigs=_ag3.VIRTUAL_CONTIGS, + ) + + +@pytest.fixture +def af1_sim_api(af1_sim_fixture): + return AnophelesDipClustAnalysis( + url=af1_sim_fixture.url, + config_path=_af1.CONFIG_PATH, + gcs_url=_af1.GCS_URL, + major_version_number=_af1.MAJOR_VERSION_NUMBER, + major_version_path=_af1.MAJOR_VERSION_PATH, + pre=False, + gff_gene_type="protein_coding_gene", + gff_gene_name_attribute="Note", + gff_default_attributes=("ID", "Parent", "Note", "description"), + default_site_mask="funestus", + results_cache=af1_sim_fixture.results_cache_path.as_posix(), + taxon_colors=_af1.TAXON_COLORS, + ) + + +# N.B., here we use pytest_cases to parametrize tests. Each +# function whose name begins with "case_" defines a set of +# inputs to the test functions. See the documentation for +# pytest_cases for more information, e.g.: +# +# https://smarie.github.io/python-pytest-cases/#basic-usage +# +# We use this approach here because we want to use fixtures +# as test parameters, which is otherwise hard to do with +# pytest alone. + + +def case_ag3_sim(ag3_sim_fixture, ag3_sim_api): + return ag3_sim_fixture, ag3_sim_api + + +def case_af1_sim(af1_sim_fixture, af1_sim_api): + return af1_sim_fixture, af1_sim_api + + +@pytest.mark.parametrize("distance_metric", ["cityblock", "euclidean"]) +@parametrize_with_cases("fixture,api", cases=".") +def test_plot_diplotype_clustering( + fixture, api: AnophelesDipClustAnalysis, distance_metric +): + # Set up test parameters. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + linkage_methods = ( + "single", + "complete", + "average", + "weighted", + "centroid", + "median", + "ward", + ) + sample_queries = (None, "sex_call == 'F'") + dipclust_params = dict( + region=fixture.random_region_str(region_size=5000), + sample_sets=[random.choice(all_sample_sets)], + linkage_method=random.choice(linkage_methods), + distance_metric=distance_metric, + sample_query=random.choice(sample_queries), + show=False, + ) + + # Run checks. + api.plot_diplotype_clustering(**dipclust_params) diff --git a/tests/anoph/test_hapclust.py b/tests/anoph/test_hapclust.py new file mode 100644 index 000000000..29a5c4087 --- /dev/null +++ b/tests/anoph/test_hapclust.py @@ -0,0 +1,98 @@ +import random +import pytest +from pytest_cases import parametrize_with_cases + +from malariagen_data import af1 as _af1 +from malariagen_data import ag3 as _ag3 +from malariagen_data.anoph.hapclust import AnophelesHapClustAnalysis + + +@pytest.fixture +def ag3_sim_api(ag3_sim_fixture): + return AnophelesHapClustAnalysis( + url=ag3_sim_fixture.url, + config_path=_ag3.CONFIG_PATH, + gcs_url=_ag3.GCS_URL, + major_version_number=_ag3.MAJOR_VERSION_NUMBER, + major_version_path=_ag3.MAJOR_VERSION_PATH, + pre=True, + aim_metadata_dtype={ + "aim_species_fraction_arab": "float64", + "aim_species_fraction_colu": "float64", + "aim_species_fraction_colu_no2l": "float64", + "aim_species_gambcolu_arabiensis": object, + "aim_species_gambiae_coluzzii": object, + "aim_species": object, + }, + gff_gene_type="gene", + gff_gene_name_attribute="Name", + gff_default_attributes=("ID", "Parent", "Name", "description"), + default_phasing_analysis="gamb_colu_arab", + results_cache=ag3_sim_fixture.results_cache_path.as_posix(), + taxon_colors=_ag3.TAXON_COLORS, + virtual_contigs=_ag3.VIRTUAL_CONTIGS, + ) + + +@pytest.fixture +def af1_sim_api(af1_sim_fixture): + return AnophelesHapClustAnalysis( + url=af1_sim_fixture.url, + config_path=_af1.CONFIG_PATH, + gcs_url=_af1.GCS_URL, + major_version_number=_af1.MAJOR_VERSION_NUMBER, + major_version_path=_af1.MAJOR_VERSION_PATH, + pre=False, + gff_gene_type="protein_coding_gene", + gff_gene_name_attribute="Note", + gff_default_attributes=("ID", "Parent", "Note", "description"), + default_phasing_analysis="funestus", + results_cache=af1_sim_fixture.results_cache_path.as_posix(), + taxon_colors=_af1.TAXON_COLORS, + ) + + +# N.B., here we use pytest_cases to parametrize tests. Each +# function whose name begins with "case_" defines a set of +# inputs to the test functions. See the documentation for +# pytest_cases for more information, e.g.: +# +# https://smarie.github.io/python-pytest-cases/#basic-usage +# +# We use this approach here because we want to use fixtures +# as test parameters, which is otherwise hard to do with +# pytest alone. + + +def case_ag3_sim(ag3_sim_fixture, ag3_sim_api): + return ag3_sim_fixture, ag3_sim_api + + +def case_af1_sim(af1_sim_fixture, af1_sim_api): + return af1_sim_fixture, af1_sim_api + + +@parametrize_with_cases("fixture,api", cases=".") +def test_plot_haplotype_clustering(fixture, api: AnophelesHapClustAnalysis): + # Set up test parameters. + all_sample_sets = api.sample_sets()["sample_set"].to_list() + linkage_methods = ( + "single", + "complete", + "average", + "weighted", + "centroid", + "median", + "ward", + ) + sample_queries = (None, "sex_call == 'F'") + hapclust_params = dict( + region=fixture.random_region_str(region_size=5000), + sample_sets=[random.choice(all_sample_sets)], + linkage_method=random.choice(linkage_methods), + sample_query=random.choice(sample_queries), + show=False, + ) + + # Run checks. + api.plot_haplotype_clustering(**hapclust_params)