Skip to content

Commit

Permalink
Merge branch 'hail-backend-comp-het' of https://github.com/broadinsti…
Browse files Browse the repository at this point in the history
…tute/seqr into hail-backend-sort
  • Loading branch information
hanars committed Aug 16, 2023
2 parents 55b6589 + e39dbe5 commit a0f6372
Show file tree
Hide file tree
Showing 21 changed files with 94 additions and 46 deletions.
3 changes: 3 additions & 0 deletions hail_search/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
ANNOTATION_OVERRIDE_FIELDS = [
SCREEN_KEY, SPLICE_AI_FIELD, NEW_SV_FIELD, STRUCTURAL_ANNOTATION_FIELD,
]
HAS_ALLOWED_ANNOTATION = 'has_allowed_annotation'
HAS_ALLOWED_SECONDARY_ANNOTATION = f'{HAS_ALLOWED_ANNOTATION}_secondary'

XPOS = 'xpos'

Expand Down Expand Up @@ -57,6 +59,7 @@
},
}

PREFILTER_FREQ_CUTOFF = 0.01
PATH_FREQ_OVERRIDE_CUTOFF = 0.05
CLINVAR_PATH_FILTER = 'pathogenic'
CLINVAR_LIKELY_PATH_FILTER = 'likely_pathogenic'
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
This folder comprises a Hail (www.hail.is) native Table or MatrixTable.
Written with version 0.2.109-b71b065e4bb6
Created at 2023/08/11 14:19:35
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
110 changes: 74 additions & 36 deletions hail_search/hail_search_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
ANY_AFFECTED, X_LINKED_RECESSIVE, REF_REF, REF_ALT, COMP_HET_ALT, ALT_ALT, HAS_ALT, HAS_REF, \
ANNOTATION_OVERRIDE_FIELDS, SCREEN_KEY, SPLICE_AI_FIELD, CLINVAR_KEY, HGMD_KEY, CLINVAR_PATH_SIGNIFICANCES, \
CLINVAR_PATH_FILTER, CLINVAR_LIKELY_PATH_FILTER, CLINVAR_PATH_RANGES, HGMD_PATH_RANGES, PATH_FREQ_OVERRIDE_CUTOFF, \
COMPOUND_HET, RECESSIVE, GROUPED_VARIANTS_FIELD, PATHOGENICTY_SORT_KEY, PATHOGENICTY_HGMD_SORT_KEY, \
ABSENT_PATH_SORT_OFFSET
PREFILTER_FREQ_CUTOFF, COMPOUND_HET, RECESSIVE, GROUPED_VARIANTS_FIELD, HAS_ALLOWED_ANNOTATION, \
HAS_ALLOWED_SECONDARY_ANNOTATION, PATHOGENICTY_SORT_KEY, PATHOGENICTY_HGMD_SORT_KEY, ABSENT_PATH_SORT_OFFSET

DATASETS_DIR = os.environ.get('DATASETS_DIR', '/hail_datasets')

