Skip to content

Commit

Permalink
Improve performance of haplotype clustering (#450)
Browse files Browse the repository at this point in the history
* add sklearn; free up some constraints

* poetry update

* compute distance matrix only once

* fix scikit-learn dependency

* try fortran order

* expose render_mode

* use sklearn pairwise_distances

* faster hamming

* tidy up

* check notebook

* tweak

* cache haplotype_pairwise_distances

* back out scikit-learn as not needed yet

* fix typing

* tidy

* optimise hamming pdist

* manually merge in changes from #441

* fix typing
  • Loading branch information
alimanfoo authored Nov 28, 2023
1 parent 5198e85 commit 64a3e71
Show file tree
Hide file tree
Showing 7 changed files with 2,091 additions and 1,782 deletions.
7 changes: 7 additions & 0 deletions malariagen_data/anoph/plotly_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@

renderer: TypeAlias = Annotated[Optional[str], "The name of the renderer to use."]

render_mode: TypeAlias = Annotated[
Literal["auto", "svg", "webgl"],
"The type of rendering backend to use. See also https://plotly.com/python/webgl-vs-svg/",
]

render_mode_default: render_mode = "auto"

figure: TypeAlias = Annotated[
Optional[go.Figure], "A plotly figure (only returned if show=False)."
]
Expand Down
294 changes: 195 additions & 99 deletions malariagen_data/anopheles.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
locate_region,
parse_multi_region,
parse_single_region,
pdist_abs_hamming,
plotly_discrete_legend,
region_str,
simple_xarray_concat,
Expand Down Expand Up @@ -6324,6 +6325,98 @@ def plot_xpehh_gwss_track(
else:
return fig

@doc(
summary="""
Compute pairwise distances between haplotypes.
""",
returns=("dist", "phased_samples", "n_snps"),
)
def haplotype_pairwise_distances(
self,
region: base_params.regions,
analysis: hap_params.analysis = DEFAULT,
sample_sets: Optional[base_params.sample_sets] = None,
sample_query: Optional[base_params.sample_query] = None,
cohort_size: Optional[base_params.cohort_size] = None,
random_seed: base_params.random_seed = 42,
) -> Tuple[np.ndarray, np.ndarray, int]:
# Change this name if you ever change the behaviour of this function, to
# invalidate any previously cached data.
name = "haplotype_pairwise_distances"

# Normalize params for consistent hash value.
sample_sets_prepped = self._prep_sample_sets_param(sample_sets=sample_sets)
region_prepped = self._prep_region_cache_param(region=region)
params = dict(
region=region_prepped,
analysis=analysis,
sample_sets=sample_sets_prepped,
sample_query=sample_query,
cohort_size=cohort_size,
random_seed=random_seed,
)

# Try to retrieve results from the cache.
try:
results = self.results_cache_get(name=name, params=params)

except CacheMiss:
results = self._haplotype_pairwise_distances(**params)
self.results_cache_set(name=name, params=params, results=results)

# Unpack results")
dist = results["dist"]
phased_samples = results["phased_samples"]
n_snps = results["n_snps"]

return dist, phased_samples, n_snps

def _haplotype_pairwise_distances(
self,
*,
region,
analysis,
sample_sets,
sample_query,
cohort_size,
random_seed,
):
from scipy.spatial.distance import squareform

# Load haplotypes.
ds_haps = self.haplotypes(
region=region,
analysis=analysis,
sample_query=sample_query,
sample_sets=sample_sets,
cohort_size=cohort_size,
random_seed=random_seed,
)
gt = allel.GenotypeDaskArray(ds_haps["call_genotype"].data)
with self._dask_progress(desc="Load haplotypes"):
ht = gt.to_haplotypes().compute().values

# Compute allele count, remove non-segregating sites.
ac = allel.HaplotypeArray(ht).count_alleles(max_allele=1)
ht_seg = ht[ac.is_segregating()]

# Transpose memory layout for faster hamming distance calculations.
ht_t = np.ascontiguousarray(ht_seg.T)

# Compute pairwise distances.
dist_sq = pdist_abs_hamming(ht_t)
dist = squareform(dist_sq)

# Extract IDs of phased samples. Convert to "U" dtype here
# to allow these to be saved to the results cache.
phased_samples = ds_haps["sample_id"].values.astype("U")

return dict(
dist=dist,
phased_samples=phased_samples,
n_snps=ht.shape[0],
)

@doc(
summary="""
Hierarchically cluster haplotypes in region and produce an interactive plot.
Expand All @@ -6338,49 +6431,102 @@ def plot_haplotype_clustering(
analysis: hap_params.analysis = DEFAULT,
sample_sets: Optional[base_params.sample_sets] = None,
sample_query: Optional[base_params.sample_query] = None,
cohort_size: Optional[base_params.cohort_size] = None,
random_seed: base_params.random_seed = 42,
color: plotly_params.color = None,
symbol: plotly_params.symbol = None,
linkage_method: hapclust_params.linkage_method = hapclust_params.linkage_method_default,
count_sort: hapclust_params.count_sort = True,
distance_sort: hapclust_params.distance_sort = False,
cohort_size: Optional[base_params.cohort_size] = None,
random_seed: base_params.random_seed = 42,
width: plotly_params.width = 1000,
height: plotly_params.height = 500,
show: plotly_params.show = True,
renderer: plotly_params.renderer = None,
render_mode: plotly_params.render_mode = plotly_params.render_mode_default,
**kwargs,
) -> plotly_params.figure:
from scipy.cluster.hierarchy import linkage

from .plotly_dendrogram import create_dendrogram

debug = self._log.debug
# Load sample metadata.
df_samples = self.sample_metadata(
sample_sets=sample_sets, sample_query=sample_query
)

ds_haps = self.haplotypes(
# Set up scatter plotting options.
plot_kwargs: Dict[str, Any] = dict(
template="simple_white",
hover_name="sample_id",
render_mode=render_mode,
)

# Handle the color parameter.
if isinstance(color, str):
if color == "taxon":
# Special handling for taxon color.
self._setup_taxon_colors(plot_kwargs)
elif "cohort_" + color in df_samples.columns:
# Convenience to allow things like "admin1_year" instead of "cohort_admin1_year".
color = "cohort_" + color
if color not in df_samples.columns:
raise ValueError(
f"{color!r} is not a known column in the sample metadata."
)

elif isinstance(color, dict):
# Custom grouping of samples using queries.
df_samples["color"] = ""
for key, value in color.items():
df_samples.loc[df_samples.query(value).index, "color"] = key
color = "color"

# Handle the symbol parameter.
if isinstance(symbol, str):
if "cohort_" + symbol in df_samples.columns:
# Convenience to allow things like "admin1_year" instead of "cohort_admin1_year".
symbol = "cohort_" + symbol
if symbol not in df_samples.columns:
raise ValueError(
f"{symbol!r} is not a known column in the sample metadata."
)

elif isinstance(symbol, dict):
df_samples["symbol"] = ""
for key, value in symbol.items():
df_samples.loc[df_samples.query(value).index, "symbol"] = key
symbol = "symbol"

# Compute pairwise distances.
dist, phased_samples, n_snps = self.haplotype_pairwise_distances(
region=region,
analysis=analysis,
sample_query=sample_query,
sample_sets=sample_sets,
sample_query=sample_query,
cohort_size=cohort_size,
random_seed=random_seed,
)
n_haps = len(phased_samples) * 2

gt = allel.GenotypeDaskArray(ds_haps["call_genotype"].data)
with self._dask_progress(desc="Load haplotypes"):
ht = gt.to_haplotypes().compute()

debug("load sample metadata")
df_samples = self.sample_metadata(
sample_sets=sample_sets, sample_query=sample_query
)
debug("align sample metadata with haplotypes")
phased_samples = ds_haps["sample_id"].values.tolist()
# Align sample metadata with haplotypes.
df_samples_phased = (
df_samples.set_index("sample_id").loc[phased_samples].reset_index()
df_samples.set_index("sample_id").loc[phased_samples.tolist()].reset_index()
)

debug("set up plotting options")
# Set labels as the index which we extract to reorder metadata.
leaf_labels = np.arange(n_haps)

# Create the dendrogram.
fig = create_dendrogram(
dist,
linkagefun=lambda x: linkage(x, method=linkage_method),
labels=leaf_labels,
color_threshold=0,
count_sort=count_sort,
distance_sort=distance_sort,
)

# Configure hover data.
hover_data = [
"sample_id",
"partner_sample_id",
Expand All @@ -6394,56 +6540,55 @@ def plot_haplotype_clustering(
"year",
"month",
]

if color and color not in hover_data:
hover_data.append(color)
if symbol and symbol not in hover_data:
hover_data.append(symbol)
plot_kwargs["hover_data"] = hover_data

plot_kwargs = dict(
template="simple_white",
hover_name="sample_id",
hover_data=hover_data,
render_mode="svg",
)
# Apply any user overrides.
plot_kwargs.update(kwargs)

debug("special handling for taxon color")
if color == "taxon":
self._setup_taxon_colors(plot_kwargs)
# Repeat the dataframe so there is one row of metadata for each haplotype.
df_haps = pd.DataFrame(np.repeat(df_samples_phased.values, 2, axis=0))
df_haps.columns = df_samples_phased.columns

debug("apply any user overrides")
plot_kwargs.update(kwargs)
# Select only columns in hover_data.
df_haps = df_haps[hover_data]

debug("Create dendrogram with plotly")
# set labels as the index which we extract to reorder metadata
leaf_labels = np.arange(ht.shape[1])
# get the max distance, required to set xmin, xmax, which we need xmin to be slightly below 0
max_dist = _get_max_hamming_distance(
ht.T, metric="hamming", linkage_method=linkage_method
)
# noinspection PyTypeChecker
fig = create_dendrogram(
ht.T,
distfun=lambda x: _hamming_to_snps(x),
linkagefun=lambda x: linkage(x, method=linkage_method),
labels=leaf_labels,
color_threshold=0,
count_sort=count_sort,
distance_sort=distance_sort,
# Reorder haplotype metadata to align with haplotype clustering.
df_haps = df_haps.loc[fig.layout.xaxis["ticktext"]]

# Add scatter plot to draw the leaves.
fig.add_traces(
list(
px.scatter(
df_haps,
x=fig.layout.xaxis["tickvals"],
y=np.repeat(-1, n_haps),
color=color,
symbol=symbol,
**plot_kwargs,
).select_traces()
)
)

# Add hover for lines to show distance.
fig.update_traces(
hoverinfo="y",
line=dict(width=0.5, color="black"),
)

# Add plot title.
title_lines = []
if sample_sets is not None:
title_lines.append(f"sample sets: {sample_sets}")
if sample_query is not None:
title_lines.append(f"sample query: {sample_query}")
title_lines.append(f"genomic region: {region} ({ht.shape[0]} SNPs)")
title_lines.append(f"genomic region: {region} ({n_snps:,} SNPs)")
title = "<br>".join(title_lines)

# Style the figure.
fig.update_layout(
width=width,
height=height,
Expand All @@ -6456,34 +6601,12 @@ def plot_haplotype_clustering(
showlegend=True,
)

# Repeat the dataframe so there is one row of metadata for each haplotype
df_samples_phased_haps = pd.DataFrame(
np.repeat(df_samples_phased.values, 2, axis=0)
)
df_samples_phased_haps.columns = df_samples_phased.columns
# select only columns in hover_data
df_samples_phased_haps = df_samples_phased_haps[hover_data]
debug("Reorder haplotype metadata to align with haplotype clustering")
df_samples_phased_haps = df_samples_phased_haps.loc[
fig.layout.xaxis["ticktext"]
]
# Style axes.
fig.update_xaxes(mirror=False, showgrid=True, showticklabels=False, ticks="")
fig.update_yaxes(
mirror=False, showgrid=True, showline=True, range=[-2, max_dist + 1]
)

debug("Add scatter plot with hover text")
fig.add_traces(
list(
px.scatter(
df_samples_phased_haps,
x=fig.layout.xaxis["tickvals"],
y=np.repeat(-1, len(ht.T)),
color=color,
symbol=symbol,
**plot_kwargs,
).select_traces()
)
mirror=False,
showgrid=True,
showline=True,
)

if show: # pragma: no cover
Expand Down Expand Up @@ -6821,33 +6944,6 @@ def display_tap_node_data(data):
return app.run_server(**run_params)


def _hamming_to_snps(h):
"""
Cluster haplotype array and return the number of SNP differences
"""
from scipy.spatial.distance import pdist

dist = pdist(h, metric="hamming")
dist *= h.shape[1]
return dist


def _get_max_hamming_distance(h, metric="hamming", linkage_method="single"):
"""
Find the maximum hamming distance between haplotypes
"""
from scipy.cluster.hierarchy import linkage

z = linkage(h, metric=metric, method=linkage_method)

# Get the distances column
dists = z[:, 2]
# Convert to the number of SNP differences
dists *= h.shape[1]
# Return the maximum
return dists.max()


def _diplotype_frequencies(gt):
"""Compute diplotype frequencies, returning a dictionary that maps
diplotype hash values to frequencies."""
Expand Down
Loading

0 comments on commit 64a3e71

Please sign in to comment.