Skip to content

Add bcftools-style filtering #1330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ Utilities
.. autosummary::
:toctree: generated/

bcftools_filter
convert_call_to_index
convert_probability_to_call
display_genotypes
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"fsspec != 2021.6.*",
"scikit-learn",
"pandas",
"vcztools",
"setuptools >= 41.2", # For pkg_resources
]
dynamic = ["version"]
Expand Down Expand Up @@ -69,7 +70,7 @@ fail_under = 100

[tool.pytest.ini_options]
addopts = "--doctest-modules --ignore=validation --cov-fail-under=100"
norecursedirs = [".eggs", "build", "docs"]
norecursedirs = [".eggs", ".hypothesis", "build", "docs"]
filterwarnings = ["error", "ignore::DeprecationWarning"]


Expand Down
1 change: 1 addition & 0 deletions requirements-numpy1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ zarr >= 2.10.0, != 2.11.0, != 2.11.1, != 2.11.2, < 3
fsspec != 2021.6.*
scikit-learn
pandas
vcztools
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ zarr >= 2.10.0, != 2.11.0, != 2.11.1, != 2.11.2, < 3
fsspec != 2021.6.*
scikit-learn
pandas
vcztools
2 changes: 2 additions & 0 deletions sgkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .display import display_genotypes, display_pedigree
from .distance.api import pairwise_distance
from .filtering import bcftools_filter
from .io.dataset import load_dataset, save_dataset
from .model import (
DIM_ALLELE,
Expand Down Expand Up @@ -78,6 +79,7 @@
"DIM_PLOIDY",
"DIM_SAMPLE",
"DIM_VARIANT",
"bcftools_filter",
"call_allele_frequencies",
"create_genotype_call_dataset",
"cohort_allele_frequencies",
Expand Down
167 changes: 167 additions & 0 deletions sgkit/filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
from typing import Optional, Union

import numpy as np
import xarray as xr
from vcztools import filter as filter_mod
from vcztools.regions import (
parse_regions,
parse_targets,
regions_to_chunk_indexes,
regions_to_selection,
)
from vcztools.samples import parse_samples


def bcftools_filter(
ds,
*,
regions: Optional[str] = None,
targets: Optional[str] = None,
include: Optional[str] = None,
exclude: Optional[str] = None,
samples: Union[list[str], str, None] = None,
):
"""Filter a dataset using `bcftools`-style expressions.

The dataset can be subset in the variants dimension to a set
of genomic regions and/or filter expressions that match sites
using the `bcftools` `expression language <https://samtools.github.io/bcftools/bcftools.html#expressions>`_
Additionally, the samples dimension can optionally be subset
to a list of sample IDs.

Parameters
----------
ds
The dataset to be filtered.
regions
The regions to include specified as a comma-separated list of
region strings. Corresponds to `bcftools` `-r`/`--regions`.
targets
The region targets to include specified as a comma-separated list
of region strings. Corresponds to `bcftools` `-t`/`--targets`.
include
Filter expression to include variant sites.
Corresponds to `bcftools` `-i`/`--include`.
exclude
Filter expression to exclude variant sites.
Corresponds to `bcftools` `-e`/`--exclude`.
samples
The samples to include specified as a comma-separated list of
sample IDs, or a list of sample IDs.
Corresponds to `bcftools` `-s`/`--samples`.

Returns
-------
A dataset with a subset of variants and/or samples.
"""
if regions is not None or targets is not None:
ds = _filter_regions(ds, regions, targets)
if include is not None or exclude is not None:
ds = _filter_expressions(ds, include, exclude)
if samples is not None:
ds = _filter_samples(ds, samples)
return ds


def _filter_regions(ds, regions, targets):
# Use the region index to find the chunks that overlap specified regions or
# targets, and
# 1. convert that into a single coarse slice (with chunk granularity)
# in the variants dimension, then
# 2. find the mask for each to each chunk in the coarse slice to
# ensure the exact region is selected.
# This works well for small, dense regions but not for sparse regions that
# span the genome.

# Step 1: find smallest single slice that covers regions

contigs = ds["contig_id"].values.astype("U").tolist()
regions_pyranges = parse_regions(regions, contigs)
targets_pyranges, complement = parse_targets(targets, contigs)

region_index = ds["region_index"].values
chunk_indexes = regions_to_chunk_indexes(
regions_pyranges,
targets_pyranges,
complement,
region_index,
)

if len(chunk_indexes) == 0:
# zero variants
return ds.isel(variants=slice(0, 0))

# check chunks are equally sized
chunks = ds.chunks["variants"]
if len(chunks) > 1:
# ignore last chunk since it may be smaller
chunks = chunks[:-1]
if not all(chunk == chunks[0] for chunk in chunks):
raise ValueError(
f'Dataset must have uniform chunk sizes in variants dimension, but was {ds.chunks["variants"]}.'
)
variant_chunksize = chunks[0]

variant_slice = slice(
int(chunk_indexes[0] * variant_chunksize),
max(ds.sizes["variants"], int(chunk_indexes[-1] * variant_chunksize + 1)),
)

ds_sliced = ds.isel(variants=variant_slice)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works but could be very inefficient. Imagine we just want one small region at the start of the genome and one at the end - this would read all of the intervening chunks even though they are not needed.

What we really want is a way to index by a set of slices (or even chunks) in one go. NumPy and Xarray don't provide such a primitive, but it's something we could perhaps build.

cc-ing @keewis @TomNicholas and @alxmrs as we were discussing this exact issue in yesterday's Pangeo Distributed Computing meeting (in the context of a geospatial application Justus is working on).

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've opened pydata/xarray#10479 to discuss this for xarray


# Step 2: filter each chunk in sliced dataset

def regions_to_mask(*args):
variant_selection = regions_to_selection(
regions_pyranges, targets_pyranges, complement, *args
)
variant_mask = np.zeros(args[0].shape[0], dtype=bool)
variant_mask[variant_selection] = 1
return variant_mask

# restrict to fields needed by regions_to_selection
ds_filter_fields = ds_sliced[
["variant_contig", "variant_position", "variant_length"]
]
data_arrays = tuple(ds_filter_fields[n] for n in ds_filter_fields.data_vars)
variant_filter = xr.apply_ufunc(
regions_to_mask, *data_arrays, output_dtypes=[bool], dask="parallelized"
)

# note we have to call compute here since xarray indexing with a Dask or Cubed
# boolean array is not supported
return ds_sliced.isel(variants=variant_filter.compute())


def _filter_expressions(ds, include, exclude):
filter_expr = filter_mod.FilterExpression(
field_names=set(ds.data_vars), include=include, exclude=exclude
)

def compute_call_mask(*args):
chunk_data = {k: v for k, v in zip(filter_expr.referenced_fields, args)}
call_mask = filter_expr.evaluate(chunk_data)
return call_mask

# restrict to fields needed by filter expression
ds_filter_fields = ds[list(filter_expr.referenced_fields)]

# note that this will only work if chunked only in the variants dimension
# may need to merge chunks in samples dim

data_arrays = tuple(ds_filter_fields[n] for n in ds_filter_fields.data_vars)
da = xr.apply_ufunc(
compute_call_mask, *data_arrays, output_dtypes=[bool], dask="parallelized"
)
ds["call_mask"] = da

# filter to variants where at least one sample has been selected
# note we have to call compute here since xarray indexing with a Dask or Cubed
# boolean array is not supported
return ds.isel(variants=ds.call_mask.any(dim="samples").compute())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the implications here for the memory requirements when working on a large dataset?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. It will materialise a 1D boolean array of length #variants into memory, which unless a restrictive region filter has been applied first could be very large (potentially the whole genome).

This is an area where there is scope to improve - perhaps by making Xarray and the underlying distributed processing engine able to handle this case efficiently, or by using masked arrays in some way?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a 1D boolean array length #variants is fine - it could be improved but it's not terrible. I was worried that it was a O(num_samples) thing.



def _filter_samples(ds, samples):
all_samples = ds["sample_id"].values
_, sample_selection = parse_samples(samples, all_samples)
return ds.isel(samples=sample_selection)
78 changes: 78 additions & 0 deletions sgkit/tests/io/test_filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytest
from numpy.testing import assert_array_equal

pytest.importorskip("bio2zarr")
from bio2zarr import vcf

from sgkit import bcftools_filter, load_dataset


@pytest.fixture()
def vcz(shared_datadir, tmp_path):
vcf_path = shared_datadir / "sample.vcf.gz"
vcz_path = tmp_path.joinpath("sample.vcz").as_posix()

vcf.convert(
[vcf_path],
vcz_path,
variants_chunk_size=5,
samples_chunk_size=2,
)

return vcz_path


def test_bcftools_filter_regions(vcz):
ds = load_dataset(vcz)
ds = bcftools_filter(ds, regions="20:1230236-")

assert ds.sizes["variants"] == 3
assert ds.sizes["samples"] == 3
assert_array_equal(ds["variant_position"], [1230237, 1234567, 1235237])


def test_bcftools_filter_empty_regions(vcz):
ds = load_dataset(vcz)
ds = bcftools_filter(ds, regions="20:1-2")

assert ds.sizes["variants"] == 0
assert ds.sizes["samples"] == 3
assert len(ds["variant_position"]) == 0


def test_bcftools_filter_expressions(vcz):
ds = load_dataset(vcz)
ds = bcftools_filter(ds, include="FMT/DP>3")

assert ds.sizes["variants"] == 5
assert ds.sizes["samples"] == 3
assert_array_equal(ds["variant_contig"], [1, 1, 1, 1, 1])
assert_array_equal(
ds["variant_position"], [14370, 17330, 1110696, 1230237, 1234567]
)


def test_bcftools_filter_samples(vcz):
ds = load_dataset(vcz)
ds = bcftools_filter(ds, samples="NA00002,NA00003")

assert ds.sizes["variants"] == 9
assert ds.sizes["samples"] == 2
assert_array_equal(ds["sample_id"], ["NA00002", "NA00003"])


def test_bcftools_filter_all(vcz):
ds = load_dataset(vcz)
assert ds.sizes["variants"] == 9
assert ds.sizes["samples"] == 3

ds = bcftools_filter(
ds, regions="20:1230236-", include="FMT/DP>3", samples="NA00002,NA00003"
)

assert ds.sizes["variants"] == 2
assert ds.sizes["samples"] == 2

assert_array_equal(ds["variant_contig"], [1, 1])
assert_array_equal(ds["variant_position"], [1230237, 1234567])
assert_array_equal(ds["sample_id"], ["NA00002", "NA00003"])
Loading