Skip to content

Commit

Permalink
Merge pull request #3528 from broadinstitute/hail-backend-filter-loca…
Browse files Browse the repository at this point in the history
…tion

Hail backend filter location
  • Loading branch information
hanars authored Aug 3, 2023
2 parents 521cb7a + 7d9730e commit 6483883
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 35 deletions.
95 changes: 84 additions & 11 deletions hail_search/hail_search_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def _format_population_config(cls, pop_config):
base_pop_config.update(pop_config)
return base_pop_config

@property
def annotation_fields(self):
ht_globals = {k: hl.eval(self._ht[k]) for k in self.GLOBALS}
enums = ht_globals.pop('enums')
Expand Down Expand Up @@ -143,11 +142,19 @@ def __init__(self, data_type, sample_data, genome_version, sort=XPOS, num_result

self._load_filtered_table(data_type, sample_data, **kwargs)

def _load_filtered_table(self, data_type, sample_data, **kwargs):
self.import_filtered_table(data_type, sample_data, **kwargs)

def import_filtered_table(self, data_type, sample_data, **kwargs):
def _load_filtered_table(self, data_type, sample_data, intervals=None, exclude_intervals=False, variant_ids=None, **kwargs):
parsed_intervals, variant_ids = self._parse_intervals(intervals, variant_ids)
excluded_intervals = None
if exclude_intervals:
excluded_intervals = parsed_intervals
parsed_intervals = None
self.import_filtered_table(
data_type, sample_data, intervals=parsed_intervals, excluded_intervals=excluded_intervals,
variant_ids=variant_ids, **kwargs)

def import_filtered_table(self, data_type, sample_data, intervals=None, **kwargs):
tables_path = f'{DATASETS_DIR}/{self._genome_version}/{data_type}'
load_table_kwargs = {'_intervals': intervals, '_filter_intervals': bool(intervals)}

family_samples = defaultdict(list)
project_samples = defaultdict(list)
Expand All @@ -158,13 +165,13 @@ def import_filtered_table(self, data_type, sample_data, **kwargs):
logger.info(f'Loading {data_type} data for {len(family_samples)} families in {len(project_samples)} projects')
if len(family_samples) == 1:
family_guid, family_sample_data = list(family_samples.items())[0]
family_ht = hl.read_table(f'{tables_path}/families/{family_guid}.ht')
family_ht = hl.read_table(f'{tables_path}/families/{family_guid}.ht', **load_table_kwargs)
families_ht = self._filter_entries_table(family_ht, family_sample_data, **kwargs)
else:
filtered_project_hts = []
exception_messages = set()
for project_guid, project_sample_data in project_samples.items():
project_ht = hl.read_table(f'{tables_path}/projects/{project_guid}.ht')
project_ht = hl.read_table(f'{tables_path}/projects/{project_guid}.ht', **load_table_kwargs)
try:
filtered_project_hts.append(self._filter_entries_table(project_ht, project_sample_data, **kwargs))
except HTTPBadRequest as e:
Expand Down Expand Up @@ -204,8 +211,14 @@ def import_filtered_table(self, data_type, sample_data, **kwargs):
)
self._filter_annotated_table(**kwargs)

def _filter_entries_table(self, ht, sample_data, inheritance_mode=None, inheritance_filter=None, quality_filter=None,
**kwargs):
def _filter_entries_table(self, ht, sample_data, inheritance_mode=None, inheritance_filter=None, quality_filter=None,
excluded_intervals=None, variant_ids=None, **kwargs):
if excluded_intervals:
ht = hl.filter_intervals(ht, excluded_intervals, keep=False)

if variant_ids:
ht = self._filter_variant_ids(ht, variant_ids)

ht, sample_id_family_index_map = self._add_entry_sample_families(ht, sample_data)

