Skip to content

Commit

Permalink
Faster xarray concatenation (#395)
Browse files Browse the repository at this point in the history
* use faster xarray concatenation

* override attrs

* handle either dask or numpy arrays

* fix typing

* fix import
  • Loading branch information
alimanfoo authored May 11, 2023
1 parent 29b7023 commit 61bb489
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 46 deletions.
24 changes: 12 additions & 12 deletions malariagen_data/ag3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
Region,
da_from_zarr,
init_zarr_store,
xarray_concat,
simple_xarray_concat,
)

# silence dask performance warnings
Expand Down Expand Up @@ -463,7 +463,7 @@ def _snp_calls_for_contig(self, contig, *, sample_set, inline_array, chunks):
chunks=chunks,
)

ds = xr.concat([ds_r, ds_l], dim=DIM_VARIANT)
ds = simple_xarray_concat([ds_r, ds_l], dim=DIM_VARIANT)

return ds

Expand Down Expand Up @@ -493,7 +493,7 @@ def _snp_variants_for_contig(self, contig, *, inline_array, chunks):
max_r = super().genome_sequence(region=contig_r).shape[0]
ds_l["variant_position"] = ds_l["variant_position"] + max_r

ds = xr.concat([ds_r, ds_l], dim=DIM_VARIANT)
ds = simple_xarray_concat([ds_r, ds_l], dim=DIM_VARIANT)

return ds

