Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hail backend filter location #3528

Merged
merged 12 commits into from
Aug 3, 2023
95 changes: 84 additions & 11 deletions hail_search/hail_search_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def _format_population_config(cls, pop_config):
base_pop_config.update(pop_config)
return base_pop_config

@property
def annotation_fields(self):
ht_globals = {k: hl.eval(self._ht[k]) for k in self.GLOBALS}
enums = ht_globals.pop('enums')
Expand Down Expand Up @@ -143,11 +142,19 @@ def __init__(self, data_type, sample_data, genome_version, sort=XPOS, num_result

self._load_filtered_table(data_type, sample_data, **kwargs)

def _load_filtered_table(self, data_type, sample_data, **kwargs):
self.import_filtered_table(data_type, sample_data, **kwargs)

def import_filtered_table(self, data_type, sample_data, **kwargs):
def _load_filtered_table(self, data_type, sample_data, intervals=None, exclude_intervals=False, variant_ids=None, **kwargs):
parsed_intervals, variant_ids = self._parse_intervals(intervals, variant_ids)
excluded_intervals = None
if exclude_intervals:
excluded_intervals = parsed_intervals
parsed_intervals = None
self.import_filtered_table(
data_type, sample_data, intervals=parsed_intervals, excluded_intervals=excluded_intervals,
variant_ids=variant_ids, **kwargs)

def import_filtered_table(self, data_type, sample_data, intervals=None, **kwargs):
tables_path = f'{DATASETS_DIR}/{self._genome_version}/{data_type}'
load_table_kwargs = {'_intervals': intervals, '_filter_intervals': bool(intervals)}

family_samples = defaultdict(list)
project_samples = defaultdict(list)
Expand All @@ -158,13 +165,13 @@ def import_filtered_table(self, data_type, sample_data, **kwargs):
logger.info(f'Loading {data_type} data for {len(family_samples)} families in {len(project_samples)} projects')
if len(family_samples) == 1:
family_guid, family_sample_data = list(family_samples.items())[0]
family_ht = hl.read_table(f'{tables_path}/families/{family_guid}.ht')
family_ht = hl.read_table(f'{tables_path}/families/{family_guid}.ht', **load_table_kwargs)
families_ht = self._filter_entries_table(family_ht, family_sample_data, **kwargs)
else:
filtered_project_hts = []
exception_messages = set()
for project_guid, project_sample_data in project_samples.items():
project_ht = hl.read_table(f'{tables_path}/projects/{project_guid}.ht')
project_ht = hl.read_table(f'{tables_path}/projects/{project_guid}.ht', **load_table_kwargs)
try:
filtered_project_hts.append(self._filter_entries_table(project_ht, project_sample_data, **kwargs))
except HTTPBadRequest as e:
Expand Down Expand Up @@ -204,8 +211,14 @@ def import_filtered_table(self, data_type, sample_data, **kwargs):
)
self._filter_annotated_table(**kwargs)

def _filter_entries_table(self, ht, sample_data, inheritance_mode=None, inheritance_filter=None, quality_filter=None,
**kwargs):
def _filter_entries_table(self, ht, sample_data, inheritance_mode=None, inheritance_filter=None, quality_filter=None,
excluded_intervals=None, variant_ids=None, **kwargs):
if excluded_intervals:
ht = hl.filter_intervals(ht, excluded_intervals, keep=False)

if variant_ids:
ht = self._filter_variant_ids(ht, variant_ids)

ht, sample_id_family_index_map = self._add_entry_sample_families(ht, sample_data)

