diff --git a/malariagen_data/ag3.py b/malariagen_data/ag3.py index b451be7fd..fd2edd561 100644 --- a/malariagen_data/ag3.py +++ b/malariagen_data/ag3.py @@ -20,7 +20,7 @@ Region, da_from_zarr, init_zarr_store, - xarray_concat, + simple_xarray_concat, ) # silence dask performance warnings @@ -463,7 +463,7 @@ def _snp_calls_for_contig(self, contig, *, sample_set, inline_array, chunks): chunks=chunks, ) - ds = xr.concat([ds_r, ds_l], dim=DIM_VARIANT) + ds = simple_xarray_concat([ds_r, ds_l], dim=DIM_VARIANT) return ds @@ -493,7 +493,7 @@ def _snp_variants_for_contig(self, contig, *, inline_array, chunks): max_r = super().genome_sequence(region=contig_r).shape[0] ds_l["variant_position"] = ds_l["variant_position"] + max_r - ds = xr.concat([ds_r, ds_l], dim=DIM_VARIANT) + ds = simple_xarray_concat([ds_r, ds_l], dim=DIM_VARIANT) return ds @@ -566,7 +566,7 @@ def _haplotypes_for_contig( if ds_l is not None: max_r = super().genome_sequence(region=contig_r).shape[0] ds_l["variant_position"] = ds_l["variant_position"] + max_r - ds = xr.concat([ds_r, ds_l], dim=DIM_VARIANT) + ds = simple_xarray_concat([ds_r, ds_l], dim=DIM_VARIANT) return ds return None @@ -735,7 +735,7 @@ def cnv_hmm( ly.append(y) debug("concatenate data from multiple sample sets") - x = xarray_concat(ly, dim=DIM_SAMPLE) + x = simple_xarray_concat(ly, dim=DIM_SAMPLE) debug("handle region, do this only once - optimisation") if r.start is not None or r.end is not None: @@ -750,7 +750,7 @@ def cnv_hmm( lx.append(x) debug("concatenate data from multiple regions") - ds = xarray_concat(lx, dim=DIM_VARIANT) + ds = simple_xarray_concat(lx, dim=DIM_VARIANT) debug("handle sample query") if sample_query is not None: @@ -964,7 +964,7 @@ def cnv_coverage_calls( x = x.isel(variants=loc_region) lx.append(x) - ds = xarray_concat(lx, dim=DIM_VARIANT) + ds = simple_xarray_concat(lx, dim=DIM_VARIANT) return ds @@ -1122,10 +1122,10 @@ def cnv_discordant_read_calls( ) ly.append(y) - x = xarray_concat(ly, dim=DIM_SAMPLE) + x = simple_xarray_concat(ly, dim=DIM_SAMPLE) lx.append(x) - ds = xarray_concat(lx, dim=DIM_VARIANT) + ds = simple_xarray_concat(lx, dim=DIM_VARIANT) return ds @@ -1167,7 +1167,7 @@ def gene_cnv( if isinstance(region, Region): region = [region] - ds = xarray_concat( + ds = simple_xarray_concat( [ self._gene_cnv( region=r, @@ -1557,7 +1557,7 @@ def gene_cnv_frequencies_advanced( if isinstance(region, Region): region = [region] - ds = xarray_concat( + ds = simple_xarray_concat( [ self._gene_cnv_frequencies_advanced( region=r, @@ -2327,7 +2327,7 @@ def aim_calls( ly.append(y) debug("concatenate data from multiple sample sets") - ds = xarray_concat(ly, dim=DIM_SAMPLE) + ds = simple_xarray_concat(ly, dim=DIM_SAMPLE) debug("handle sample query") if sample_query is not None: diff --git a/malariagen_data/amin1.py b/malariagen_data/amin1.py index 3633b584c..53bd37ecb 100644 --- a/malariagen_data/amin1.py +++ b/malariagen_data/amin1.py @@ -16,8 +16,8 @@ locate_region, read_gff3, resolve_region, + simple_xarray_concat, unpack_gff3_attributes, - xarray_concat, ) GENOME_FEATURES_GFF3_PATH = ( @@ -265,13 +265,9 @@ def snp_calls(self, region, site_mask=False, inline_array=True, chunks="native") ) for r in region ] - ds = xarray_concat( + ds = simple_xarray_concat( datasets, dim=DIM_VARIANT, - data_vars="minimal", - coords="minimal", - compat="override", - join="override", ) # apply site filters diff --git a/malariagen_data/anopheles.py b/malariagen_data/anopheles.py index ec939d135..e3ba4abcf 100644 --- a/malariagen_data/anopheles.py +++ b/malariagen_data/anopheles.py @@ -47,8 +47,8 @@ locate_region, plotly_discrete_legend, resolve_region, + simple_xarray_concat, type_error, - xarray_concat, ) AA_CHANGE_QUERY = ( @@ -1956,7 +1956,7 @@ def snp_calls( ly.append(y) debug("concatenate data from multiple sample sets") - x = xarray_concat(ly, dim=DIM_SAMPLE) + x = simple_xarray_concat(ly, dim=DIM_SAMPLE) debug("add variants variables") v = self._snp_variants_for_contig( @@ -1982,7 +1982,7 @@ def snp_calls( lx.append(x) debug("concatenate data from multiple regions") - ds = xarray_concat(lx, dim=DIM_VARIANT) + ds = simple_xarray_concat(lx, dim=DIM_VARIANT) if site_mask is not None: debug("apply site filters") @@ -2114,7 +2114,7 @@ def snp_variants( lx.append(x) debug("concatenate data from multiple regions") - ds = xarray_concat(lx, dim=DIM_VARIANT) + ds = simple_xarray_concat(lx, dim=DIM_VARIANT) debug("apply site filters") if site_mask is not None: @@ -5285,7 +5285,7 @@ def haplotypes( return None debug("concatenate data from multiple sample sets") - x = xarray_concat(ly, dim=DIM_SAMPLE) + x = simple_xarray_concat(ly, dim=DIM_SAMPLE) debug("handle region") if r.start or r.end: @@ -5296,7 +5296,7 @@ def haplotypes( lx.append(x) debug("concatenate data from multiple regions") - ds = xarray_concat(lx, dim=DIM_VARIANT) + ds = simple_xarray_concat(lx, dim=DIM_VARIANT) debug("handle sample query") if sample_query is not None: diff --git a/malariagen_data/util.py b/malariagen_data/util.py index 9596c25d6..cec144ca8 100644 --- a/malariagen_data/util.py +++ b/malariagen_data/util.py @@ -4,10 +4,9 @@ import re import sys import warnings -from collections.abc import Mapping from enum import Enum from textwrap import dedent, fill -from typing import IO, Optional, Tuple, Union +from typing import IO, Dict, Hashable, List, Mapping, Optional, Tuple, Union from urllib.parse import unquote_plus try: @@ -471,27 +470,107 @@ def locate_region(region, pos): return loc_region -def xarray_concat( - datasets, - dim, - data_vars="minimal", - coords="minimal", - compat="override", - join="override", - **kwargs, -): +def _simple_xarray_concat_arrays( + datasets: List[xr.Dataset], names: List[Hashable], dim: str +) -> Mapping[Hashable, xr.DataArray]: + # Access the first dataset, this will be used as the template for + # any arrays that don't need to be concatenated. + ds0 = datasets[0] + + # Set up return value, collection of concatenated arrays. + out: Dict[Hashable, xr.DataArray] = dict() + + # Iterate over variable names. + for k in names: + # Access the variable from the virst dataset. + v = ds0[k] + + if dim in v.dims: + # Dimension to be concatenated is present, need to concatenate. + + # Figure out which axis corresponds to the given dimension. + axis = v.dims.index(dim) + + # Access the xarray DataArrays to be concatenated. + xr_arrays = [ds[k] for ds in datasets] + + # Check that all arrays have the same dimension as the same axis. + assert all([a.dims[axis] == dim for a in xr_arrays]) + + # Access the inner arrays - these are either numpy or dask arrays. + inner_arrays = [a.data for a in xr_arrays] + + # Concatenate inner arrays, depending on their type. + if isinstance(inner_arrays[0], da.Array): + concatenated_array = da.concatenate(inner_arrays, axis=axis) + else: + concatenated_array = np.concatenate(inner_arrays, axis=axis) + + # Store the result. + out[k] = xr.DataArray(data=concatenated_array, dims=v.dims) + + else: + # No concatenation is needed, keep the variable from the first dataset. + out[k] = v + + return out + + +def simple_xarray_concat( + datasets: List[xr.Dataset], dim: str, attrs: Optional[Mapping] = None +) -> xr.Dataset: + # Access the first dataset, this will be used as the template for + # any arrays that don't need to be concatenated. + ds0 = datasets[0] + + if attrs is None: + # Copy attributes from the first dataset. + attrs = ds0.attrs + if len(datasets) == 1: - return datasets[0] - else: - return xr.concat( - datasets, - dim=dim, - data_vars=data_vars, - coords=coords, - compat=compat, - join=join, - **kwargs, - ) + # Fast path, nothing to concatenate. + return ds0 + + # Concatenate coordinate variables. + coords = _simple_xarray_concat_arrays( + datasets=datasets, + names=list(ds0.coords), + dim=dim, + ) + + # Concatenate data variables. + data_vars = _simple_xarray_concat_arrays( + datasets=datasets, + names=list(ds0.data_vars), + dim=dim, + ) + + return xr.Dataset(coords=coords, data_vars=data_vars, attrs=attrs) + + +# xarray concat() function is very slow, don't use for now +# +# def xarray_concat( +# datasets, +# dim, +# data_vars="minimal", +# coords="minimal", +# compat="override", +# join="override", +# **kwargs, +# ): +# if len(datasets) == 1: +# return datasets[0] +# else: +# return xr.concat( +# datasets, +# dim=dim, +# data_vars=data_vars, +# coords=coords, +# compat=compat, +# join=join, +# **kwargs, +# ) def type_error(