Expand Down Expand Up @@ -566,7 +566,7 @@ def _haplotypes_for_contig(
if ds_l is not None:
max_r = super().genome_sequence(region=contig_r).shape[0]
ds_l["variant_position"] = ds_l["variant_position"] + max_r
ds = xr.concat([ds_r, ds_l], dim=DIM_VARIANT)
ds = simple_xarray_concat([ds_r, ds_l], dim=DIM_VARIANT)
return ds

return None
Expand Down Expand Up @@ -735,7 +735,7 @@ def cnv_hmm(
ly.append(y)

debug("concatenate data from multiple sample sets")
x = xarray_concat(ly, dim=DIM_SAMPLE)
x = simple_xarray_concat(ly, dim=DIM_SAMPLE)

debug("handle region, do this only once - optimisation")
if r.start is not None or r.end is not None:
Expand All @@ -750,7 +750,7 @@ def cnv_hmm(
lx.append(x)

debug("concatenate data from multiple regions")
ds = xarray_concat(lx, dim=DIM_VARIANT)
ds = simple_xarray_concat(lx, dim=DIM_VARIANT)

debug("handle sample query")
if sample_query is not None:
Expand Down Expand Up @@ -964,7 +964,7 @@ def cnv_coverage_calls(
x = x.isel(variants=loc_region)

lx.append(x)
ds = xarray_concat(lx, dim=DIM_VARIANT)
ds = simple_xarray_concat(lx, dim=DIM_VARIANT)

return ds

Expand Down Expand Up @@ -1122,10 +1122,10 @@ def cnv_discordant_read_calls(
)
ly.append(y)

x = xarray_concat(ly, dim=DIM_SAMPLE)
x = simple_xarray_concat(ly, dim=DIM_SAMPLE)
lx.append(x)

ds = xarray_concat(lx, dim=DIM_VARIANT)
ds = simple_xarray_concat(lx, dim=DIM_VARIANT)

return ds

Expand Down Expand Up @@ -1167,7 +1167,7 @@ def gene_cnv(
if isinstance(region, Region):
region = [region]

ds = xarray_concat(
ds = simple_xarray_concat(
[
self._gene_cnv(
region=r,
Expand Down Expand Up @@ -1557,7 +1557,7 @@ def gene_cnv_frequencies_advanced(
if isinstance(region, Region):
region = [region]

ds = xarray_concat(
ds = simple_xarray_concat(
[
self._gene_cnv_frequencies_advanced(
region=r,
Expand Down Expand Up @@ -2327,7 +2327,7 @@ def aim_calls(
ly.append(y)

debug("concatenate data from multiple sample sets")
ds = xarray_concat(ly, dim=DIM_SAMPLE)
ds = simple_xarray_concat(ly, dim=DIM_SAMPLE)

debug("handle sample query")
if sample_query is not None:
Expand Down
8 changes: 2 additions & 6 deletions malariagen_data/amin1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
locate_region,
read_gff3,
resolve_region,
simple_xarray_concat,
unpack_gff3_attributes,
xarray_concat,
)

GENOME_FEATURES_GFF3_PATH = (
Expand Down Expand Up @@ -265,13 +265,9 @@ def snp_calls(self, region, site_mask=False, inline_array=True, chunks="native")
)
for r in region
]
ds = xarray_concat(
ds = simple_xarray_concat(
datasets,
dim=DIM_VARIANT,
data_vars="minimal",
coords="minimal",
compat="override",
join="override",
)

# apply site filters
Expand Down
12 changes: 6 additions & 6 deletions malariagen_data/anopheles.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@
locate_region,
plotly_discrete_legend,
resolve_region,
simple_xarray_concat,
type_error,
xarray_concat,
)

AA_CHANGE_QUERY = (
Expand Down Expand Up @@ -1956,7 +1956,7 @@ def snp_calls(
ly.append(y)

debug("concatenate data from multiple sample sets")
x = xarray_concat(ly, dim=DIM_SAMPLE)
x = simple_xarray_concat(ly, dim=DIM_SAMPLE)

debug("add variants variables")
v = self._snp_variants_for_contig(
Expand All @@ -1982,7 +1982,7 @@ def snp_calls(
lx.append(x)

debug("concatenate data from multiple regions")
ds = xarray_concat(lx, dim=DIM_VARIANT)
ds = simple_xarray_concat(lx, dim=DIM_VARIANT)

if site_mask is not None:
debug("apply site filters")
Expand Down Expand Up @@ -2114,7 +2114,7 @@ def snp_variants(
lx.append(x)

debug("concatenate data from multiple regions")
ds = xarray_concat(lx, dim=DIM_VARIANT)
ds = simple_xarray_concat(lx, dim=DIM_VARIANT)

debug("apply site filters")
if site_mask is not None:
Expand Down Expand Up @@ -5285,7 +5285,7 @@ def haplotypes(
return None

debug("concatenate data from multiple sample sets")
x = xarray_concat(ly, dim=DIM_SAMPLE)
x = simple_xarray_concat(ly, dim=DIM_SAMPLE)

debug("handle region")
if r.start or r.end:
Expand All @@ -5296,7 +5296,7 @@ def haplotypes(
lx.append(x)

debug("concatenate data from multiple regions")
ds = xarray_concat(lx, dim=DIM_VARIANT)
ds = simple_xarray_concat(lx, dim=DIM_VARIANT)

debug("handle sample query")
if sample_query is not None:
Expand Down
123 changes: 101 additions & 22 deletions malariagen_data/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import re
import sys
import warnings
from collections.abc import Mapping
from enum import Enum
from textwrap import dedent, fill
from typing import IO, Optional, Tuple, Union
from typing import IO, Dict, Hashable, List, Mapping, Optional, Tuple, Union
from urllib.parse import unquote_plus

try:
Expand Down Expand Up @@ -471,27 +470,107 @@ def locate_region(region, pos):
return loc_region


def xarray_concat(
datasets,
dim,
data_vars="minimal",
coords="minimal",
compat="override",
join="override",
**kwargs,
):
def _simple_xarray_concat_arrays(
datasets: List[xr.Dataset], names: List[Hashable], dim: str
) -> Mapping[Hashable, xr.DataArray]:
# Access the first dataset, this will be used as the template for
# any arrays that don't need to be concatenated.
ds0 = datasets[0]

# Set up return value, collection of concatenated arrays.
out: Dict[Hashable, xr.DataArray] = dict()

# Iterate over variable names.
for k in names:
# Access the variable from the virst dataset.
v = ds0[k]

if dim in v.dims:
# Dimension to be concatenated is present, need to concatenate.

# Figure out which axis corresponds to the given dimension.
axis = v.dims.index(dim)

# Access the xarray DataArrays to be concatenated.
xr_arrays = [ds[k] for ds in datasets]

# Check that all arrays have the same dimension as the same axis.
assert all([a.dims[axis] == dim for a in xr_arrays])

# Access the inner arrays - these are either numpy or dask arrays.
inner_arrays = [a.data for a in xr_arrays]

# Concatenate inner arrays, depending on their type.
if isinstance(inner_arrays[0], da.Array):
concatenated_array = da.concatenate(inner_arrays, axis=axis)
else:
concatenated_array = np.concatenate(inner_arrays, axis=axis)

# Store the result.
out[k] = xr.DataArray(data=concatenated_array, dims=v.dims)

else:
# No concatenation is needed, keep the variable from the first dataset.
out[k] = v

return out


def simple_xarray_concat(
datasets: List[xr.Dataset], dim: str, attrs: Optional[Mapping] = None
) -> xr.Dataset:
# Access the first dataset, this will be used as the template for
# any arrays that don't need to be concatenated.
ds0 = datasets[0]

if attrs is None:
# Copy attributes from the first dataset.
attrs = ds0.attrs

if len(datasets) == 1:
return datasets[0]
else:
return xr.concat(
datasets,
dim=dim,
data_vars=data_vars,
coords=coords,
compat=compat,
join=join,
**kwargs,
)
# Fast path, nothing to concatenate.
return ds0

# Concatenate coordinate variables.
coords = _simple_xarray_concat_arrays(
datasets=datasets,
names=list(ds0.coords),
dim=dim,
)

# Concatenate data variables.
data_vars = _simple_xarray_concat_arrays(
datasets=datasets,
names=list(ds0.data_vars),
dim=dim,
)

return xr.Dataset(coords=coords, data_vars=data_vars, attrs=attrs)


# xarray concat() function is very slow, don't use for now
#
# def xarray_concat(
# datasets,
# dim,
# data_vars="minimal",
# coords="minimal",
# compat="override",
# join="override",
# **kwargs,
# ):
# if len(datasets) == 1:
# return datasets[0]
# else:
# return xr.concat(
# datasets,
# dim=dim,
# data_vars=data_vars,
# coords=coords,
# compat=compat,
# join=join,
# **kwargs,
# )


def type_error(
Expand Down

0 comments on commit 61bb489

Please sign in to comment.