Skip to content

Commit

Permalink
Improve G123 and H12 performance (#493)
Browse files Browse the repository at this point in the history
* explore faster diplotype frequencies

* accelerate g123

* optimise h12

* typing

* reinstate results cache in notebooks
  • Loading branch information
alimanfoo authored Jan 5, 2024
1 parent f4c263f commit f3e44d5
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 91 deletions.
4 changes: 2 additions & 2 deletions malariagen_data/anoph/base_params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""General parameters common to many functions in the public API."""

from typing import Final, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Final, List, Mapping, Optional, Sequence, Tuple, Union, Callable

from typing_extensions import Annotated, TypeAlias

Expand Down Expand Up @@ -226,7 +226,7 @@ def validate_sample_selection_params(
inline_array_default: inline_array = True

chunks: TypeAlias = Annotated[
Union[str, Tuple[int, ...]],
Union[str, Tuple[int, ...], Callable[[Tuple[int, ...]], Tuple[int, ...]]],
"""
If 'auto' let dask decide chunk size. If 'native' use native zarr
chunks. Also, can be a target size, e.g., '200 MiB', or a tuple of
Expand Down
130 changes: 87 additions & 43 deletions malariagen_data/anopheles.py
Original file line number Diff line number Diff line change
Expand Up @@ -2533,7 +2533,7 @@ def _h12_calibration(

calibration_runs: Dict[str, np.ndarray] = dict()
for window_size in self._progress(window_sizes, desc="Compute H12"):
h1, h12, h123, h2_h1 = allel.moving_garud_h(ht, size=window_size)
h12 = allel.moving_statistic(ht, statistic=_garud_h12, size=window_size)
calibration_runs[str(window_size)] = h12

return calibration_runs
Expand Down Expand Up @@ -2712,7 +2712,7 @@ def _h12_gwss(

with self._spinner(desc="Compute H12"):
# Compute H12.
h1, h12, h123, h2_h1 = allel.moving_garud_h(ht, size=window_size)
h12 = allel.moving_statistic(ht, statistic=_garud_h12, size=window_size)

# Compute window midpoints.
pos = ds_haps["variant_position"].values
Expand Down Expand Up @@ -3684,33 +3684,6 @@ def plot_ihs_gwss(
else:
return fig

def _garud_g123(self, gt):
"""Compute Garud's G123."""

# compute diplotype frequencies
frq_counter = _diplotype_frequencies(gt)

# convert to array of sorted frequencies
f = np.sort(np.fromiter(frq_counter.values(), dtype=float))[::-1]

# compute G123
g123 = np.sum(f[:3]) ** 2 + np.sum(f[3:] ** 2)

# These other statistics are not currently needed, but leaving here
# commented out for future reference...

# compute G1
# g1 = np.sum(f**2)

# compute G12
# g12 = np.sum(f[:2]) ** 2 + np.sum(f[2:] ** 2) # type: ignore[index]

# compute G2/G1
# g2 = g1 - f[0] ** 2 # type: ignore[index]
# g2_g1 = g2 / g1

return g123

@check_types
@doc(
summary="Run a G123 genome-wide selection scan.",
Expand Down Expand Up @@ -3808,9 +3781,7 @@ def _g123_gwss(
)

with self._spinner("Compute G123"):
g123 = allel.moving_statistic(
gt, statistic=self._garud_g123, size=window_size
)
g123 = allel.moving_statistic(gt, statistic=_garud_g123, size=window_size)
x = allel.moving_statistic(pos, statistic=np.mean, size=window_size)

results = dict(x=x, g123=g123)
Expand Down Expand Up @@ -4021,10 +3992,11 @@ def _load_data_for_g123(
chunks=chunks,
)

gt = allel.GenotypeDaskArray(ds_snps["call_genotype"].data)
with self._dask_progress(desc="Load genotypes"):
gt = gt.compute()
pos = ds_snps["variant_position"].values
gt = ds_snps["call_genotype"].data.compute()

with self._dask_progress(desc="Load SNP positions"):
pos = ds_snps["variant_position"].data.compute()

if sites in self.phasing_analysis_ids:
with self._spinner("Subsetting to selected sites"):
Expand All @@ -4041,7 +4013,7 @@ def _load_data_for_g123(

elif sites == "segregating":
with self._spinner("Subsetting to segregating sites"):
ac = gt.count_alleles(max_allele=3)
ac = allel.GenotypeArray(gt).count_alleles(max_allele=3)
seg = ac.is_segregating()
pos = pos[seg]
gt = gt.compress(seg, axis=0)
Expand Down Expand Up @@ -4134,9 +4106,7 @@ def _g123_calibration(

calibration_runs: Dict[str, np.ndarray] = dict()
for window_size in self._progress(window_sizes, desc="Compute G123"):
g123 = allel.moving_statistic(
gt, statistic=self._garud_g123, size=window_size
)
g123 = allel.moving_statistic(gt, statistic=_garud_g123, size=window_size)
calibration_runs[str(window_size)] = g123

return calibration_runs
Expand Down Expand Up @@ -5632,28 +5602,102 @@ def _unrooted_tree_layout_equal_angle(
leaf_nodes.append([x, y, tree_node.index, leaf_color])


@numba.njit
def _hash_columns(x):
# Here we want to compute a hash for each column in the
# input array. However, we assume the input array is in
# C contiguous order, and therefore we scan the array
# and perform the computation in this order for more
# efficient memory access.
#
# This function uses the DJBX33A hash function which
# is much faster than computing Python hashes of
# bytes, as discovered by Tom White in work on sgkit.
m = x.shape[0]
n = x.shape[1]
out = np.empty(n, dtype=np.int64)
out[:] = 5381
for i in range(m):
for j in range(n):
v = x[i, j]
out[j] = out[j] * 33 + v
return out


def _diplotype_frequencies(gt):
"""Compute diplotype frequencies, returning a dictionary that maps
diplotype hash values to frequencies."""
# TODO could use faster hashing

# Here are some optimisations to speed up the computation
# of diplotype hashes. First we combine the two int8 alleles
# in each genotype call into a single int16.
m = gt.shape[0]
n = gt.shape[1]
hashes = [hash(gt[:, i].tobytes()) for i in range(n)]
x = np.asarray(gt).view(np.int16).reshape((m, n))

# Now call optimised hashing function.
hashes = _hash_columns(x)

# Now compute counts and frequencies of distinct haplotypes.
counts = Counter(hashes)
freqs = {key: count / n for key, count in counts.items()}

return freqs


def _garud_g123(gt):
"""Compute Garud's G123."""

# compute diplotype frequencies
frq_counter = _diplotype_frequencies(gt)

# convert to array of sorted frequencies
f = np.sort(np.fromiter(frq_counter.values(), dtype=float))[::-1]

# compute G123
g123 = np.sum(f[:3]) ** 2 + np.sum(f[3:] ** 2)

# These other statistics are not currently needed, but leaving here
# commented out for future reference...

# compute G1
# g1 = np.sum(f**2)

# compute G12
# g12 = np.sum(f[:2]) ** 2 + np.sum(f[2:] ** 2) # type: ignore[index]

# compute G2/G1
# g2 = g1 - f[0] ** 2 # type: ignore[index]
# g2_g1 = g2 / g1

return g123


def _haplotype_frequencies(h):
"""Compute haplotype frequencies, returning a dictionary that maps
haplotype hash values to frequencies."""
# TODO could use faster hashing
n = h.shape[1]
hashes = [hash(h[:, i].tobytes()) for i in range(n)]
hashes = _hash_columns(np.asarray(h))
counts = Counter(hashes)
freqs = {key: count / n for key, count in counts.items()}
return freqs


def _garud_h12(ht):
"""Compute Garud's H12."""

# compute diplotype frequencies
frq_counter = _haplotype_frequencies(ht)

# convert to array of sorted frequencies
f = np.sort(np.fromiter(frq_counter.values(), dtype=float))[::-1]

# compute H12
h12 = np.sum(f[:2]) ** 2 + np.sum(f[2:] ** 2)

return h12


def _haplotype_joint_frequencies(ha, hb):
"""Compute the joint frequency of haplotypes in two difference
cohorts. Returns a dictionary mapping haplotype hash values to
Expand Down
20 changes: 15 additions & 5 deletions malariagen_data/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from functools import wraps
from inspect import getcallargs
from textwrap import dedent, fill
from typing import IO, Dict, Hashable, List, Mapping, Optional, Tuple, Union
from typing import IO, Dict, Hashable, List, Mapping, Optional, Tuple, Union, Callable
from urllib.parse import unquote_plus
from numpy.testing import assert_allclose, assert_array_equal

Expand Down Expand Up @@ -168,18 +168,28 @@ class SiteClass(Enum):


def da_from_zarr(
z: zarr.core.Array, inline_array: bool, chunks: Union[str, Tuple[int, ...]] = "auto"
z: zarr.core.Array,
inline_array: bool,
chunks: Union[
str, Tuple[int, ...], Callable[[Tuple[int, ...]], Tuple[int, ...]]
] = "auto",
) -> da.Array:
"""Utility function for turning a zarr array into a dask array.
N.B., dask does have its own from_zarr() function, but we roll
our own here to get a little more control.
"""
if chunks == "native" or z.dtype == object:
if callable(chunks):
dask_chunks: Union[Tuple[int, ...], str] = chunks(z.chunks)
elif chunks == "native" or z.dtype == object:
# N.B., dask does not support "auto" chunks for arrays with object dtype
chunks = z.chunks
kwargs = dict(chunks=chunks, fancy=False, lock=False, inline_array=inline_array)
dask_chunks = z.chunks
else:
dask_chunks = chunks
kwargs = dict(
chunks=dask_chunks, fancy=False, lock=False, inline_array=inline_array
)
try:
d = da.from_array(z, **kwargs)
except TypeError:
Expand Down
21 changes: 16 additions & 5 deletions notebooks/plot_g123_gwss.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,24 @@
{
"cell_type": "code",
"execution_count": null,
"id": "db33f6d9",
"id": "44ff6b91-cc69-4683-b408-45f199816d9a",
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"contig = \"3L\"\n",
"sample_set = \"AG1000G-BF-A\"\n",
"sample_query = 'taxon == \"gambiae\"'\n",
"site_mask = \"gamb_colu\"\n",
"\n",
"site_mask = \"gamb_colu\"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "db33f6d9",
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"ag3.plot_g123_calibration(\n",
" contig=contig,\n",
" sites=site_mask,\n",
Expand Down Expand Up @@ -130,6 +138,7 @@
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"ag3.plot_g123_gwss_track(\n",
" contig=\"2L\",\n",
" window_size=1_000,\n",
Expand All @@ -150,6 +159,7 @@
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"ag3.plot_g123_gwss(\n",
" contig=\"2L\",\n",
" window_size=1_000,\n",
Expand All @@ -165,10 +175,11 @@
{
"cell_type": "code",
"execution_count": null,
"id": "f77bbae1",
"id": "e6d69754-951c-48be-9eaf-c3ff2ead80b4",
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"ag3.plot_g123_gwss(\n",
" contig=\"2L\",\n",
" window_size=1_000,\n",
Expand Down
Loading

0 comments on commit f3e44d5

Please sign in to comment.