ht = self._filter_inheritance(
Expand Down Expand Up @@ -375,9 +388,69 @@ def _filter_vcf_filters(ht):
def get_x_chrom_filter(ht, x_interval):
return x_interval.contains(ht.locus)

def _filter_annotated_table(self, frequencies=None, **kwargs):
def _filter_variant_ids(self, ht, variant_ids):
if len(variant_ids) == 1:
variant_id_q = ht.alleles == [variant_ids[0][2], variant_ids[0][3]]
else:
variant_id_qs = [
(ht.locus == hl.locus(chrom, pos, reference_genome=self._genome_version)) &
(ht.alleles == [ref, alt])
for chrom, pos, ref, alt in variant_ids
]
variant_id_q = variant_id_qs[0]
for q in variant_id_qs[1:]:
variant_id_q |= q
return ht.filter(variant_id_q)

def _filter_annotated_table(self, gene_ids=None, rs_ids=None, frequencies=None, **kwargs):
if gene_ids:
self._filter_by_gene_ids(gene_ids)

if rs_ids:
self._filter_rs_ids(rs_ids)

self._filter_by_frequency(frequencies)

def _filter_by_gene_ids(self, gene_ids):
gene_ids = hl.set(gene_ids)
self._ht = self._ht.filter(self._ht.sorted_transcript_consequences.any(lambda t: gene_ids.contains(t.gene_id)))

def _filter_rs_ids(self, rs_ids):
rs_id_set = hl.set(rs_ids)
self._ht = self._ht.filter(rs_id_set.contains(self._ht.rsid))

@staticmethod
def _formatted_chr_interval(interval):
return f'[chr{interval.replace("[", "")}' if interval.startswith('[') else f'chr{interval}'

def _parse_intervals(self, intervals, variant_ids):
if not (intervals or variant_ids):
return intervals, variant_ids

reference_genome = hl.get_reference(self._genome_version)
should_add_chr_prefix = any(c.startswith('chr') for c in reference_genome.contigs)

raw_intervals = intervals
if variant_ids:
if should_add_chr_prefix:
variant_ids = [(f'chr{chr}', *v_id) for chr, *v_id in variant_ids]
intervals = [f'[{chrom}:{pos}-{pos}]' for chrom, pos, _, _ in variant_ids]
elif should_add_chr_prefix:
intervals = [
f'[chr{interval.replace("[", "")}' if interval.startswith('[') else f'chr{interval}'
for interval in intervals
]

parsed_intervals = [
hl.eval(hl.parse_locus_interval(interval, reference_genome=self._genome_version, invalid_missing=True))
for interval in intervals
]
invalid_intervals = [raw_intervals[i] for i, interval in enumerate(parsed_intervals) if interval is None]
if invalid_intervals:
raise HTTPBadRequest(text=f'Invalid intervals: {", ".join(invalid_intervals)}')

return parsed_intervals, variant_ids

def _filter_by_frequency(self, frequencies):
frequencies = {k: v for k, v in (frequencies or {}).items() if k in self.POPULATIONS}
if not frequencies:
Expand Down Expand Up @@ -410,7 +483,7 @@ def _filter_by_frequency(self, frequencies):
self._ht = self._ht.filter(hl.is_missing(pop_expr) | pop_filter)

def _format_results(self, ht):
annotations = {k: v(ht) for k, v in self.annotation_fields.items()}
annotations = {k: v(ht) for k, v in self.annotation_fields().items()}
annotations.update({
'_sort': self._sort_order(ht),
'genomeVersion': self._genome_version.replace('GRCh', ''),
Expand Down
40 changes: 38 additions & 2 deletions hail_search/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from copy import deepcopy

from hail_search.test_utils import get_hail_search_body, FAMILY_2_VARIANT_SAMPLE_DATA, FAMILY_2_MISSING_SAMPLE_DATA, \
VARIANT1, VARIANT2, VARIANT3, VARIANT4, MULTI_PROJECT_SAMPLE_DATA, MULTI_PROJECT_MISSING_SAMPLE_DATA
VARIANT1, VARIANT2, VARIANT3, VARIANT4, MULTI_PROJECT_SAMPLE_DATA, MULTI_PROJECT_MISSING_SAMPLE_DATA, \
LOCATION_SEARCH, EXCLUDE_LOCATION_SEARCH, VARIANT_ID_SEARCH, RSID_SEARCH
from hail_search.web_app import init_web_app

PROJECT_2_VARIANT = {
Expand Down Expand Up @@ -182,6 +183,33 @@ async def test_quality_filter(self):
[VARIANT2, FAMILY_3_VARIANT], quality_filter={'min_gq': 40, 'min_ab': 50}, omit_sample_type='SV_WES',
)

async def test_location_search(self):
await self._assert_expected_search(
[VARIANT2, MULTI_FAMILY_VARIANT, VARIANT4], omit_sample_type='SV_WES', **LOCATION_SEARCH,
)

await self._assert_expected_search(
[VARIANT1], omit_sample_type='SV_WES', **EXCLUDE_LOCATION_SEARCH,
)

await self._assert_expected_search(
[MULTI_FAMILY_VARIANT], omit_sample_type='SV_WES',
intervals=LOCATION_SEARCH['intervals'][-1:], gene_ids=LOCATION_SEARCH['gene_ids'][:1]
)

async def test_variant_id_search(self):
await self._assert_expected_search([VARIANT2], omit_sample_type='SV_WES', **RSID_SEARCH)

await self._assert_expected_search([VARIANT1], omit_sample_type='SV_WES', **VARIANT_ID_SEARCH)

await self._assert_expected_search(
[VARIANT1], omit_sample_type='SV_WES', variant_ids=VARIANT_ID_SEARCH['variant_ids'][:1],
)

await self._assert_expected_search(
[], omit_sample_type='SV_WES', variant_ids=VARIANT_ID_SEARCH['variant_ids'][1:],
)

async def test_frequency_filter(self):
await self._assert_expected_search(
[VARIANT1, VARIANT4], frequencies={'seqr': {'af': 0.2}}, omit_sample_type='SV_WES',
Expand Down Expand Up @@ -217,7 +245,7 @@ async def test_frequency_filter(self):
omit_sample_type='SV_WES',
)

async def test_search_missing_data(self):
async def test_search_errors(self):
search_body = get_hail_search_body(sample_data=FAMILY_2_MISSING_SAMPLE_DATA)
async with self.client.request('POST', '/search', json=search_body) as resp:
self.assertEqual(resp.status, 400)
Expand All @@ -229,3 +257,11 @@ async def test_search_missing_data(self):
self.assertEqual(resp.status, 400)
text = await resp.text()
self.assertEqual(text, 'The following samples are available in seqr but missing the loaded data: NA19675, NA19678')

search_body = get_hail_search_body(
intervals=LOCATION_SEARCH['intervals'] + ['1:1-99999999999'], omit_sample_type='SV_WES',
)
async with self.client.request('POST', '/search', json=search_body) as resp:
self.assertEqual(resp.status, 400)
text = await resp.text()
self.assertEqual(text, 'Invalid intervals: 1:1-99999999999')
8 changes: 8 additions & 0 deletions hail_search/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,14 @@
HAIL_BACKEND_VARIANTS = [VARIANT2, MULTI_FAMILY_VARIANT]
HAIL_BACKEND_SINGLE_FAMILY_VARIANTS = [VARIANT2, VARIANT3]

LOCATION_SEARCH = {
'gene_ids': ['ENSG00000177000', 'ENSG00000097046'],
'intervals': ['2:1234-5678', '7:1-11100', '1:11785723-11806455', '1:91500851-91525764'],
}
EXCLUDE_LOCATION_SEARCH = {'intervals': LOCATION_SEARCH['intervals'], 'exclude_intervals': True}
VARIANT_ID_SEARCH = {'variant_ids': [['1', 10439, 'AC', 'A'], ['1', 91511686, 'TCA', 'G']], 'rs_ids': []}
RSID_SEARCH = {'variant_ids': [], 'rs_ids': ['rs1801131']}


def get_hail_search_body(genome_version='GRCh38', num_results=100, sample_data=None, omit_sample_type=None, **search_body):
sample_data = sample_data or EXPECTED_SAMPLE_DATA
Expand Down
2 changes: 1 addition & 1 deletion reference_data/management/tests/update_gencode_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_update_gencode_command(self, mock_logger, mock_update_transcripts_logge
])
calls = [
mock.call('Dropping the 3 existing TranscriptInfo entries'),
mock.call('Dropping the 50 existing GeneInfo entries'),
mock.call('Dropping the 52 existing GeneInfo entries'),
mock.call('Creating 2 GeneInfo records'),
mock.call('Done'),
mock.call('Stats: '),
Expand Down
38 changes: 38 additions & 0 deletions seqr/fixtures/reference_data.json
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,44 @@
"gencode_gene_type": "antisense_RNA",
"gencode_release": 27
}
}, {
"model": "reference_data.geneinfo",
"pk": 60,
"fields": {
"gene_id": "ENSG00000177000",
"gene_symbol": "MTHFR",
"chrom_grch37": "1",
"start_grch37": 11785723,
"end_grch37": 11806455,
"strand_grch37": "+",
"coding_region_size_grch37": 0,
"chrom_grch38": "1",
"start_grch38": 11785723,
"end_grch38": 11806455,
"strand_grch38": "+",
"coding_region_size_grch38": 0,
"gencode_gene_type": "protein_coding",
"gencode_release": 27
}
}, {
"model": "reference_data.geneinfo",
"pk": 61,
"fields": {
"gene_id": "ENSG00000097046",
"gene_symbol": "CDC7",
"chrom_grch37": "1",
"start_grch37": 91500851,
"end_grch37": 91525764,
"strand_grch37": "+",
"coding_region_size_grch37": 0,
"chrom_grch38": "1",
"start_grch38": 91500851,
"end_grch38": 91525764,
"strand_grch38": "+",
"coding_region_size_grch38": 0,
"gencode_gene_type": "protein_coding",
"gencode_release": 27
}
}, {
"model": "reference_data.transcriptinfo",
"pk": 1,
Expand Down
26 changes: 10 additions & 16 deletions seqr/utils/search/hail_search_utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
get_variants_for_variant_ids, InvalidSearchException
from seqr.utils.search.search_utils_tests import SearchTestHelper, MOCK_COUNTS
from hail_search.test_utils import get_hail_search_body, EXPECTED_SAMPLE_DATA, FAMILY_1_SAMPLE_DATA, \
FAMILY_2_ALL_SAMPLE_DATA, ALL_AFFECTED_SAMPLE_DATA, CUSTOM_AFFECTED_SAMPLE_DATA, HAIL_BACKEND_VARIANTS

FAMILY_2_ALL_SAMPLE_DATA, ALL_AFFECTED_SAMPLE_DATA, CUSTOM_AFFECTED_SAMPLE_DATA, HAIL_BACKEND_VARIANTS, \
LOCATION_SEARCH, EXCLUDE_LOCATION_SEARCH, VARIANT_ID_SEARCH, RSID_SEARCH
MOCK_HOST = 'http://test-hail-host'


Expand Down Expand Up @@ -72,33 +72,27 @@ def test_query_variants(self):
self.assertListEqual(variants, HAIL_BACKEND_VARIANTS[1:])
self._test_expected_search_call(sort='cadd', num_results=2)

self.search_model.search['locus'] = {'rawVariantItems': '1-248367227-TC-T,2-103343353-GAGA-G'}
self.search_model.search['locus'] = {'rawVariantItems': '1-10439-AC-A,1-91511686-TCA-G'}
query_variants(self.results_model, user=self.user, sort='in_omim')
self._test_expected_search_call(
num_results=2, dataset_type='VARIANTS', omit_sample_type='SV_WES', rs_ids=[],
variant_ids=[['1', 248367227, 'TC', 'T'], ['2', 103343353, 'GAGA', 'G']],
num_results=2, dataset_type='VARIANTS', omit_sample_type='SV_WES',
sort='in_omim', sort_metadata=['ENSG00000223972', 'ENSG00000243485', 'ENSG00000268020'],
**VARIANT_ID_SEARCH,
)

self.search_model.search['locus']['rawVariantItems'] = 'rs9876'
self.search_model.search['locus']['rawVariantItems'] = 'rs1801131'
query_variants(self.results_model, user=self.user, sort='constraint')
self._test_expected_search_call(
rs_ids=['rs9876'], variant_ids=[], sort='constraint', sort_metadata={'ENSG00000223972': 2},
sort='constraint', sort_metadata={'ENSG00000223972': 2}, **RSID_SEARCH,
)

self.search_model.search['locus']['rawItems'] = 'DDX11L1, chr2:1234-5678, chr7:100-10100%10, ENSG00000186092'
self.search_model.search['locus']['rawItems'] = 'CDC7, chr2:1234-5678, chr7:100-10100%10, ENSG00000177000'
query_variants(self.results_model, user=self.user)
self._test_expected_search_call(
gene_ids=['ENSG00000223972', 'ENSG00000186092'], intervals=[
'2:1234-5678', '7:1-11100', '1:11869-14409', '1:65419-71585'
],
)
self._test_expected_search_call(**LOCATION_SEARCH)

self.search_model.search['locus']['excludeLocations'] = True
query_variants(self.results_model, user=self.user)
self._test_expected_search_call(
intervals=['2:1234-5678', '7:1-11100', '1:11869-14409', '1:65419-71585'], exclude_intervals=True,
)
self._test_expected_search_call(**EXCLUDE_LOCATION_SEARCH)

self.search_model.search = {
'inheritance': {'mode': 'recessive', 'filter': {'affected': {
Expand Down
6 changes: 3 additions & 3 deletions seqr/views/apis/awesomebar_api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ def test_awesomebar_autocomplete_handler(self):
self.assertEqual(len(genes), 5)
self.assertListEqual(
[g['title'] for g in genes],
['ENSG00000135953', 'ENSG00000186092', 'ENSG00000185097', 'DDX11L1', 'ENSG00000237613'],
['ENSG00000135953', 'ENSG00000177000', 'ENSG00000186092', 'ENSG00000185097', 'DDX11L1'],
)
self.assertDictEqual(genes[1], {
self.assertDictEqual(genes[2], {
'key': 'ENSG00000186092',
'title': 'ENSG00000186092',
'description': '(OR4F5)',
'href': '/summary_data/gene_info/ENSG00000186092',
})
self.assertDictEqual(genes[3], {
self.assertDictEqual(genes[4], {
'key': 'ENSG00000223972',
'title': 'DDX11L1',
'description': '(ENSG00000223972)',
Expand Down
9 changes: 7 additions & 2 deletions seqr/views/apis/variant_search_api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@
'SV0000001_2103343353_r0390_100': expected_detail_saved_variant,
'SV0000002_1248367227_r0390_100': EXPECTED_SAVED_VARIANT,
},
'genesById': {'ENSG00000227232': expected_pa_gene, 'ENSG00000268903': EXPECTED_GENE, 'ENSG00000233653': EXPECTED_GENE},
'genesById': {
'ENSG00000227232': expected_pa_gene, 'ENSG00000268903': EXPECTED_GENE, 'ENSG00000233653': EXPECTED_GENE,
'ENSG00000177000': mock.ANY, 'ENSG00000097046': mock.ANY,
},
'transcriptsById': {'ENST00000624735': {'isManeSelect': False, 'refseqId': None, 'transcriptId': 'ENST00000624735'}},
'search': {
'search': SEARCH,
Expand Down Expand Up @@ -717,7 +720,9 @@ def test_query_single_variant(self, mock_get_variant):
expected_search_response['variantTagsByGuid'].pop('VT1726945_2103343353_r0390_100')
expected_search_response['variantTagsByGuid'].pop('VT1726970_2103343353_r0004_tes')
expected_search_response['variantNotesByGuid'] = {}
expected_search_response['genesById'].pop('ENSG00000233653')
expected_search_response['genesById'] = {
k: v for k, v in expected_search_response['genesById'].items() if k in {'ENSG00000227232', 'ENSG00000268903'}
}
expected_search_response['searchedVariants'] = [single_family_variant]
self.assertDictEqual(response_json, expected_search_response)
self._assert_expected_results_family_context(response_json, locus_list_detail=True)
Expand Down
Loading