From fac34695811a065275617a3390e960159f718450 Mon Sep 17 00:00:00 2001
From: Sanjay Curtis Nagi <34922269+sanjaynagi@users.noreply.github.com>
Date: Wed, 20 Mar 2024 18:54:57 +0000
Subject: [PATCH] Implement diplotype clustering, refactor haplotype
clustering, and fix random_region_str() (#507)
* add diplotype clustering
* pre commit fixes
* fix remove self from args
* fix not being imported
* change capitalisation of DipClust class
* fix
* fixes2
* fix nb
* add tests
* precommit lint
* test fix
* fix
* test fix
* fix
* fix
* fix random_region_str to prevent simulated regions exceeding contig size
* add sample_query to testW
* remove assert distance metric line
* fix
* refactor hap clust
* fix dip clust y axes + axes title
* slighly more buffer dipclust
* improve codecov
* precommit
* alimanfoo suggestions
* small edits
* fix lint
---
malariagen_data/anoph/dipclust.py | 275 +++++++++++++
malariagen_data/anoph/dipclust_params.py | 24 ++
malariagen_data/anoph/hapclust.py | 271 +++++++++++++
malariagen_data/anopheles.py | 252 +-----------
malariagen_data/plotly_dendrogram.py | 6 +-
malariagen_data/util.py | 135 +++++++
notebooks/plot_diplotype_clustering.ipynb | 449 ++++++++++++++++++++++
tests/anoph/conftest.py | 4 +
tests/anoph/test_dipclust.py | 102 +++++
tests/anoph/test_hapclust.py | 98 +++++
10 files changed, 1366 insertions(+), 250 deletions(-)
create mode 100644 malariagen_data/anoph/dipclust.py
create mode 100644 malariagen_data/anoph/dipclust_params.py
create mode 100644 malariagen_data/anoph/hapclust.py
create mode 100644 notebooks/plot_diplotype_clustering.ipynb
create mode 100644 tests/anoph/test_dipclust.py
create mode 100644 tests/anoph/test_hapclust.py
diff --git a/malariagen_data/anoph/dipclust.py b/malariagen_data/anoph/dipclust.py
new file mode 100644
index 000000000..a8173958c
--- /dev/null
+++ b/malariagen_data/anoph/dipclust.py
@@ -0,0 +1,275 @@
+from typing import Optional, Tuple
+
+import allel # type: ignore
+import numpy as np
+from numpydoc_decorator import doc # type: ignore
+
+from ..util import (
+ CacheMiss,
+ check_types,
+ multiallelic_diplotype_pdist,
+ multiallelic_diplotype_mean_sqeuclidean,
+ multiallelic_diplotype_mean_cityblock,
+)
+from ..plotly_dendrogram import plot_dendrogram
+from . import base_params, plotly_params, tree_params, dipclust_params
+from .base_params import DEFAULT
+from .snp_data import AnophelesSnpData
+
+
+class AnophelesDipClustAnalysis(
+ AnophelesSnpData,
+):
+ def __init__(
+ self,
+ **kwargs,
+ ):
+ # N.B., this class is designed to work cooperatively, and
+ # so it's important that any remaining parameters are passed
+ # to the superclass constructor.
+ super().__init__(**kwargs)
+
+ @check_types
+ @doc(
+ summary="Hierarchically cluster diplotypes in region and produce an interactive plot.",
+ parameters=dict(
+ leaf_y="Y coordinate at which to plot the leaf markers.",
+ ),
+ )
+ def plot_diplotype_clustering(
+ self,
+ region: base_params.regions,
+ site_mask: base_params.site_mask = DEFAULT,
+ sample_sets: Optional[base_params.sample_sets] = None,
+ sample_query: Optional[base_params.sample_query] = None,
+ cohort_size: Optional[base_params.cohort_size] = None,
+ random_seed: base_params.random_seed = 42,
+ color: plotly_params.color = None,
+ symbol: plotly_params.symbol = None,
+ linkage_method: dipclust_params.linkage_method = dipclust_params.linkage_method_default,
+ distance_metric: dipclust_params.distance_metric = dipclust_params.distance_metric_default,
+ count_sort: Optional[tree_params.count_sort] = None,
+ distance_sort: Optional[tree_params.distance_sort] = None,
+ title: plotly_params.title = True,
+ title_font_size: plotly_params.title_font_size = 14,
+ width: plotly_params.width = None,
+ height: plotly_params.height = 500,
+ show: plotly_params.show = True,
+ renderer: plotly_params.renderer = None,
+ render_mode: plotly_params.render_mode = "svg",
+ leaf_y: int = 0,
+ marker_size: plotly_params.marker_size = 5,
+ line_width: plotly_params.line_width = 0.5,
+ line_color: plotly_params.line_color = "black",
+ color_discrete_sequence: plotly_params.color_discrete_sequence = None,
+ color_discrete_map: plotly_params.color_discrete_map = None,
+ category_orders: plotly_params.category_order = None,
+ legend_sizing: plotly_params.legend_sizing = "constant",
+ ) -> plotly_params.figure:
+ import sys
+
+ debug = self._log.debug
+
+ # Normalise params.
+ if count_sort is None and distance_sort is None:
+ count_sort = True
+ distance_sort = False
+
+ # This is needed to avoid RecursionError on some haplotype clustering analyses
+ # with larger numbers of haplotypes.
+ sys.setrecursionlimit(10_000)
+
+ debug("load sample metadata")
+ df_samples = self.sample_metadata(
+ sample_sets=sample_sets, sample_query=sample_query
+ )
+
+ dist, gt_samples, n_snps_used = self.diplotype_pairwise_distances(
+ region=region,
+ site_mask=site_mask,
+ sample_sets=sample_sets,
+ sample_query=sample_query,
+ cohort_size=cohort_size,
+ distance_metric=distance_metric,
+ random_seed=random_seed,
+ )
+
+ # Align sample metadata with genotypes.
+ df_samples = (
+ df_samples.set_index("sample_id").loc[gt_samples.tolist()].reset_index()
+ )
+
+ # Normalise color and symbol parameters.
+ symbol_prepped = self._setup_sample_symbol(
+ data=df_samples,
+ symbol=symbol,
+ )
+ del symbol
+ (
+ color_prepped,
+ color_discrete_map_prepped,
+ category_orders_prepped,
+ ) = self._setup_sample_colors_plotly(
+ data=df_samples,
+ color=color,
+ color_discrete_map=color_discrete_map,
+ color_discrete_sequence=color_discrete_sequence,
+ category_orders=category_orders,
+ )
+ del color
+ del color_discrete_map
+ del color_discrete_sequence
+
+ # Configure hover data.
+ hover_data = self._setup_sample_hover_data_plotly(
+ color=color_prepped, symbol=symbol_prepped
+ )
+
+ # Construct plot title.
+ if title is True:
+ title_lines = []
+ if sample_sets is not None:
+ title_lines.append(f"Sample sets: {sample_sets}")
+ if sample_query is not None:
+ title_lines.append(f"Sample query: {sample_query}")
+ title_lines.append(f"Genomic region: {region} ({n_snps_used:,} SNPs)")
+ title = "
".join(title_lines)
+
+ # Create the plot.
+ with self._spinner("Plot dendrogram"):
+ fig = plot_dendrogram(
+ dist=dist,
+ linkage_method=linkage_method,
+ count_sort=count_sort,
+ distance_sort=distance_sort,
+ render_mode=render_mode,
+ width=width,
+ height=height,
+ title=title,
+ line_width=line_width,
+ line_color=line_color,
+ marker_size=marker_size,
+ leaf_data=df_samples,
+ leaf_hover_name="sample_id",
+ leaf_hover_data=hover_data,
+ leaf_color=color_prepped,
+ leaf_symbol=symbol_prepped,
+ leaf_y=leaf_y,
+ leaf_color_discrete_map=color_discrete_map_prepped,
+ leaf_category_orders=category_orders_prepped,
+ template="simple_white",
+ y_axis_title=f"Distance ({distance_metric})",
+ y_axis_buffer=0.1,
+ )
+
+ # Tidy up.
+ fig.update_layout(
+ title_font=dict(
+ size=title_font_size,
+ ),
+ legend=dict(itemsizing=legend_sizing, tracegroupgap=0),
+ )
+
+ if show: # pragma: no cover
+ fig.show(renderer=renderer)
+ return None
+ else:
+ return fig
+
+ def diplotype_pairwise_distances(
+ self,
+ region: base_params.regions,
+ site_mask: base_params.site_mask = DEFAULT,
+ sample_sets: Optional[base_params.sample_sets] = None,
+ sample_query: Optional[base_params.sample_query] = None,
+ site_class: Optional[base_params.site_class] = None,
+ cohort_size: Optional[base_params.cohort_size] = None,
+ distance_metric: dipclust_params.distance_metric = dipclust_params.distance_metric_default,
+ random_seed: base_params.random_seed = 42,
+ ) -> Tuple[np.ndarray, np.ndarray, int]:
+ # Change this name if you ever change the behaviour of this function, to
+ # invalidate any previously cached data.
+ name = "diplotype_pairwise_distances_v1"
+
+ # Normalize params for consistent hash value.
+ sample_sets_prepped = self._prep_sample_sets_param(sample_sets=sample_sets)
+ region_prepped = self._prep_region_cache_param(region=region)
+ params = dict(
+ region=region_prepped,
+ site_mask=site_mask,
+ sample_sets=sample_sets_prepped,
+ sample_query=sample_query,
+ site_class=site_class,
+ cohort_size=cohort_size,
+ distance_metric=distance_metric,
+ random_seed=random_seed,
+ )
+
+ # Try to retrieve results from the cache.
+ try:
+ results = self.results_cache_get(name=name, params=params)
+
+ except CacheMiss:
+ results = self._diplotype_pairwise_distances(**params)
+ self.results_cache_set(name=name, params=params, results=results)
+
+ # Unpack results")
+ dist: np.ndarray = results["dist"]
+ gt_samples: np.ndarray = results["gt_samples"]
+ n_snps: int = int(results["n_snps"][()]) # ensure scalar
+
+ return dist, gt_samples, n_snps
+
+ def _diplotype_pairwise_distances(
+ self,
+ *,
+ region,
+ site_mask,
+ sample_sets,
+ sample_query,
+ site_class,
+ cohort_size,
+ distance_metric,
+ random_seed,
+ ):
+ if distance_metric == "cityblock":
+ metric = multiallelic_diplotype_mean_cityblock
+ elif distance_metric == "euclidean":
+ metric = multiallelic_diplotype_mean_sqeuclidean
+
+ # Load haplotypes.
+ ds_snps = self.snp_calls(
+ region=region,
+ sample_query=sample_query,
+ sample_sets=sample_sets,
+ site_mask=site_mask,
+ site_class=site_class,
+ cohort_size=cohort_size,
+ random_seed=random_seed,
+ )
+
+ with self._dask_progress(desc="Load genotypes"):
+ gt = ds_snps["call_genotype"].data.compute()
+
+ with self._spinner(
+ desc="Compute allele counts and remove non-segregating sites"
+ ):
+ # Compute allele count, remove non-segregating sites.
+ ac = allel.GenotypeArray(gt).count_alleles(max_allele=3)
+ gt_seg = gt.compress(ac.is_segregating(), axis=0)
+ ac_seg = allel.GenotypeArray(gt_seg).to_allele_counts(max_allele=3)
+ X = np.ascontiguousarray(np.swapaxes(ac_seg.values, 0, 1))
+
+ # Compute pairwise distances.
+ with self._spinner(desc="Compute pairwise distances"):
+ dist = multiallelic_diplotype_pdist(X, metric=metric)
+
+ # Extract IDs of samples. Convert to "U" dtype here
+ # to allow these to be saved to the results cache.
+ gt_samples = ds_snps["sample_id"].values.astype("U")
+
+ return dict(
+ dist=dist,
+ gt_samples=gt_samples,
+ n_snps=np.array(gt_seg.shape[0]),
+ )
diff --git a/malariagen_data/anoph/dipclust_params.py b/malariagen_data/anoph/dipclust_params.py
new file mode 100644
index 000000000..d55436b5d
--- /dev/null
+++ b/malariagen_data/anoph/dipclust_params.py
@@ -0,0 +1,24 @@
+"""Parameters for diplotype clustering functions."""
+
+from typing import Literal
+
+from typing_extensions import Annotated, TypeAlias
+
+linkage_method: TypeAlias = Annotated[
+ Literal["single", "complete", "average", "weighted", "centroid", "median", "ward"],
+ """
+ The linkage algorithm to use. See the Linkage Methods section of the
+ scipy.cluster.hierarchy.linkage docs for full descriptions.
+ """,
+]
+
+linkage_method_default: linkage_method = "complete"
+
+distance_metric: TypeAlias = Annotated[
+ Literal["cityblock", "euclidean"],
+ """
+ The distance metric to use. Either "cityblock" or "euclidean".
+ """,
+]
+
+distance_metric_default: distance_metric = "cityblock"
diff --git a/malariagen_data/anoph/hapclust.py b/malariagen_data/anoph/hapclust.py
new file mode 100644
index 000000000..fa3bc2ea2
--- /dev/null
+++ b/malariagen_data/anoph/hapclust.py
@@ -0,0 +1,271 @@
+from typing import Optional, Tuple
+
+import allel # type: ignore
+import numpy as np
+import pandas as pd
+from numpydoc_decorator import doc # type: ignore
+
+from ..util import CacheMiss, check_types, pdist_abs_hamming
+from ..plotly_dendrogram import plot_dendrogram
+from . import base_params, plotly_params, tree_params, hap_params, hapclust_params
+from .base_params import DEFAULT
+from .snp_data import AnophelesSnpData
+from .hap_data import AnophelesHapData
+
+
+class AnophelesHapClustAnalysis(AnophelesHapData, AnophelesSnpData):
+ def __init__(
+ self,
+ **kwargs,
+ ):
+ # N.B., this class is designed to work cooperatively, and
+ # so it's important that any remaining parameters are passed
+ # to the superclass constructor.
+ super().__init__(**kwargs)
+
+ @check_types
+ @doc(
+ summary="""
+ Hierarchically cluster haplotypes in region and produce an interactive plot.
+ """,
+ parameters=dict(
+ leaf_y="Y coordinate at which to plot the leaf markers.",
+ ),
+ )
+ def plot_haplotype_clustering(
+ self,
+ region: base_params.regions,
+ analysis: hap_params.analysis = DEFAULT,
+ sample_sets: Optional[base_params.sample_sets] = None,
+ sample_query: Optional[base_params.sample_query] = None,
+ cohort_size: Optional[base_params.cohort_size] = None,
+ random_seed: base_params.random_seed = 42,
+ color: plotly_params.color = None,
+ symbol: plotly_params.symbol = None,
+ linkage_method: hapclust_params.linkage_method = hapclust_params.linkage_method_default,
+ count_sort: Optional[tree_params.count_sort] = None,
+ distance_sort: Optional[tree_params.distance_sort] = None,
+ title: plotly_params.title = True,
+ title_font_size: plotly_params.title_font_size = 14,
+ width: plotly_params.width = None,
+ height: plotly_params.height = 500,
+ show: plotly_params.show = True,
+ renderer: plotly_params.renderer = None,
+ render_mode: plotly_params.render_mode = "svg",
+ leaf_y: int = 0,
+ marker_size: plotly_params.marker_size = 5,
+ line_width: plotly_params.line_width = 0.5,
+ line_color: plotly_params.line_color = "black",
+ color_discrete_sequence: plotly_params.color_discrete_sequence = None,
+ color_discrete_map: plotly_params.color_discrete_map = None,
+ category_orders: plotly_params.category_order = None,
+ legend_sizing: plotly_params.legend_sizing = "constant",
+ ) -> plotly_params.figure:
+ import sys
+
+ # Normalise params.
+ if count_sort is None and distance_sort is None:
+ count_sort = True
+ distance_sort = False
+
+ # This is needed to avoid RecursionError on some haplotype clustering analyses
+ # with larger numbers of haplotypes.
+ sys.setrecursionlimit(10_000)
+
+ # Load sample metadata.
+ df_samples = self.sample_metadata(
+ sample_sets=sample_sets, sample_query=sample_query
+ )
+
+ # Compute pairwise distances.
+ dist, phased_samples, n_snps_used = self.haplotype_pairwise_distances(
+ region=region,
+ analysis=analysis,
+ sample_sets=sample_sets,
+ sample_query=sample_query,
+ cohort_size=cohort_size,
+ random_seed=random_seed,
+ )
+
+ # Align sample metadata with haplotypes.
+ df_samples_phased = (
+ df_samples.set_index("sample_id").loc[phased_samples.tolist()].reset_index()
+ )
+
+ # Normalise color and symbol parameters.
+ symbol_prepped = self._setup_sample_symbol(
+ data=df_samples_phased,
+ symbol=symbol,
+ )
+ del symbol
+ (
+ color_prepped,
+ color_discrete_map_prepped,
+ category_orders_prepped,
+ ) = self._setup_sample_colors_plotly(
+ data=df_samples_phased,
+ color=color,
+ color_discrete_map=color_discrete_map,
+ color_discrete_sequence=color_discrete_sequence,
+ category_orders=category_orders,
+ )
+ del color
+ del color_discrete_map
+ del color_discrete_sequence
+
+ # Repeat the dataframe so there is one row of metadata for each haplotype.
+ df_haps = pd.DataFrame(np.repeat(df_samples_phased.values, 2, axis=0))
+ df_haps.columns = df_samples_phased.columns
+
+ # Configure hover data.
+ hover_data = self._setup_sample_hover_data_plotly(
+ color=color_prepped, symbol=symbol_prepped
+ )
+
+ # Construct plot title.
+ if title is True:
+ title_lines = []
+ if sample_sets is not None:
+ title_lines.append(f"Sample sets: {sample_sets}")
+ if sample_query is not None:
+ title_lines.append(f"Sample query: {sample_query}")
+ title_lines.append(f"Genomic region: {region} ({n_snps_used:,} SNPs)")
+ title = "
".join(title_lines)
+
+ # Create the plot.
+ with self._spinner("Plot dendrogram"):
+ fig = plot_dendrogram(
+ dist=dist,
+ linkage_method=linkage_method,
+ count_sort=count_sort,
+ distance_sort=distance_sort,
+ render_mode=render_mode,
+ width=width,
+ height=height,
+ title=title,
+ line_width=line_width,
+ line_color=line_color,
+ marker_size=marker_size,
+ leaf_data=df_haps,
+ leaf_hover_name="sample_id",
+ leaf_hover_data=hover_data,
+ leaf_color=color_prepped,
+ leaf_symbol=symbol_prepped,
+ leaf_y=leaf_y,
+ leaf_color_discrete_map=color_discrete_map_prepped,
+ leaf_category_orders=category_orders_prepped,
+ template="simple_white",
+ y_axis_title="Distance (no. SNPs)",
+ y_axis_buffer=1,
+ )
+
+ # Tidy up.
+ fig.update_layout(
+ title_font=dict(
+ size=title_font_size,
+ ),
+ legend=dict(itemsizing=legend_sizing, tracegroupgap=0),
+ )
+
+ if show: # pragma: no cover
+ fig.show(renderer=renderer)
+ return None
+ else:
+ return fig
+
+ @doc(
+ summary="""
+ Compute pairwise distances between haplotypes.
+ """,
+ returns=dict(
+ dist="Pairwise distance.",
+ phased_samples="Sample identifiers for haplotypes.",
+ n_snps="Number of SNPs used.",
+ ),
+ )
+ def haplotype_pairwise_distances(
+ self,
+ region: base_params.regions,
+ analysis: hap_params.analysis = DEFAULT,
+ sample_sets: Optional[base_params.sample_sets] = None,
+ sample_query: Optional[base_params.sample_query] = None,
+ cohort_size: Optional[base_params.cohort_size] = None,
+ random_seed: base_params.random_seed = 42,
+ ) -> Tuple[np.ndarray, np.ndarray, int]:
+ # Change this name if you ever change the behaviour of this function, to
+ # invalidate any previously cached data.
+ name = "haplotype_pairwise_distances"
+
+ # Normalize params for consistent hash value.
+ sample_sets_prepped = self._prep_sample_sets_param(sample_sets=sample_sets)
+ region_prepped = self._prep_region_cache_param(region=region)
+ params = dict(
+ region=region_prepped,
+ analysis=analysis,
+ sample_sets=sample_sets_prepped,
+ sample_query=sample_query,
+ cohort_size=cohort_size,
+ random_seed=random_seed,
+ )
+
+ # Try to retrieve results from the cache.
+ try:
+ results = self.results_cache_get(name=name, params=params)
+
+ except CacheMiss:
+ results = self._haplotype_pairwise_distances(**params)
+ self.results_cache_set(name=name, params=params, results=results)
+
+ # Unpack results")
+ dist: np.ndarray = results["dist"]
+ phased_samples: np.ndarray = results["phased_samples"]
+ n_snps: int = int(results["n_snps"][()]) # ensure scalar
+
+ return dist, phased_samples, n_snps
+
+ def _haplotype_pairwise_distances(
+ self,
+ *,
+ region,
+ analysis,
+ sample_sets,
+ sample_query,
+ cohort_size,
+ random_seed,
+ ):
+ from scipy.spatial.distance import squareform # type: ignore
+
+ # Load haplotypes.
+ ds_haps = self.haplotypes(
+ region=region,
+ analysis=analysis,
+ sample_query=sample_query,
+ sample_sets=sample_sets,
+ cohort_size=cohort_size,
+ random_seed=random_seed,
+ )
+ gt = allel.GenotypeDaskArray(ds_haps["call_genotype"].data)
+ with self._dask_progress(desc="Load haplotypes"):
+ ht = gt.to_haplotypes().compute().values
+
+ # Compute allele count, remove non-segregating sites.
+ ac = allel.HaplotypeArray(ht).count_alleles(max_allele=1)
+ ht_seg = ht[ac.is_segregating()]
+
+ # Transpose memory layout for faster hamming distance calculations.
+ ht_t = np.ascontiguousarray(ht_seg.T)
+
+ # Compute pairwise distances.
+ with self._spinner(desc="Compute pairwise distances"):
+ dist_sq = pdist_abs_hamming(ht_t)
+ dist = squareform(dist_sq)
+
+ # Extract IDs of phased samples. Convert to "U" dtype here
+ # to allow these to be saved to the results cache.
+ phased_samples = ds_haps["sample_id"].values.astype("U")
+
+ return dict(
+ dist=dist,
+ phased_samples=phased_samples,
+ n_snps=np.array(ht.shape[0]),
+ )
diff --git a/malariagen_data/anopheles.py b/malariagen_data/anopheles.py
index 5fb8ddb70..1962031e8 100644
--- a/malariagen_data/anopheles.py
+++ b/malariagen_data/anopheles.py
@@ -32,7 +32,6 @@
dash_params,
diplotype_distance_params,
gplt_params,
- hapclust_params,
hapnet_params,
het_params,
ihs_params,
@@ -55,6 +54,8 @@
from .anoph.h12 import AnophelesH12Analysis
from .anoph.h1x import AnophelesH1XAnalysis
from .mjn import median_joining_network, mjn_graph
+from .anoph.hapclust import AnophelesHapClustAnalysis
+from .anoph.dipclust import AnophelesDipClustAnalysis
from .util import (
CacheMiss,
Region,
@@ -66,7 +67,6 @@
jackknife_ci,
parse_multi_region,
parse_single_region,
- pdist_abs_hamming,
plotly_discrete_legend,
region_str,
simple_xarray_concat,
@@ -98,6 +98,8 @@
# work around pycharm failing to recognise that doc() is callable
# noinspection PyCallingNonCallable
class AnophelesDataResource(
+ AnophelesDipClustAnalysis,
+ AnophelesHapClustAnalysis,
AnophelesH1XAnalysis,
AnophelesH12Analysis,
AnophelesG123Analysis,
@@ -2632,252 +2634,6 @@ def plot_xpehh_gwss_track(
else:
return fig
- @doc(
- summary="""
- Compute pairwise distances between haplotypes.
- """,
- returns=dict(
- dist="Pairwise distance.",
- phased_samples="Sample identifiers for haplotypes.",
- n_snps="Number of SNPs used.",
- ),
- )
- def haplotype_pairwise_distances(
- self,
- region: base_params.regions,
- analysis: hap_params.analysis = DEFAULT,
- sample_sets: Optional[base_params.sample_sets] = None,
- sample_query: Optional[base_params.sample_query] = None,
- cohort_size: Optional[base_params.cohort_size] = None,
- random_seed: base_params.random_seed = 42,
- ) -> Tuple[np.ndarray, np.ndarray, int]:
- # Change this name if you ever change the behaviour of this function, to
- # invalidate any previously cached data.
- name = "haplotype_pairwise_distances"
-
- # Normalize params for consistent hash value.
- sample_sets_prepped = self._prep_sample_sets_param(sample_sets=sample_sets)
- region_prepped = self._prep_region_cache_param(region=region)
- params = dict(
- region=region_prepped,
- analysis=analysis,
- sample_sets=sample_sets_prepped,
- sample_query=sample_query,
- cohort_size=cohort_size,
- random_seed=random_seed,
- )
-
- # Try to retrieve results from the cache.
- try:
- results = self.results_cache_get(name=name, params=params)
-
- except CacheMiss:
- results = self._haplotype_pairwise_distances(**params)
- self.results_cache_set(name=name, params=params, results=results)
-
- # Unpack results")
- dist: np.ndarray = results["dist"]
- phased_samples: np.ndarray = results["phased_samples"]
- n_snps: int = int(results["n_snps"][()]) # ensure scalar
-
- return dist, phased_samples, n_snps
-
- def _haplotype_pairwise_distances(
- self,
- *,
- region,
- analysis,
- sample_sets,
- sample_query,
- cohort_size,
- random_seed,
- ):
- from scipy.spatial.distance import squareform # type: ignore
-
- # Load haplotypes.
- ds_haps = self.haplotypes(
- region=region,
- analysis=analysis,
- sample_query=sample_query,
- sample_sets=sample_sets,
- cohort_size=cohort_size,
- random_seed=random_seed,
- )
- gt = allel.GenotypeDaskArray(ds_haps["call_genotype"].data)
- with self._dask_progress(desc="Load haplotypes"):
- ht = gt.to_haplotypes().compute().values
-
- # Compute allele count, remove non-segregating sites.
- ac = allel.HaplotypeArray(ht).count_alleles(max_allele=1)
- ht_seg = ht[ac.is_segregating()]
-
- # Transpose memory layout for faster hamming distance calculations.
- ht_t = np.ascontiguousarray(ht_seg.T)
-
- # Compute pairwise distances.
- with self._spinner(desc="Compute pairwise distances"):
- dist_sq = pdist_abs_hamming(ht_t)
- dist = squareform(dist_sq)
-
- # Extract IDs of phased samples. Convert to "U" dtype here
- # to allow these to be saved to the results cache.
- phased_samples = ds_haps["sample_id"].values.astype("U")
-
- return dict(
- dist=dist,
- phased_samples=phased_samples,
- n_snps=np.array(ht.shape[0]),
- )
-
- @doc(
- summary="""
- Hierarchically cluster haplotypes in region and produce an interactive plot.
- """,
- parameters=dict(
- leaf_y="Y coordinate at which to plot the leaf markers.",
- ),
- )
- def plot_haplotype_clustering(
- self,
- region: base_params.regions,
- analysis: hap_params.analysis = DEFAULT,
- sample_sets: Optional[base_params.sample_sets] = None,
- sample_query: Optional[base_params.sample_query] = None,
- cohort_size: Optional[base_params.cohort_size] = None,
- random_seed: base_params.random_seed = 42,
- color: plotly_params.color = None,
- symbol: plotly_params.symbol = None,
- linkage_method: hapclust_params.linkage_method = hapclust_params.linkage_method_default,
- count_sort: Optional[tree_params.count_sort] = None,
- distance_sort: Optional[tree_params.distance_sort] = None,
- title: plotly_params.title = True,
- title_font_size: plotly_params.title_font_size = 14,
- width: plotly_params.width = None,
- height: plotly_params.height = 500,
- show: plotly_params.show = True,
- renderer: plotly_params.renderer = None,
- render_mode: plotly_params.render_mode = "svg",
- leaf_y: int = 0,
- marker_size: plotly_params.marker_size = 5,
- line_width: plotly_params.line_width = 0.5,
- line_color: plotly_params.line_color = "black",
- color_discrete_sequence: plotly_params.color_discrete_sequence = None,
- color_discrete_map: plotly_params.color_discrete_map = None,
- category_orders: plotly_params.category_order = None,
- legend_sizing: plotly_params.legend_sizing = "constant",
- ) -> plotly_params.figure:
- import sys
-
- from .plotly_dendrogram import plot_dendrogram
-
- # Normalise params.
- if count_sort is None and distance_sort is None:
- count_sort = True
- distance_sort = False
-
- # This is needed to avoid RecursionError on some haplotype clustering analyses
- # with larger numbers of haplotypes.
- sys.setrecursionlimit(10_000)
-
- # Load sample metadata.
- df_samples = self.sample_metadata(
- sample_sets=sample_sets, sample_query=sample_query
- )
-
- # Compute pairwise distances.
- dist, phased_samples, n_snps_used = self.haplotype_pairwise_distances(
- region=region,
- analysis=analysis,
- sample_sets=sample_sets,
- sample_query=sample_query,
- cohort_size=cohort_size,
- random_seed=random_seed,
- )
-
- # Align sample metadata with haplotypes.
- df_samples_phased = (
- df_samples.set_index("sample_id").loc[phased_samples.tolist()].reset_index()
- )
-
- # Normalise color and symbol parameters.
- symbol_prepped = self._setup_sample_symbol(
- data=df_samples_phased,
- symbol=symbol,
- )
- del symbol
- (
- color_prepped,
- color_discrete_map_prepped,
- category_orders_prepped,
- ) = self._setup_sample_colors_plotly(
- data=df_samples_phased,
- color=color,
- color_discrete_map=color_discrete_map,
- color_discrete_sequence=color_discrete_sequence,
- category_orders=category_orders,
- )
- del color
- del color_discrete_map
- del color_discrete_sequence
-
- # Repeat the dataframe so there is one row of metadata for each haplotype.
- df_haps = pd.DataFrame(np.repeat(df_samples_phased.values, 2, axis=0))
- df_haps.columns = df_samples_phased.columns
-
- # Configure hover data.
- hover_data = self._setup_sample_hover_data_plotly(
- color=color_prepped, symbol=symbol_prepped
- )
-
- # Construct plot title.
- if title is True:
- title_lines = []
- if sample_sets is not None:
- title_lines.append(f"Sample sets: {sample_sets}")
- if sample_query is not None:
- title_lines.append(f"Sample query: {sample_query}")
- title_lines.append(f"Genomic region: {region} ({n_snps_used:,} SNPs)")
- title = "
".join(title_lines)
-
- # Create the plot.
- with self._spinner("Plot dendrogram"):
- fig = plot_dendrogram(
- dist=dist,
- linkage_method=linkage_method,
- count_sort=count_sort,
- distance_sort=distance_sort,
- render_mode=render_mode,
- width=width,
- height=height,
- title=title,
- line_width=line_width,
- line_color=line_color,
- marker_size=marker_size,
- leaf_data=df_haps,
- leaf_hover_name="sample_id",
- leaf_hover_data=hover_data,
- leaf_color=color_prepped,
- leaf_symbol=symbol_prepped,
- leaf_y=leaf_y,
- leaf_color_discrete_map=color_discrete_map_prepped,
- leaf_category_orders=category_orders_prepped,
- template="simple_white",
- )
-
- # Tidy up.
- fig.update_layout(
- title_font=dict(
- size=title_font_size,
- ),
- legend=dict(itemsizing=legend_sizing, tracegroupgap=0),
- )
-
- if show: # pragma: no cover
- fig.show(renderer=renderer)
- return None
- else:
- return fig
-
@check_types
@doc(
summary="""
diff --git a/malariagen_data/plotly_dendrogram.py b/malariagen_data/plotly_dendrogram.py
index 20dcbc3b2..4c1bca06f 100644
--- a/malariagen_data/plotly_dendrogram.py
+++ b/malariagen_data/plotly_dendrogram.py
@@ -25,6 +25,8 @@ def plot_dendrogram(
leaf_color_discrete_map,
leaf_category_orders,
template,
+ y_axis_title,
+ y_axis_buffer,
):
# Hierarchical clustering.
Z = sch.linkage(dist, method=linkage_method)
@@ -105,7 +107,7 @@ def plot_dendrogram(
# it's above the plot it often overlaps the title, so hiding it
# for now.
xaxis_title=None,
- yaxis_title="Distance (no. SNPs)",
+ yaxis_title=y_axis_title,
showlegend=True,
)
@@ -124,7 +126,7 @@ def plot_dendrogram(
showline=False,
showticklabels=True,
ticks="outside",
- range=(leaf_y - 1, np.max(dcoord) + 1),
+ range=(leaf_y - y_axis_buffer, np.max(dcoord) + y_axis_buffer),
)
return fig
diff --git a/malariagen_data/util.py b/malariagen_data/util.py
index ceb5e440e..35360c29e 100644
--- a/malariagen_data/util.py
+++ b/malariagen_data/util.py
@@ -1077,6 +1077,141 @@ def biallelic_diplotype_euclidean(x, y):
return np.sqrt(biallelic_diplotype_sqeuclidean(x, y))
+@numba.njit(parallel=True)
+def multiallelic_diplotype_pdist(X, metric):
+ """Optimised implementation of pairwise distance between diplotypes.
+
+ N.B., here we assume the array X provides diplotypes as genotype allele
+ counts, with axes in the order (n_samples, n_sites, n_alleles).
+
+ Computation will be faster if X is a contiguous (C order) array.
+
+ The metric argument is the function to compute distance for a pair of
+ diplotypes. This can be a numba jitted function.
+
+ """
+ n_samples = X.shape[0]
+ n_pairs = (n_samples * (n_samples - 1)) // 2
+ out = np.zeros(n_pairs, dtype=np.float32)
+
+ # Loop over samples, first in pair.
+ for i in range(n_samples):
+ x = X[i, :, :]
+
+ # Loop over observations again, second in pair.
+ for j in numba.prange(i + 1, n_samples):
+ y = X[j, :, :]
+
+ # Compute distance for the current pair.
+ d = metric(x, y)
+
+ # Store result for the current pair.
+ k = square_to_condensed(i, j, n_samples)
+ out[k] = d
+
+ return out
+
+
+@numba.njit
+def multiallelic_diplotype_mean_cityblock(x, y):
+ """Compute the mean cityblock distance between two diplotypes x and y. The
+ diplotype vectors are expected as genotype allele counts, i.e., x and y
+ should have the same shape (n_sites, n_alleles).
+
+ N.B., here we compute the mean value of the distance over sites where
+ both individuals have a called genotype. This avoids computing distance
+ at missing sites.
+
+ """
+ n_sites = x.shape[0]
+ n_alleles = x.shape[1]
+ distance = np.float32(0)
+ n_sites_called = np.float32(0)
+
+ # Loop over sites.
+ for i in range(n_sites):
+ x_is_called = False
+ y_is_called = False
+ d = np.float32(0)
+
+ # Loop over alleles.
+ for j in range(n_alleles):
+ # Access allele counts.
+ xc = np.float32(x[i, j])
+ yc = np.float32(y[i, j])
+
+ # Check if any alleles observed.
+ x_is_called = x_is_called or (xc > 0)
+ y_is_called = y_is_called or (yc > 0)
+
+ # Compute cityblock distance (absolute difference).
+ d += np.fabs(xc - yc)
+
+ # Accumulate distance for the current pair, but only if both samples
+ # have a called genotype.
+ if x_is_called and y_is_called:
+ distance += d
+ n_sites_called += np.float32(1)
+
+ # Compute the mean distance over sites with called genotypes.
+ if n_sites_called > 0:
+ mean_distance = distance / n_sites_called
+ else:
+ mean_distance = np.nan
+
+ return mean_distance
+
+
+@numba.njit
+def multiallelic_diplotype_mean_sqeuclidean(x, y):
+ """Compute the mean squared euclidean distance between two diplotypes x and
+ y. The diplotype vectors are expected as genotype allele counts, i.e., x and
+ y should have the same shape (n_sites, n_alleles).
+
+ N.B., here we compute the mean value of the distance over sites where
+ both individuals have a called genotype. This avoids computing distance
+ at missing sites.
+
+ """
+ n_sites = x.shape[0]
+ n_alleles = x.shape[1]
+ distance = np.float32(0)
+ n_sites_called = np.float32(0)
+
+ # Loop over sites.
+ for i in range(n_sites):
+ x_is_called = False
+ y_is_called = False
+ d = np.float32(0)
+
+ # Loop over alleles.
+ for j in range(n_alleles):
+ # Access allele counts.
+ xc = np.float32(x[i, j])
+ yc = np.float32(y[i, j])
+
+ # Check if any alleles observed.
+ x_is_called = x_is_called or (xc > 0)
+ y_is_called = y_is_called or (yc > 0)
+
+ # Compute squared euclidean distance.
+ d += (xc - yc) ** 2
+
+ # Accumulate distance for the current pair, but only if both samples
+ # have a called genotype.
+ if x_is_called and y_is_called:
+ distance += d
+ n_sites_called += np.float32(1)
+
+ # Compute the mean distance over sites with called genotypes.
+ if n_sites_called > 0:
+ mean_distance = distance / n_sites_called
+ else:
+ mean_distance = np.nan
+
+ return mean_distance
+
+
@numba.njit
def trim_alleles(ac):
"""Remap allele indices to trim out unobserved alleles.
diff --git a/notebooks/plot_diplotype_clustering.ipynb b/notebooks/plot_diplotype_clustering.ipynb
new file mode 100644
index 000000000..b34fe416b
--- /dev/null
+++ b/notebooks/plot_diplotype_clustering.ipynb
@@ -0,0 +1,449 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d3f73666",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import malariagen_data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0b889fd3",
+ "metadata": {},
+ "source": [
+ "## Ag3"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "38abbb28",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ag3 = malariagen_data.Ag3(\n",
+ " \"simplecache::gs://vo_agam_release\",\n",
+ " simplecache=dict(cache_storage=\"../gcs_cache\"),\n",
+ " results_cache=\"results_cache\",\n",
+ " pre=True,\n",
+ ")\n",
+ "ag3"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "60bd6d1f-e8bd-499e-8c78-e64d2d7e9210",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# %%time\n",
+ "# ag3.plot_haplotype_clustering(\n",
+ "# region=\"2L:2,410,000-2,430,000\",\n",
+ "# sample_sets=[\"3.0\", \"3.1\", \"3.2\", \"3.3\", \"3.4\", \"3.5\", \"3.6\", \"3.7\"],\n",
+ "# site_mask=\"gamb_colu\",\n",
+ "# color=\"taxon\",\n",
+ "# symbol=\"country\",\n",
+ "# linkage_method=\"single\",\n",
+ "# width=1000,\n",
+ "# height=500,\n",
+ "# count_sort=True,\n",
+ "# distance_sort=False,\n",
+ "# render_mode=\"auto\",\n",
+ "# show=False,\n",
+ "# );"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "04e74ac9-51db-4bae-875c-1ee50e02d024",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# ag3.plot_haplotype_clustering(\n",
+ "# region=\"2L:2,410,000-2,430,000\",\n",
+ "# sample_sets=[\"3.0\", \"3.1\", \"3.2\", \"3.3\", \"3.4\", \"3.5\", \"3.6\", \"3.7\"],\n",
+ "# site_mask=\"gamb_colu\",\n",
+ "# color=\"country\",\n",
+ "# linkage_method=\"single\",\n",
+ "# width=None,\n",
+ "# height=500,\n",
+ "# count_sort=True,\n",
+ "# )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "637ae668-c3b8-4914-a0c2-1587cc252c0c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ag3.plot_diplotype_clustering(\n",
+ " region=\"2L:2,410,000-2,430,000\",\n",
+ " sample_sets=[\"AG1000G-GH\", \"AG1000G-BF-B\"],\n",
+ " site_mask=\"gamb_colu\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9bca4ab4-89b4-4238-9f82-d7920e107583",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ag3.plot_diplotype_clustering(\n",
+ " region=\"2L:2,410,000-2,430,000\",\n",
+ " sample_sets=[\"AG1000G-GH\", \"AG1000G-BF-B\"],\n",
+ " site_mask=\"gamb_colu\",\n",
+ " color=\"taxon\",\n",
+ " symbol=\"country\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ccb2366f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ag3.plot_diplotype_clustering(\n",
+ " region=\"2L:2,410,000-2,430,000\",\n",
+ " sample_sets=[\"AG1000G-GH\", \"AG1000G-BF-B\"],\n",
+ " site_mask=\"gamb_colu\",\n",
+ " color=\"taxon\",\n",
+ " symbol=\"country\",\n",
+ " linkage_method=\"complete\",\n",
+ " width=1000,\n",
+ " height=500,\n",
+ " count_sort=True,\n",
+ " distance_sort=False,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c0420334-124d-4bab-9ef7-bb9ceb99d6f0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ag3.plot_diplotype_clustering(\n",
+ " region=\"2L:2,410,000-2,430,000\",\n",
+ " sample_sets=[\"AG1000G-GH\", \"AG1000G-BF-B\"],\n",
+ " site_mask=\"gamb_colu\",\n",
+ " color=\"taxon\",\n",
+ " symbol=\"country\",\n",
+ " linkage_method=\"complete\",\n",
+ " width=1000,\n",
+ " height=500,\n",
+ " count_sort=True,\n",
+ " distance_sort=False,\n",
+ " render_mode=\"svg\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e957de19-af70-412a-8c75-2e6835ffb2cd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ag3.plot_diplotype_clustering(\n",
+ " region=\"2L:2,410,000-2,430,000\",\n",
+ " sample_sets=[\"AG1000G-GH\", \"AG1000G-BF-B\"],\n",
+ " site_mask=\"gamb_colu\",\n",
+ " color=\"country\",\n",
+ " symbol=\"taxon\",\n",
+ " linkage_method=\"single\",\n",
+ " width=1000,\n",
+ " height=500,\n",
+ " count_sort=True,\n",
+ " distance_sort=False,\n",
+ " render_mode=\"webgl\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "857145af",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ag3.plot_diplotype_clustering(\n",
+ " region=\"2L:28,545,000-28,550,000\",\n",
+ " sample_sets=[\"AG1000G-GH\", \"1244-VO-GH-YAWSON-VMF00051\"],\n",
+ " site_mask=\"gamb_colu\",\n",
+ " color=\"admin1_name\",\n",
+ " symbol=\"taxon\",\n",
+ " cohort_size=500,\n",
+ " linkage_method=\"weighted\",\n",
+ " distance_metric=\"euclidean\",\n",
+ " count_sort=True,\n",
+ " width=1200,\n",
+ " height=600,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f41b7e63",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ag3.plot_diplotype_clustering(\n",
+ " region=\"2R:28,480,000-28,490,000\",\n",
+ " sample_sets=[\"3.0\"],\n",
+ " sample_query=\"taxon == 'arabiensis'\",\n",
+ " site_mask=\"gamb_colu_arab\",\n",
+ " color=\"sample_set\",\n",
+ " cohort_size=None,\n",
+ " width=1000,\n",
+ " height=400,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b93d6d94-1add-4ba4-8618-87a11c843e6e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ag3.plot_diplotype_clustering(\n",
+ " region=\"2R:28,480,000-28,490,000\",\n",
+ " sample_sets=[\"3.0\"],\n",
+ " sample_query=\"taxon == 'arabiensis'\",\n",
+ " site_mask=\"gamb_colu_arab\",\n",
+ " color=\"admin1_year\",\n",
+ " cohort_size=None,\n",
+ " width=1000,\n",
+ " height=400,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5773cf5a-6040-4ed9-8bdd-01d7f16408a7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ag3.plot_diplotype_clustering(\n",
+ " region=\"2R:28,480,000-28,490,000\",\n",
+ " sample_sets=[\"3.0\"],\n",
+ " sample_query=\"taxon == 'arabiensis'\",\n",
+ " site_mask=\"gamb_colu_arab\",\n",
+ " color=\"admin1_year\",\n",
+ " cohort_size=None,\n",
+ " width=1000,\n",
+ " height=400,\n",
+ " title=None,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "52c159c8-1751-4be7-92d9-52bd17751259",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "new_cohorts = {\n",
+ " \"East\": \"country in ['Malawi', 'Tanzania', 'Kenya', 'Uganda']\",\n",
+ " \"West\": \"country in ['Mali', 'Burkina Faso', 'Cameroon']\",\n",
+ "}\n",
+ "other_cohorts = {\n",
+ " \"East\": \"country in ['Malawi']\",\n",
+ " \"West\": \"country in ['Mali', 'Burkina Faso', 'Cameroon']\",\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7bd7b04c-5401-4c1e-92eb-cae5b3d87851",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ag3.plot_diplotype_clustering(\n",
+ " region=\"2R:28,480,000-28,490,000\",\n",
+ " sample_sets=[\"3.0\"],\n",
+ " sample_query=\"taxon == 'arabiensis'\",\n",
+ " distance_metric=\"euclidean\",\n",
+ " site_mask=\"gamb_colu_arab\",\n",
+ " color=new_cohorts,\n",
+ " cohort_size=None,\n",
+ " width=1000,\n",
+ " height=400,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9f6415e6-8d0d-4446-9adb-58c5b236dbcc",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ag3.plot_diplotype_clustering(\n",
+ " region=\"2R:28,480,000-28,490,000\",\n",
+ " sample_sets=[\"3.0\"],\n",
+ " sample_query=\"taxon == 'arabiensis'\",\n",
+ " site_mask=\"gamb_colu_arab\",\n",
+ " color=other_cohorts,\n",
+ " cohort_size=None,\n",
+ " width=1000,\n",
+ " height=400,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "31df6632-8025-4f55-9d5d-9d04978b1690",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ag3.plot_diplotype_clustering(\n",
+ " region=\"2R:28,480,000-28,490,000\",\n",
+ " sample_sets=[\"3.0\"],\n",
+ " sample_query=\"taxon == 'arabiensis'\",\n",
+ " site_mask=\"gamb_colu_arab\",\n",
+ " symbol=new_cohorts,\n",
+ " color=\"year\",\n",
+ " cohort_size=None,\n",
+ " width=1000,\n",
+ " height=400,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "6016ee62",
+ "metadata": {},
+ "source": [
+ "## Af1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "87f36d07",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "af1 = malariagen_data.Af1(\n",
+ " \"simplecache::gs://vo_afun_release\",\n",
+ " simplecache=dict(cache_storage=\"../gcs_cache\"),\n",
+ " debug=False,\n",
+ " pre=True,\n",
+ ")\n",
+ "af1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e1f68c6c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "af1.plot_diplotype_clustering(\n",
+ " region=\"2RL:2,410,000-2,430,000\",\n",
+ " sample_sets=[\"1240-VO-CD-KOEKEMOER-VMF00099\", \"1240-VO-MZ-KOEKEMOER-VMF00101\"],\n",
+ " color=\"sample_set\",\n",
+ " symbol=\"country\",\n",
+ " linkage_method=\"complete\",\n",
+ " distance_metric=\"euclidean\",\n",
+ " width=1000,\n",
+ " height=500,\n",
+ " count_sort=True,\n",
+ " distance_sort=False,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5bc93110",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "af1.plot_diplotype_clustering(\n",
+ " region=\"2RL:28,545,000-28,550,000\",\n",
+ " sample_sets=[\"1240-VO-CD-KOEKEMOER-VMF00099\", \"1240-VO-MZ-KOEKEMOER-VMF00101\"],\n",
+ " color=\"country\",\n",
+ " symbol=\"sample_set\",\n",
+ " cohort_size=80,\n",
+ " linkage_method=\"weighted\",\n",
+ " count_sort=True,\n",
+ " width=1200,\n",
+ " height=600,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "45bdce7c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "af1.plot_diplotype_clustering(\n",
+ " region=\"2RL:28,480,000-28,490,000\",\n",
+ " sample_sets=[\"1.0\"],\n",
+ " sample_query=\"country == 'Ghana'\",\n",
+ " color=\"sample_set\",\n",
+ " cohort_size=None,\n",
+ " width=1000,\n",
+ " height=400,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5979cca3",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "3b9ddb1005cd06989fd869b9e3d566470f1be01faa610bb17d64e58e32302e8b"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/tests/anoph/conftest.py b/tests/anoph/conftest.py
index ad15f2c0b..95f5eb006 100644
--- a/tests/anoph/conftest.py
+++ b/tests/anoph/conftest.py
@@ -1016,6 +1016,10 @@ def random_region_str(self, region_size=None):
contig_size = self.contig_sizes[contig]
region_start = randint(1, contig_size)
if region_size:
+ # Ensure we the region span doesn't exceed the contig size.
+ if contig_size - region_start < region_size:
+ region_start = contig_size - region_size
+
region_end = region_start + region_size
else:
region_end = randint(region_start, contig_size)
diff --git a/tests/anoph/test_dipclust.py b/tests/anoph/test_dipclust.py
new file mode 100644
index 000000000..69dc0cb3c
--- /dev/null
+++ b/tests/anoph/test_dipclust.py
@@ -0,0 +1,102 @@
+import random
+import pytest
+from pytest_cases import parametrize_with_cases
+
+from malariagen_data import af1 as _af1
+from malariagen_data import ag3 as _ag3
+from malariagen_data.anoph.dipclust import AnophelesDipClustAnalysis
+
+
+@pytest.fixture
+def ag3_sim_api(ag3_sim_fixture):
+ return AnophelesDipClustAnalysis(
+ url=ag3_sim_fixture.url,
+ config_path=_ag3.CONFIG_PATH,
+ gcs_url=_ag3.GCS_URL,
+ major_version_number=_ag3.MAJOR_VERSION_NUMBER,
+ major_version_path=_ag3.MAJOR_VERSION_PATH,
+ pre=True,
+ aim_metadata_dtype={
+ "aim_species_fraction_arab": "float64",
+ "aim_species_fraction_colu": "float64",
+ "aim_species_fraction_colu_no2l": "float64",
+ "aim_species_gambcolu_arabiensis": object,
+ "aim_species_gambiae_coluzzii": object,
+ "aim_species": object,
+ },
+ gff_gene_type="gene",
+ gff_gene_name_attribute="Name",
+ gff_default_attributes=("ID", "Parent", "Name", "description"),
+ default_site_mask="gamb_colu_arab",
+ results_cache=ag3_sim_fixture.results_cache_path.as_posix(),
+ taxon_colors=_ag3.TAXON_COLORS,
+ virtual_contigs=_ag3.VIRTUAL_CONTIGS,
+ )
+
+
+@pytest.fixture
+def af1_sim_api(af1_sim_fixture):
+ return AnophelesDipClustAnalysis(
+ url=af1_sim_fixture.url,
+ config_path=_af1.CONFIG_PATH,
+ gcs_url=_af1.GCS_URL,
+ major_version_number=_af1.MAJOR_VERSION_NUMBER,
+ major_version_path=_af1.MAJOR_VERSION_PATH,
+ pre=False,
+ gff_gene_type="protein_coding_gene",
+ gff_gene_name_attribute="Note",
+ gff_default_attributes=("ID", "Parent", "Note", "description"),
+ default_site_mask="funestus",
+ results_cache=af1_sim_fixture.results_cache_path.as_posix(),
+ taxon_colors=_af1.TAXON_COLORS,
+ )
+
+
+# N.B., here we use pytest_cases to parametrize tests. Each
+# function whose name begins with "case_" defines a set of
+# inputs to the test functions. See the documentation for
+# pytest_cases for more information, e.g.:
+#
+# https://smarie.github.io/python-pytest-cases/#basic-usage
+#
+# We use this approach here because we want to use fixtures
+# as test parameters, which is otherwise hard to do with
+# pytest alone.
+
+
+def case_ag3_sim(ag3_sim_fixture, ag3_sim_api):
+ return ag3_sim_fixture, ag3_sim_api
+
+
+def case_af1_sim(af1_sim_fixture, af1_sim_api):
+ return af1_sim_fixture, af1_sim_api
+
+
+@pytest.mark.parametrize("distance_metric", ["cityblock", "euclidean"])
+@parametrize_with_cases("fixture,api", cases=".")
+def test_plot_diplotype_clustering(
+ fixture, api: AnophelesDipClustAnalysis, distance_metric
+):
+ # Set up test parameters.
+ all_sample_sets = api.sample_sets()["sample_set"].to_list()
+ linkage_methods = (
+ "single",
+ "complete",
+ "average",
+ "weighted",
+ "centroid",
+ "median",
+ "ward",
+ )
+ sample_queries = (None, "sex_call == 'F'")
+ dipclust_params = dict(
+ region=fixture.random_region_str(region_size=5000),
+ sample_sets=[random.choice(all_sample_sets)],
+ linkage_method=random.choice(linkage_methods),
+ distance_metric=distance_metric,
+ sample_query=random.choice(sample_queries),
+ show=False,
+ )
+
+ # Run checks.
+ api.plot_diplotype_clustering(**dipclust_params)
diff --git a/tests/anoph/test_hapclust.py b/tests/anoph/test_hapclust.py
new file mode 100644
index 000000000..29a5c4087
--- /dev/null
+++ b/tests/anoph/test_hapclust.py
@@ -0,0 +1,98 @@
+import random
+import pytest
+from pytest_cases import parametrize_with_cases
+
+from malariagen_data import af1 as _af1
+from malariagen_data import ag3 as _ag3
+from malariagen_data.anoph.hapclust import AnophelesHapClustAnalysis
+
+
+@pytest.fixture
+def ag3_sim_api(ag3_sim_fixture):
+ return AnophelesHapClustAnalysis(
+ url=ag3_sim_fixture.url,
+ config_path=_ag3.CONFIG_PATH,
+ gcs_url=_ag3.GCS_URL,
+ major_version_number=_ag3.MAJOR_VERSION_NUMBER,
+ major_version_path=_ag3.MAJOR_VERSION_PATH,
+ pre=True,
+ aim_metadata_dtype={
+ "aim_species_fraction_arab": "float64",
+ "aim_species_fraction_colu": "float64",
+ "aim_species_fraction_colu_no2l": "float64",
+ "aim_species_gambcolu_arabiensis": object,
+ "aim_species_gambiae_coluzzii": object,
+ "aim_species": object,
+ },
+ gff_gene_type="gene",
+ gff_gene_name_attribute="Name",
+ gff_default_attributes=("ID", "Parent", "Name", "description"),
+ default_phasing_analysis="gamb_colu_arab",
+ results_cache=ag3_sim_fixture.results_cache_path.as_posix(),
+ taxon_colors=_ag3.TAXON_COLORS,
+ virtual_contigs=_ag3.VIRTUAL_CONTIGS,
+ )
+
+
+@pytest.fixture
+def af1_sim_api(af1_sim_fixture):
+ return AnophelesHapClustAnalysis(
+ url=af1_sim_fixture.url,
+ config_path=_af1.CONFIG_PATH,
+ gcs_url=_af1.GCS_URL,
+ major_version_number=_af1.MAJOR_VERSION_NUMBER,
+ major_version_path=_af1.MAJOR_VERSION_PATH,
+ pre=False,
+ gff_gene_type="protein_coding_gene",
+ gff_gene_name_attribute="Note",
+ gff_default_attributes=("ID", "Parent", "Note", "description"),
+ default_phasing_analysis="funestus",
+ results_cache=af1_sim_fixture.results_cache_path.as_posix(),
+ taxon_colors=_af1.TAXON_COLORS,
+ )
+
+
+# N.B., here we use pytest_cases to parametrize tests. Each
+# function whose name begins with "case_" defines a set of
+# inputs to the test functions. See the documentation for
+# pytest_cases for more information, e.g.:
+#
+# https://smarie.github.io/python-pytest-cases/#basic-usage
+#
+# We use this approach here because we want to use fixtures
+# as test parameters, which is otherwise hard to do with
+# pytest alone.
+
+
+def case_ag3_sim(ag3_sim_fixture, ag3_sim_api):
+ return ag3_sim_fixture, ag3_sim_api
+
+
+def case_af1_sim(af1_sim_fixture, af1_sim_api):
+ return af1_sim_fixture, af1_sim_api
+
+
+@parametrize_with_cases("fixture,api", cases=".")
+def test_plot_haplotype_clustering(fixture, api: AnophelesHapClustAnalysis):
+ # Set up test parameters.
+ all_sample_sets = api.sample_sets()["sample_set"].to_list()
+ linkage_methods = (
+ "single",
+ "complete",
+ "average",
+ "weighted",
+ "centroid",
+ "median",
+ "ward",
+ )
+ sample_queries = (None, "sex_call == 'F'")
+ hapclust_params = dict(
+ region=fixture.random_region_str(region_size=5000),
+ sample_sets=[random.choice(all_sample_sets)],
+ linkage_method=random.choice(linkage_methods),
+ sample_query=random.choice(sample_queries),
+ show=False,
+ )
+
+ # Run checks.
+ api.plot_haplotype_clustering(**hapclust_params)