Expand Down Expand Up @@ -158,11 +158,18 @@ def __init__(self, data_type, sample_data, genome_version, sort=XPOS, sort_metad
self._comp_het_ht = None
self._enums = None
self._globals = None
self._is_recessive_search = inheritance_mode == RECESSIVE
self._has_comp_het_search = inheritance_mode in {RECESSIVE, COMPOUND_HET}
self._inheritance_mode = inheritance_mode

self._load_filtered_table(sample_data, inheritance_mode=inheritance_mode, **kwargs)

@property
def _is_recessive_search(self):
return self._inheritance_mode == RECESSIVE

@property
def _has_comp_het_search(self):
return self._inheritance_mode in {RECESSIVE, COMPOUND_HET}

def _load_filtered_table(self, sample_data, intervals=None, exclude_intervals=False, variant_ids=None, **kwargs):
parsed_intervals, variant_ids = self._parse_intervals(intervals, variant_ids)
excluded_intervals = None
Expand All @@ -177,8 +184,8 @@ def _load_filtered_table(self, sample_data, intervals=None, exclude_intervals=Fa
self._comp_het_ht = self._filter_compound_hets()
if self._is_recessive_search:
self._ht = self._ht.filter(self._ht.family_entries.any(hl.is_defined))
if 'has_allowed_annotation_secondary' in self._ht.row:
self._ht = self._ht.filter(self._ht.has_allowed_annotation).drop('has_allowed_annotation_secondary')
if HAS_ALLOWED_SECONDARY_ANNOTATION in self._ht.row:
self._ht = self._ht.filter(self._ht[HAS_ALLOWED_ANNOTATION]).drop(HAS_ALLOWED_SECONDARY_ANNOTATION)
else:
self._ht = None

Expand Down Expand Up @@ -208,10 +215,10 @@ def import_filtered_table(self, sample_data, intervals=None, **kwargs):
try:
filtered_project_hts.append(self._filter_entries_table(project_ht, project_sample_data, **kwargs))
except HTTPBadRequest as e:
exception_messages.add(e.text)
exception_messages.add(e.reason)

if exception_messages:
raise HTTPBadRequest(text='; '.join(exception_messages))
raise HTTPBadRequest(reason='; '.join(exception_messages))

families_ht, num_families = filtered_project_hts[0]
entry_type = families_ht.family_entries.dtype.element_type
Expand Down Expand Up @@ -251,9 +258,10 @@ def _filter_entries_table(self, ht, sample_data, inheritance_mode=None, inherita
excluded_intervals=None, variant_ids=None, **kwargs):
if excluded_intervals:
ht = hl.filter_intervals(ht, excluded_intervals, keep=False)

if variant_ids:
elif variant_ids:
ht = self._filter_variant_ids(ht, variant_ids)
elif not self._load_table_kwargs['_intervals']:
ht = self._prefilter_entries_table(ht, **kwargs)

ht, sample_id_family_index_map, num_families = self._add_entry_sample_families(ht, sample_data)

Expand Down Expand Up @@ -283,7 +291,7 @@ def _add_entry_sample_families(cls, ht, sample_data):
missing_samples = set(sample_individual_map.keys()) - set(sample_id_index_map.keys())
if missing_samples:
raise HTTPBadRequest(
text=f'The following samples are available in seqr but missing the loaded data: {", ".join(sorted(missing_samples))}'
reason=f'The following samples are available in seqr but missing the loaded data: {", ".join(sorted(missing_samples))}'
)

affected_id_map = {AFFECTED: AFFECTED_ID, UNAFFECTED: UNAFFECTED_ID}
Expand Down Expand Up @@ -445,6 +453,9 @@ def _filter_variant_ids(self, ht, variant_ids):
variant_id_q |= q
return ht.filter(variant_id_q)

def _prefilter_entries_table(self, ht, **kwargs):
return ht

def _filter_annotated_table(self, gene_ids=None, rs_ids=None, frequencies=None, in_silico=None, pathogenicity=None,
annotations=None, annotations_secondary=None, **kwargs):
if gene_ids:
Expand Down Expand Up @@ -495,7 +506,7 @@ def _parse_intervals(self, intervals, variant_ids):
]
invalid_intervals = [raw_intervals[i] for i, interval in enumerate(parsed_intervals) if interval is None]
if invalid_intervals:
raise HTTPBadRequest(text=f'Invalid intervals: {", ".join(invalid_intervals)}')
raise HTTPBadRequest(reason=f'Invalid intervals: {", ".join(invalid_intervals)}')

return parsed_intervals, variant_ids

Expand Down Expand Up @@ -584,9 +595,9 @@ def _filter_by_annotations(self, pathogenicity, annotations, annotations_seconda
return

self._ht = self._ht.annotate(**annotation_exprs)
annotation_filter = self._ht.has_allowed_annotation
annotation_filter = self._ht[HAS_ALLOWED_ANNOTATION]
if has_secondary_annotations:
annotation_filter |= self._ht.has_allowed_annotation_secondary
annotation_filter |= self._ht[HAS_ALLOWED_SECONDARY_ANNOTATION]
self._ht = self._ht.filter(annotation_filter)

def _get_allowed_consequences_annotations(self, annotations, annotation_filters):
Expand All @@ -604,9 +615,9 @@ def _filter_compound_hets(self):
ch_ht = ch_ht.annotate(gene_ids=self._gene_ids_expr(ch_ht))
ch_ht = ch_ht.explode(ch_ht.gene_ids)
formatted_rows_expr = hl.agg.collect(ch_ht.row)
if 'has_allowed_annotation_secondary' in self._ht.row:
primary_variants = hl.agg.filter(ch_ht.has_allowed_annotation, formatted_rows_expr)
secondary_variants = hl.agg.filter(ch_ht.has_allowed_annotation_secondary, formatted_rows_expr)
if HAS_ALLOWED_SECONDARY_ANNOTATION in self._ht.row:
primary_variants = hl.agg.filter(ch_ht[HAS_ALLOWED_ANNOTATION], formatted_rows_expr)
secondary_variants = hl.agg.filter(ch_ht[HAS_ALLOWED_SECONDARY_ANNOTATION], formatted_rows_expr)
else:
primary_variants = formatted_rows_expr
secondary_variants = formatted_rows_expr
Expand Down Expand Up @@ -805,20 +816,20 @@ def _selected_main_transcript_expr(ht):
gene_transcripts = getattr(ht, 'gene_transcripts', None)

allowed_transcripts = getattr(ht, 'allowed_transcripts', None)
if hasattr(ht, 'has_allowed_annotation_secondary'):
if hasattr(ht, HAS_ALLOWED_SECONDARY_ANNOTATION):
allowed_transcripts = hl.if_else(
allowed_transcripts.any(hl.is_defined), allowed_transcripts, ht.allowed_transcripts_secondary,
) if allowed_transcripts is not None else ht.allowed_transcripts_secondary

main_transcript = ht.sorted_transcript_consequences.first()
if gene_transcripts is not None:
if gene_transcripts is not None and allowed_transcripts is not None:
allowed_transcript_ids = hl.set(allowed_transcripts.map(lambda t: t.transcript_id))
matched_transcript = hl.or_else(
gene_transcripts.find(lambda t: allowed_transcript_ids.contains(t.transcript_id)),
gene_transcripts.first(),
)
elif gene_transcripts is not None:
matched_transcript = gene_transcripts.first()
if allowed_transcripts is not None:
allowed_transcript_ids = hl.set(allowed_transcripts.map(lambda t: t.transcript_id))
matched_transcript = hl.or_else(
gene_transcripts.find(lambda t: allowed_transcript_ids.contains(t.transcript_id)),
matched_transcript,
)
elif allowed_transcripts is not None:
matched_transcript = allowed_transcripts.first()
else:
Expand All @@ -844,29 +855,56 @@ def _format_transcript_args(self):

def _get_family_passes_quality_filter(self, quality_filter, ht=None, pathogenicity=None, **kwargs):
passes_quality = super(VariantHailTableQuery, self)._get_family_passes_quality_filter(quality_filter)
clinvar_path_ht = False if passes_quality is None else self._get_clinvar_filter_ht(pathogenicity)
clinvar_path_ht = False if passes_quality is None else self._get_loaded_filter_ht(
CLINVAR_KEY, 'clinvar_path_variants.ht', self._get_clinvar_prefilter, pathogenicity=pathogenicity)
if not clinvar_path_ht:
return passes_quality

return lambda entries: hl.is_defined(clinvar_path_ht[ht.key]) | passes_quality(entries)

def _get_clinvar_filter_ht(self, pathogenicity):
if self._filter_hts.get(CLINVAR_KEY) is not None:
return self._filter_hts[CLINVAR_KEY]
def _get_loaded_filter_ht(self, key, table_path, get_filters, **kwargs):
if self._filter_hts.get(key) is None:
ht_filter = get_filters(**kwargs)
if ht_filter is False:
self._filter_hts[key] = False
else:
ht = self._read_table(table_path)
if ht_filter is not True:
ht = ht.filter(ht[ht_filter])
self._filter_hts[key] = ht

return self._filter_hts[key]

def _get_clinvar_prefilter(self, pathogenicity=None):
clinvar_path_filters = self._get_clinvar_path_filters(pathogenicity)
if not clinvar_path_filters:
self._filter_hts[CLINVAR_KEY] = False
return False

clinvar_path_ht = self._read_table('clinvar_path_variants.ht')
if CLINVAR_LIKELY_PATH_FILTER not in clinvar_path_filters:
clinvar_path_ht = clinvar_path_ht.filter(clinvar_path_ht.is_pathogenic)
return 'is_pathogenic'
elif CLINVAR_PATH_FILTER not in clinvar_path_filters:
clinvar_path_ht = clinvar_path_ht.filter(clinvar_path_ht.is_likely_pathogenic)
self._filter_hts[CLINVAR_KEY] = clinvar_path_ht
return 'is_likely_pathogenic'
return True

def _prefilter_entries_table(self, ht, **kwargs):
af_ht = self._get_loaded_filter_ht(
GNOMAD_GENOMES_FIELD, 'high_af_variants.ht', self._get_gnomad_af_prefilter, **kwargs)
if af_ht:
ht = ht.filter(hl.is_missing(af_ht[ht.key]))
return ht

def _get_gnomad_af_prefilter(self, frequencies=None, pathogenicity=None, **kwargs):
gnomad_genomes_filter = (frequencies or {}).get(GNOMAD_GENOMES_FIELD, {})
af_cutoff = gnomad_genomes_filter.get('af')
if af_cutoff is None and gnomad_genomes_filter.get('ac') is not None:
af_cutoff = PREFILTER_FREQ_CUTOFF
if af_cutoff is None:
return False

if self._get_clinvar_path_filters(pathogenicity):
af_cutoff = max(af_cutoff, PATH_FREQ_OVERRIDE_CUTOFF)

return clinvar_path_ht
return 'is_gt_10_percent' if af_cutoff > PREFILTER_FREQ_CUTOFF else True

def _get_gene_id_filter(self, gene_ids):
self._ht = self._ht.annotate(
Expand All @@ -892,7 +930,7 @@ def _get_allowed_consequences_annotations(self, annotations, annotation_filters)
annotation_filters = annotation_filters + [hl.is_defined(allowed_transcripts.first())]

if annotation_filters:
annotation_exprs['has_allowed_annotation'] = hl.any(annotation_filters)
annotation_exprs[HAS_ALLOWED_ANNOTATION] = hl.any(annotation_filters)

return annotation_exprs

Expand Down
24 changes: 14 additions & 10 deletions hail_search/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,15 +260,19 @@ async def test_frequency_filter(self):
)

await self._assert_expected_search(
[VARIANT1, VARIANT2, VARIANT4], frequencies={'gnomad_genomes': {'af': 0.41}}, omit_sample_type='SV_WES',
[VARIANT1, VARIANT2, VARIANT4], frequencies={'gnomad_genomes': {'af': 0.05}}, omit_sample_type='SV_WES',
)

await self._assert_expected_search(
[VARIANT2, VARIANT4], frequencies={'gnomad_genomes': {'af': 0.41, 'hh': 1}}, omit_sample_type='SV_WES',
[VARIANT2, VARIANT4], frequencies={'gnomad_genomes': {'af': 0.05, 'hh': 1}}, omit_sample_type='SV_WES',
)

await self._assert_expected_search(
[VARIANT1, VARIANT4], frequencies={'seqr': {'af': 0.2}, 'gnomad_genomes': {'af': 0.41}},
[VARIANT2, VARIANT4], frequencies={'gnomad_genomes': {'af': 0.005}}, omit_sample_type='SV_WES',
)

await self._assert_expected_search(
[VARIANT4], frequencies={'seqr': {'af': 0.2}, 'gnomad_genomes': {'ac': 50}},
omit_sample_type='SV_WES',
)

Expand All @@ -280,7 +284,7 @@ async def test_frequency_filter(self):
annotations = {'splice_ai': '0.0'} # Ensures no variants are filtered out by annotation/path filters
await self._assert_expected_search(
[VARIANT1, VARIANT2, VARIANT4], frequencies={'gnomad_genomes': {'af': 0.01}}, omit_sample_type='SV_WES',
annotations=annotations, pathogenicity={'clinvar': ['likely_pathogenic', 'vus_or_conflicting']},
annotations=annotations, pathogenicity={'clinvar': ['pathogenic', 'likely_pathogenic', 'vus_or_conflicting']},
)

await self._assert_expected_search(
Expand Down Expand Up @@ -387,22 +391,22 @@ async def test_search_errors(self):
search_body = get_hail_search_body(sample_data=FAMILY_2_MISSING_SAMPLE_DATA)
async with self.client.request('POST', '/search', json=search_body) as resp:
self.assertEqual(resp.status, 400)
text = await resp.text()
self.assertEqual(text, 'The following samples are available in seqr but missing the loaded data: NA19675, NA19678')
reason = resp.reason
self.assertEqual(reason, 'The following samples are available in seqr but missing the loaded data: NA19675, NA19678')

search_body = get_hail_search_body(sample_data=MULTI_PROJECT_MISSING_SAMPLE_DATA)
async with self.client.request('POST', '/search', json=search_body) as resp:
self.assertEqual(resp.status, 400)
text = await resp.text()
self.assertEqual(text, 'The following samples are available in seqr but missing the loaded data: NA19675, NA19678')
reason = resp.reason
self.assertEqual(reason, 'The following samples are available in seqr but missing the loaded data: NA19675, NA19678')

search_body = get_hail_search_body(
intervals=LOCATION_SEARCH['intervals'] + ['1:1-99999999999'], omit_sample_type='SV_WES',
)
async with self.client.request('POST', '/search', json=search_body) as resp:
self.assertEqual(resp.status, 400)
text = await resp.text()
self.assertEqual(text, 'Invalid intervals: 1:1-99999999999')
reason = resp.reason
self.assertEqual(reason, 'Invalid intervals: 1:1-99999999999')

async def test_sort(self):
await self._assert_expected_search(
Expand Down

0 comments on commit a0f6372

Please sign in to comment.