ht = self._filter_inheritance(
Expand Down Expand Up @@ -375,9 +388,69 @@ def _filter_vcf_filters(ht):
def get_x_chrom_filter(ht, x_interval):
return x_interval.contains(ht.locus)

def _filter_annotated_table(self, frequencies=None, **kwargs):
def _filter_variant_ids(self, ht, variant_ids):
if len(variant_ids) == 1:
variant_id_q = ht.alleles == [variant_ids[0][2], variant_ids[0][3]]
else:
variant_id_qs = [
(ht.locus == hl.locus(chrom, pos, reference_genome=self._genome_version)) &
(ht.alleles == [ref, alt])
for chrom, pos, ref, alt in variant_ids
]
variant_id_q = variant_id_qs[0]
for q in variant_id_qs[1:]:
variant_id_q |= q
return ht.filter(variant_id_q)

def _filter_annotated_table(self, gene_ids=None, rs_ids=None, frequencies=None, **kwargs):
if gene_ids:
self._filter_by_gene_ids(gene_ids)

if rs_ids:
self._filter_rs_ids(rs_ids)

self._filter_by_frequency(frequencies)

def _filter_by_gene_ids(self, gene_ids):
gene_ids = hl.set(gene_ids)
self._ht = self._ht.filter(self._ht.sorted_transcript_consequences.any(lambda t: gene_ids.contains(t.gene_id)))

def _filter_rs_ids(self, rs_ids):
rs_id_set = hl.set(rs_ids)
self._ht = self._ht.filter(rs_id_set.contains(self._ht.rsid))

@staticmethod
def _formatted_chr_interval(interval):
return f'[chr{interval.replace("[", "")}' if interval.startswith('[') else f'chr{interval}'

def _parse_intervals(self, intervals, variant_ids):
if not (intervals or variant_ids):
return intervals, variant_ids

reference_genome = hl.get_reference(self._genome_version)
should_add_chr_prefix = any(c.startswith('chr') for c in reference_genome.contigs)

raw_intervals = intervals
if variant_ids:
if should_add_chr_prefix:
variant_ids = [(f'chr{chr}', *v_id) for chr, *v_id in variant_ids]
intervals = [f'[{chrom}:{pos}-{pos}]' for chrom, pos, _, _ in variant_ids]
elif should_add_chr_prefix:
intervals = [
f'[chr{interval.replace("[", "")}' if interval.startswith('[') else f'chr{interval}'
for interval in intervals
]

parsed_intervals = [
hl.eval(hl.parse_locus_interval(interval, reference_genome=self._genome_version, invalid_missing=True))
for interval in intervals
]
invalid_intervals = [raw_intervals[i] for i, interval in enumerate(parsed_intervals) if interval is None]
if invalid_intervals:
raise HTTPBadRequest(text=f'Invalid intervals: {", ".join(invalid_intervals)}')

return parsed_intervals, variant_ids

def _filter_by_frequency(self, frequencies):
frequencies = {k: v for k, v in (frequencies or {}).items() if k in self.POPULATIONS}
if not frequencies:
Expand Down Expand Up @@ -410,7 +483,7 @@ def _filter_by_frequency(self, frequencies):
self._ht = self._ht.filter(hl.is_missing(pop_expr) | pop_filter)

def _format_results(self, ht):
annotations = {k: v(ht) for k, v in self.annotation_fields.items()}
annotations = {k: v(ht) for k, v in self.annotation_fields().items()}
annotations.update({
'_sort': self._sort_order(ht),
'genomeVersion': self._genome_version.replace('GRCh', ''),
Expand Down
40 changes: 38 additions & 2 deletions hail_search/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from copy import deepcopy

from hail_search.test_utils import get_hail_search_body, FAMILY_2_VARIANT_SAMPLE_DATA, FAMILY_2_MISSING_SAMPLE_DATA, \
VARIANT1, VARIANT2, VARIANT3, VARIANT4, MULTI_PROJECT_SAMPLE_DATA, MULTI_PROJECT_MISSING_SAMPLE_DATA
VARIANT1, VARIANT2, VARIANT3, VARIANT4, MULTI_PROJECT_SAMPLE_DATA, MULTI_PROJECT_MISSING_SAMPLE_DATA, \
LOCATION_SEARCH, EXCLUDE_LOCATION_SEARCH, VARIANT_ID_SEARCH, RSID_SEARCH
from hail_search.web_app import init_web_app

PROJECT_2_VARIANT = {
Expand Down Expand Up @@ -182,6 +183,33 @@ async def test_quality_filter(self):
[VARIANT2, FAMILY_3_VARIANT], quality_filter={'min_gq': 40, 'min_ab': 50}, omit_sample_type='SV_WES',
)

async def test_location_search(self):
await self._assert_expected_search(
[VARIANT2, MULTI_FAMILY_VARIANT, VARIANT4], omit_sample_type='SV_WES', **LOCATION_SEARCH,
)

await self._assert_expected_search(
[VARIANT1], omit_sample_type='SV_WES', **EXCLUDE_LOCATION_SEARCH,
)

await self._assert_expected_search(
[MULTI_FAMILY_VARIANT], omit_sample_type='SV_WES',
intervals=LOCATION_SEARCH['intervals'][-1:], gene_ids=LOCATION_SEARCH['gene_ids'][:1]
)

async def test_variant_id_search(self):
await self._assert_expected_search([VARIANT2], omit_sample_type='SV_WES', **RSID_SEARCH)

await self._assert_expected_search([VARIANT1], omit_sample_type='SV_WES', **VARIANT_ID_SEARCH)

await self._assert_expected_search(
[VARIANT1], omit_sample_type='SV_WES', variant_ids=VARIANT_ID_SEARCH['variant_ids'][:1],
)

await self._assert_expected_search(
[], omit_sample_type='SV_WES', variant_ids=VARIANT_ID_SEARCH['variant_ids'][1:],
)

async def test_frequency_filter(self):
await self._assert_expected_search(
[VARIANT1, VARIANT4], frequencies={'seqr': {'af': 0.2}}, omit_sample_type='SV_WES',
Expand Down Expand Up @@ -217,7 +245,7 @@ async def test_frequency_filter(self):
omit_sample_type='SV_WES',
)

async def test_search_missing_data(self):
async def test_search_errors(self):
search_body = get_hail_search_body(sample_data=FAMILY_2_MISSING_SAMPLE_DATA)
async with self.client.request('POST', '/search', json=search_body) as resp:
self.assertEqual(resp.status, 400)
Expand All @@ -229,3 +257,11 @@ async def test_search_missing_data(self):
self.assertEqual(resp.status, 400)
text = await resp.text()
self.assertEqual(text, 'The following samples are available in seqr but missing the loaded data: NA19675, NA19678')

search_body = get_hail_search_body(
intervals=LOCATION_SEARCH['intervals'] + ['1:1-99999999999'], omit_sample_type='SV_WES',
)
async with self.client.request('POST', '/search', json=search_body) as resp:
self.assertEqual(resp.status, 400)
text = await resp.text()
self.assertEqual(text, 'Invalid intervals: 1:1-99999999999')
8 changes: 8 additions & 0 deletions hail_search/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,14 @@
HAIL_BACKEND_VARIANTS = [VARIANT2, MULTI_FAMILY_VARIANT]
HAIL_BACKEND_SINGLE_FAMILY_VARIANTS = [VARIANT2, VARIANT3]

LOCATION_SEARCH = {
'gene_ids': ['ENSG00000177000', 'ENSG00000097046'],
'intervals': ['2:1234-5678', '7:1-11100', '1:11785723-11806455', '1:91500851-91525764'],
}
EXCLUDE_LOCATION_SEARCH = {'intervals': LOCATION_SEARCH['intervals'], 'exclude_intervals': True}
VARIANT_ID_SEARCH = {'variant_ids': [['1', 10439, 'AC', 'A'], ['1', 91511686, 'TCA', 'G']], 'rs_ids': []}
RSID_SEARCH = {'variant_ids': [], 'rs_ids': ['rs1801131']}


def get_hail_search_body(genome_version='GRCh38', num_results=100, sample_data=None, omit_sample_type=None, **search_body):
sample_data = sample_data or EXPECTED_SAMPLE_DATA
Expand Down
2 changes: 1 addition & 1 deletion reference_data/management/tests/update_gencode_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_update_gencode_command(self, mock_logger, mock_update_transcripts_logge
])
calls = [
mock.call('Dropping the 3 existing TranscriptInfo entries'),
mock.call('Dropping the 50 existing GeneInfo entries'),
mock.call('Dropping the 52 existing GeneInfo entries'),
mock.call('Creating 2 GeneInfo records'),
mock.call('Done'),
mock.call('Stats: '),
Expand Down
38 changes: 38 additions & 0 deletions seqr/fixtures/reference_data.json
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,44 @@
"gencode_gene_type": "antisense_RNA",
"gencode_release": 27
}
}, {
"model": "reference_data.geneinfo",
"pk": 60,
"fields": {
"gene_id": "ENSG00000177000",
"gene_symbol": "MTHFR",
"chrom_grch37": "1",
"start_grch37": 11785723,
"end_grch37": 11806455,
"strand_grch37": "+",
"coding_region_size_grch37": 0,
"chrom_grch38": "1",
"start_grch38": 11785723,
"end_grch38": 11806455,
"strand_grch38": "+",
"coding_region_size_grch38": 0,
"gencode_gene_type": "protein_coding",
"gencode_release": 27
}
}, {
"model": "reference_data.geneinfo",
"pk": 61,
"fields": {
"gene_id": "ENSG00000097046",
"gene_symbol": "CDC7",
"chrom_grch37": "1",
"start_grch37": 91500851,
"end_grch37": 91525764,
"strand_grch37": "+",
"coding_region_size_grch37": 0,
"chrom_grch38": "1",
"start_grch38": 91500851,
"end_grch38": 91525764,
"strand_grch38": "+",
"coding_region_size_grch38": 0,
"gencode_gene_type": "protein_coding",
"gencode_release": 27
}
}, {
"model": "reference_data.transcriptinfo",
"pk": 1,
Expand Down
26 changes: 10 additions & 16 deletions seqr/utils/search/hail_search_utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
get_variants_for_variant_ids, InvalidSearchException
from seqr.utils.search.search_utils_tests import SearchTestHelper, MOCK_COUNTS
from hail_search.test_utils import get_hail_search_body, EXPECTED_SAMPLE_DATA, FAMILY_1_SAMPLE_DATA, \
FAMILY_2_ALL_SAMPLE_DATA, ALL_AFFECTED_SAMPLE_DATA, CUSTOM_AFFECTED_SAMPLE_DATA, HAIL_BACKEND_VARIANTS

FAMILY_2_ALL_SAMPLE_DATA, ALL_AFFECTED_SAMPLE_DATA, CUSTOM_AFFECTED_SAMPLE_DATA, HAIL_BACKEND_VARIANTS, \
LOCATION_SEARCH, EXCLUDE_LOCATION_SEARCH, VARIANT_ID_SEARCH, RSID_SEARCH
MOCK_HOST = 'http://test-hail-host'


Expand Down Expand Up @@ -72,33 +72,27 @@ def test_query_variants(self):
self.assertListEqual(variants, HAIL_BACKEND_VARIANTS[1:])
self._test_expected_search_call(sort='cadd', num_results=2)

self.search_model.search['locus'] = {'rawVariantItems': '1-248367227-TC-T,2-103343353-GAGA-G'}
self.search_model.search['locus'] = {'rawVariantItems': '1-10439-AC-A,1-91511686-TCA-G'}
query_variants(self.results_model, user=self.user, sort='in_omim')
self._test_expected_search_call(
num_results=2, dataset_type='VARIANTS', omit_sample_type='SV_WES', rs_ids=[],
variant_ids=[['1', 248367227, 'TC', 'T'], ['2', 103343353, 'GAGA', 'G']],
num_results=2, dataset_type='VARIANTS', omit_sample_type='SV_WES',
sort='in_omim', sort_metadata=['ENSG00000223972', 'ENSG00000243485', 'ENSG00000268020'],
**VARIANT_ID_SEARCH,
)

self.search_model.search['locus']['rawVariantItems'] = 'rs9876'
self.search_model.search['locus']['rawVariantItems'] = 'rs1801131'
query_variants(self.results_model, user=self.user, sort='constraint')
self._test_expected_search_call(
rs_ids=['rs9876'], variant_ids=[], sort='constraint', sort_metadata={'ENSG00000223972': 2},
sort='constraint', sort_metadata={'ENSG00000223972': 2}, **RSID_SEARCH,
)

self.search_model.search['locus']['rawItems'] = 'DDX11L1, chr2:1234-5678, chr7:100-10100%10, ENSG00000186092'
self.search_model.search['locus']['rawItems'] = 'CDC7, chr2:1234-5678, chr7:100-10100%10, ENSG00000177000'
query_variants(self.results_model, user=self.user)
self._test_expected_search_call(
gene_ids=['ENSG00000223972', 'ENSG00000186092'], intervals=[
'2:1234-5678', '7:1-11100', '1:11869-14409', '1:65419-71585'
],
)
self._test_expected_search_call(**LOCATION_SEARCH)

self.search_model.search['locus']['excludeLocations'] = True
query_variants(self.results_model, user=self.user)
self._test_expected_search_call(
intervals=['2:1234-5678', '7:1-11100', '1:11869-14409', '1:65419-71585'], exclude_intervals=True,
)
self._test_expected_search_call(**EXCLUDE_LOCATION_SEARCH)

self.search_model.search = {
'inheritance': {'mode': 'recessive', 'filter': {'affected': {
Expand Down
6 changes: 3 additions & 3 deletions seqr/views/apis/awesomebar_api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ def test_awesomebar_autocomplete_handler(self):
self.assertEqual(len(genes), 5)
self.assertListEqual(
[g['title'] for g in genes],
['ENSG00000135953', 'ENSG00000186092', 'ENSG00000185097', 'DDX11L1', 'ENSG00000237613'],
['ENSG00000135953', 'ENSG00000177000', 'ENSG00000186092', 'ENSG00000185097', 'DDX11L1'],
)
self.assertDictEqual(genes[1], {
self.assertDictEqual(genes[2], {
'key': 'ENSG00000186092',
'title': 'ENSG00000186092',
'description': '(OR4F5)',
'href': '/summary_data/gene_info/ENSG00000186092',
})
self.assertDictEqual(genes[3], {
self.assertDictEqual(genes[4], {
'key': 'ENSG00000223972',
'title': 'DDX11L1',
'description': '(ENSG00000223972)',
Expand Down
9 changes: 7 additions & 2 deletions seqr/views/apis/variant_search_api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@
'SV0000001_2103343353_r0390_100': expected_detail_saved_variant,
'SV0000002_1248367227_r0390_100': EXPECTED_SAVED_VARIANT,
},
'genesById': {'ENSG00000227232': expected_pa_gene, 'ENSG00000268903': EXPECTED_GENE, 'ENSG00000233653': EXPECTED_GENE},
'genesById': {
'ENSG00000227232': expected_pa_gene, 'ENSG00000268903': EXPECTED_GENE, 'ENSG00000233653': EXPECTED_GENE,
'ENSG00000177000': mock.ANY, 'ENSG00000097046': mock.ANY,
},
'transcriptsById': {'ENST00000624735': {'isManeSelect': False, 'refseqId': None, 'transcriptId': 'ENST00000624735'}},
'search': {
'search': SEARCH,
Expand Down Expand Up @@ -717,7 +720,9 @@ def test_query_single_variant(self, mock_get_variant):
expected_search_response['variantTagsByGuid'].pop('VT1726945_2103343353_r0390_100')
expected_search_response['variantTagsByGuid'].pop('VT1726970_2103343353_r0004_tes')
expected_search_response['variantNotesByGuid'] = {}
expected_search_response['genesById'].pop('ENSG00000233653')
expected_search_response['genesById'] = {
k: v for k, v in expected_search_response['genesById'].items() if k in {'ENSG00000227232', 'ENSG00000268903'}
}
expected_search_response['searchedVariants'] = [single_family_variant]
self.assertDictEqual(response_json, expected_search_response)
self._assert_expected_results_family_context(response_json, locus_list_detail=True)
Expand Down

0 comments on commit 6483883

Please sign in to comment.