diff --git a/docs/api.rst b/docs/api.rst index 512a3ccda..c013a1c0a 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -140,6 +140,7 @@ Utilities .. autosummary:: :toctree: generated/ + bcftools_filter convert_call_to_index convert_probability_to_call display_genotypes diff --git a/pyproject.toml b/pyproject.toml index 7ef0d03c4..a592d94d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "fsspec != 2021.6.*", "scikit-learn", "pandas", + "vcztools", "setuptools >= 41.2", # For pkg_resources ] dynamic = ["version"] @@ -69,7 +70,7 @@ fail_under = 100 [tool.pytest.ini_options] addopts = "--doctest-modules --ignore=validation --cov-fail-under=100" -norecursedirs = [".eggs", "build", "docs"] +norecursedirs = [".eggs", ".hypothesis", "build", "docs"] filterwarnings = ["error", "ignore::DeprecationWarning"] diff --git a/requirements-numpy1.txt b/requirements-numpy1.txt index b4da65449..c5e48e8c2 100644 --- a/requirements-numpy1.txt +++ b/requirements-numpy1.txt @@ -7,3 +7,4 @@ zarr >= 2.10.0, != 2.11.0, != 2.11.1, != 2.11.2, < 3 fsspec != 2021.6.* scikit-learn pandas +vcztools diff --git a/requirements.txt b/requirements.txt index 8b130d40d..c1dfe708c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ zarr >= 2.10.0, != 2.11.0, != 2.11.1, != 2.11.2, < 3 fsspec != 2021.6.* scikit-learn pandas +vcztools diff --git a/sgkit/__init__.py b/sgkit/__init__.py index d5bf96a37..ef8eae117 100644 --- a/sgkit/__init__.py +++ b/sgkit/__init__.py @@ -1,5 +1,6 @@ from .display import display_genotypes, display_pedigree from .distance.api import pairwise_distance +from .filtering import bcftools_filter from .io.dataset import load_dataset, save_dataset from .model import ( DIM_ALLELE, @@ -78,6 +79,7 @@ "DIM_PLOIDY", "DIM_SAMPLE", "DIM_VARIANT", + "bcftools_filter", "call_allele_frequencies", "create_genotype_call_dataset", "cohort_allele_frequencies", diff --git a/sgkit/filtering.py b/sgkit/filtering.py new file mode 100644 index 000000000..d42d8adae --- /dev/null +++ b/sgkit/filtering.py @@ -0,0 +1,167 @@ +from typing import Optional, Union + +import numpy as np +import xarray as xr +from vcztools import filter as filter_mod +from vcztools.regions import ( + parse_regions, + parse_targets, + regions_to_chunk_indexes, + regions_to_selection, +) +from vcztools.samples import parse_samples + + +def bcftools_filter( + ds, + *, + regions: Optional[str] = None, + targets: Optional[str] = None, + include: Optional[str] = None, + exclude: Optional[str] = None, + samples: Union[list[str], str, None] = None, +): + """Filter a dataset using `bcftools`-style expressions. + + The dataset can be subset in the variants dimension to a set + of genomic regions and/or filter expressions that match sites + using the `bcftools` `expression language `_ + Additionally, the samples dimension can optionally be subset + to a list of sample IDs. + + Parameters + ---------- + ds + The dataset to be filtered. + regions + The regions to include specified as a comma-separated list of + region strings. Corresponds to `bcftools` `-r`/`--regions`. + targets + The region targets to include specified as a comma-separated list + of region strings. Corresponds to `bcftools` `-t`/`--targets`. + include + Filter expression to include variant sites. + Corresponds to `bcftools` `-i`/`--include`. + exclude + Filter expression to exclude variant sites. + Corresponds to `bcftools` `-e`/`--exclude`. + samples + The samples to include specified as a comma-separated list of + sample IDs, or a list of sample IDs. + Corresponds to `bcftools` `-s`/`--samples`. + + Returns + ------- + A dataset with a subset of variants and/or samples. + """ + if regions is not None or targets is not None: + ds = _filter_regions(ds, regions, targets) + if include is not None or exclude is not None: + ds = _filter_expressions(ds, include, exclude) + if samples is not None: + ds = _filter_samples(ds, samples) + return ds + + +def _filter_regions(ds, regions, targets): + # Use the region index to find the chunks that overlap specified regions or + # targets, and + # 1. convert that into a single coarse slice (with chunk granularity) + # in the variants dimension, then + # 2. find the mask for each to each chunk in the coarse slice to + # ensure the exact region is selected. + # This works well for small, dense regions but not for sparse regions that + # span the genome. + + # Step 1: find smallest single slice that covers regions + + contigs = ds["contig_id"].values.astype("U").tolist() + regions_pyranges = parse_regions(regions, contigs) + targets_pyranges, complement = parse_targets(targets, contigs) + + region_index = ds["region_index"].values + chunk_indexes = regions_to_chunk_indexes( + regions_pyranges, + targets_pyranges, + complement, + region_index, + ) + + if len(chunk_indexes) == 0: + # zero variants + return ds.isel(variants=slice(0, 0)) + + # check chunks are equally sized + chunks = ds.chunks["variants"] + if len(chunks) > 1: + # ignore last chunk since it may be smaller + chunks = chunks[:-1] + if not all(chunk == chunks[0] for chunk in chunks): + raise ValueError( + f'Dataset must have uniform chunk sizes in variants dimension, but was {ds.chunks["variants"]}.' + ) + variant_chunksize = chunks[0] + + variant_slice = slice( + int(chunk_indexes[0] * variant_chunksize), + max(ds.sizes["variants"], int(chunk_indexes[-1] * variant_chunksize + 1)), + ) + + ds_sliced = ds.isel(variants=variant_slice) + + # Step 2: filter each chunk in sliced dataset + + def regions_to_mask(*args): + variant_selection = regions_to_selection( + regions_pyranges, targets_pyranges, complement, *args + ) + variant_mask = np.zeros(args[0].shape[0], dtype=bool) + variant_mask[variant_selection] = 1 + return variant_mask + + # restrict to fields needed by regions_to_selection + ds_filter_fields = ds_sliced[ + ["variant_contig", "variant_position", "variant_length"] + ] + data_arrays = tuple(ds_filter_fields[n] for n in ds_filter_fields.data_vars) + variant_filter = xr.apply_ufunc( + regions_to_mask, *data_arrays, output_dtypes=[bool], dask="parallelized" + ) + + # note we have to call compute here since xarray indexing with a Dask or Cubed + # boolean array is not supported + return ds_sliced.isel(variants=variant_filter.compute()) + + +def _filter_expressions(ds, include, exclude): + filter_expr = filter_mod.FilterExpression( + field_names=set(ds.data_vars), include=include, exclude=exclude + ) + + def compute_call_mask(*args): + chunk_data = {k: v for k, v in zip(filter_expr.referenced_fields, args)} + call_mask = filter_expr.evaluate(chunk_data) + return call_mask + + # restrict to fields needed by filter expression + ds_filter_fields = ds[list(filter_expr.referenced_fields)] + + # note that this will only work if chunked only in the variants dimension + # may need to merge chunks in samples dim + + data_arrays = tuple(ds_filter_fields[n] for n in ds_filter_fields.data_vars) + da = xr.apply_ufunc( + compute_call_mask, *data_arrays, output_dtypes=[bool], dask="parallelized" + ) + ds["call_mask"] = da + + # filter to variants where at least one sample has been selected + # note we have to call compute here since xarray indexing with a Dask or Cubed + # boolean array is not supported + return ds.isel(variants=ds.call_mask.any(dim="samples").compute()) + + +def _filter_samples(ds, samples): + all_samples = ds["sample_id"].values + _, sample_selection = parse_samples(samples, all_samples) + return ds.isel(samples=sample_selection) diff --git a/sgkit/tests/io/test_filtering.py b/sgkit/tests/io/test_filtering.py new file mode 100644 index 000000000..e9319c998 --- /dev/null +++ b/sgkit/tests/io/test_filtering.py @@ -0,0 +1,78 @@ +import pytest +from numpy.testing import assert_array_equal + +pytest.importorskip("bio2zarr") +from bio2zarr import vcf + +from sgkit import bcftools_filter, load_dataset + + +@pytest.fixture() +def vcz(shared_datadir, tmp_path): + vcf_path = shared_datadir / "sample.vcf.gz" + vcz_path = tmp_path.joinpath("sample.vcz").as_posix() + + vcf.convert( + [vcf_path], + vcz_path, + variants_chunk_size=5, + samples_chunk_size=2, + ) + + return vcz_path + + +def test_bcftools_filter_regions(vcz): + ds = load_dataset(vcz) + ds = bcftools_filter(ds, regions="20:1230236-") + + assert ds.sizes["variants"] == 3 + assert ds.sizes["samples"] == 3 + assert_array_equal(ds["variant_position"], [1230237, 1234567, 1235237]) + + +def test_bcftools_filter_empty_regions(vcz): + ds = load_dataset(vcz) + ds = bcftools_filter(ds, regions="20:1-2") + + assert ds.sizes["variants"] == 0 + assert ds.sizes["samples"] == 3 + assert len(ds["variant_position"]) == 0 + + +def test_bcftools_filter_expressions(vcz): + ds = load_dataset(vcz) + ds = bcftools_filter(ds, include="FMT/DP>3") + + assert ds.sizes["variants"] == 5 + assert ds.sizes["samples"] == 3 + assert_array_equal(ds["variant_contig"], [1, 1, 1, 1, 1]) + assert_array_equal( + ds["variant_position"], [14370, 17330, 1110696, 1230237, 1234567] + ) + + +def test_bcftools_filter_samples(vcz): + ds = load_dataset(vcz) + ds = bcftools_filter(ds, samples="NA00002,NA00003") + + assert ds.sizes["variants"] == 9 + assert ds.sizes["samples"] == 2 + assert_array_equal(ds["sample_id"], ["NA00002", "NA00003"]) + + +def test_bcftools_filter_all(vcz): + ds = load_dataset(vcz) + assert ds.sizes["variants"] == 9 + assert ds.sizes["samples"] == 3 + + ds = bcftools_filter( + ds, regions="20:1230236-", include="FMT/DP>3", samples="NA00002,NA00003" + ) + + assert ds.sizes["variants"] == 2 + assert ds.sizes["samples"] == 2 + + assert_array_equal(ds["variant_contig"], [1, 1]) + assert_array_equal(ds["variant_position"], [1230237, 1234567]) + assert_array_equal(ds["sample_id"], ["NA00002", "NA00003"])