Skip to content

Commit

Permalink
Merge pull request #3550 from broadinstitute/hail-backend-sort
Browse files Browse the repository at this point in the history
Hail backend sort
  • Loading branch information
hanars authored Aug 17, 2023
2 parents 719eac6 + 89f9f50 commit 81f4ec0
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 11 deletions.
4 changes: 4 additions & 0 deletions hail_search/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@

XPOS = 'xpos'

PATHOGENICTY_SORT_KEY = 'pathogenicity'
PATHOGENICTY_HGMD_SORT_KEY = 'pathogenicity_hgmd'
ABSENT_PATH_SORT_OFFSET = 12.5

ALT_ALT = 'alt_alt'
REF_REF = 'ref_ref'
REF_ALT = 'ref_alt'
Expand Down
73 changes: 66 additions & 7 deletions hail_search/hail_search_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
ANNOTATION_OVERRIDE_FIELDS, SCREEN_KEY, SPLICE_AI_FIELD, CLINVAR_KEY, HGMD_KEY, CLINVAR_PATH_SIGNIFICANCES, \
CLINVAR_PATH_FILTER, CLINVAR_LIKELY_PATH_FILTER, CLINVAR_PATH_RANGES, HGMD_PATH_RANGES, PATH_FREQ_OVERRIDE_CUTOFF, \
PREFILTER_FREQ_CUTOFF, COMPOUND_HET, RECESSIVE, GROUPED_VARIANTS_FIELD, HAS_ALLOWED_ANNOTATION, \
HAS_ALLOWED_SECONDARY_ANNOTATION
HAS_ALLOWED_SECONDARY_ANNOTATION, PATHOGENICTY_SORT_KEY, PATHOGENICTY_HGMD_SORT_KEY, ABSENT_PATH_SORT_OFFSET

DATASETS_DIR = os.environ.get('DATASETS_DIR', '/hail_datasets')

Expand Down Expand Up @@ -76,6 +76,7 @@ def load_globals(cls):
def _format_population_config(cls, pop_config):
base_pop_config = {field.lower(): field for field in cls.POPULATION_KEYS}
base_pop_config.update(pop_config)
base_pop_config.pop('sort', None)
return base_pop_config

def annotation_fields(self):
Expand Down Expand Up @@ -159,9 +160,10 @@ def _enum_field(value, enum, ht_globals=None, annotate_value=None, format_value=

return value

def __init__(self, sample_data, genome_version, sort=XPOS, num_results=100, inheritance_mode=None, **kwargs):
def __init__(self, sample_data, genome_version, sort=XPOS, sort_metadata=None, num_results=100, inheritance_mode=None, **kwargs):
self._genome_version = genome_version
self._sort = sort
self._sort_metadata = sort_metadata
self._num_results = num_results
self._ht = None
self._comp_het_ht = None
Expand Down Expand Up @@ -626,7 +628,7 @@ def _filter_compound_hets(self):
ch_ht = ch_ht.filter(ch_ht.comp_het_family_entries.any(hl.is_defined))

# Get possible pairs of variants within the same gene
ch_ht = ch_ht.annotate(gene_ids=hl.set(ch_ht.sorted_transcript_consequences.map(lambda t: t.gene_id)))
ch_ht = ch_ht.annotate(gene_ids=self._gene_ids_expr(ch_ht))
ch_ht = ch_ht.explode(ch_ht.gene_ids)
formatted_rows_expr = hl.agg.collect(ch_ht.row)
if HAS_ALLOWED_SECONDARY_ANNOTATION in self._ht.row:
Expand Down Expand Up @@ -660,6 +662,10 @@ def _filter_compound_hets(self):

return ch_ht

@staticmethod
def _gene_ids_expr(ht):
raise NotImplementedError

def _is_valid_comp_het_family(self, entries_1, entries_2):
return hl.is_defined(entries_1) & hl.is_defined(entries_2) & hl.enumerate(entries_1).all(lambda x: hl.any([
(x[1].affected_id != UNAFFECTED_ID),
Expand Down Expand Up @@ -711,7 +717,32 @@ def _sort_order(self, ht):
return sort_expressions

def _get_sort_expressions(self, ht, sort):
return self.SORTS[sort](ht)
if sort in self.SORTS:
return self.SORTS[sort](ht)

if sort in self.PREDICTION_FIELDS_CONFIG:
prediction_path = self.PREDICTION_FIELDS_CONFIG[sort]
return [-hl.float64(ht[prediction_path.source][prediction_path.field])]

if sort == 'in_omim':
return self._omim_sort(ht, hl.set(set(self._sort_metadata)))

if self._sort_metadata:
return self._gene_rank_sort(ht, hl.dict(self._sort_metadata))

sort_field = next((field for field, config in self.POPULATIONS.items() if config.get('sort') == sort), None)
if sort_field:
return [hl.float64(self.population_expression(ht, sort_field).af)]

return []

@classmethod
def _omim_sort(cls, r, omim_gene_set):
return []

@classmethod
def _gene_rank_sort(cls, r, gene_ranks):
return []


class VariantHailTableQuery(BaseHailTableQuery):
Expand All @@ -723,14 +754,14 @@ class VariantHailTableQuery(BaseHailTableQuery):
'AB': QualityFilterFormat(override=lambda gt: ~gt.GT.is_het(), scale=100),
}
POPULATIONS = {
'seqr': {'hom': 'hom', 'hemi': None, 'het': None},
'seqr': {'hom': 'hom', 'hemi': None, 'het': None, 'sort': 'callset_af'},
'topmed': {'hemi': None},
'exac': {
'filter_af': 'AF_POPMAX', 'ac': 'AC_Adj', 'an': 'AN_Adj', 'hom': 'AC_Hom', 'hemi': 'AC_Hemi',
'het': 'AC_Het',
},
'gnomad_exomes': {'filter_af': 'AF_POPMAX_OR_GLOBAL', 'het': None},
GNOMAD_GENOMES_FIELD: {'filter_af': 'AF_POPMAX_OR_GLOBAL', 'het': None},
'gnomad_exomes': {'filter_af': 'AF_POPMAX_OR_GLOBAL', 'het': None, 'sort': 'gnomad_exomes'},
GNOMAD_GENOMES_FIELD: {'filter_af': 'AF_POPMAX_OR_GLOBAL', 'het': None, 'sort': 'gnomad'},
}
POPULATION_FIELDS = {'seqr': 'gt_stats'}
PREDICTION_FIELDS_CONFIG = {
Expand Down Expand Up @@ -782,6 +813,16 @@ class VariantHailTableQuery(BaseHailTableQuery):
},
}

SORTS = {
'protein_consequence': lambda r: [
hl.min(r.sorted_transcript_consequences.flatmap(lambda t: t.consequence_term_ids)),
hl.min(r.selected_transcript.consequence_term_ids),
],
PATHOGENICTY_SORT_KEY: lambda r: [hl.or_else(r.clinvar.pathogenicity_id, ABSENT_PATH_SORT_OFFSET)],
}
SORTS[PATHOGENICTY_HGMD_SORT_KEY] = lambda r: VariantHailTableQuery.SORTS[PATHOGENICTY_SORT_KEY](r) + [r.hgmd.class_id]
SORTS.update(BaseHailTableQuery.SORTS)

@staticmethod
def _selected_main_transcript_expr(ht):
gene_id = getattr(ht, 'gene_id', None)
Expand Down Expand Up @@ -957,5 +998,23 @@ def _format_results(self, ht, annotation_fields):
ht = ht.annotate(selected_transcript=self._selected_main_transcript_expr(ht))
return super()._format_results(ht, annotation_fields)

@staticmethod
def _gene_ids_expr(ht):
return hl.set(ht.sorted_transcript_consequences.map(lambda t: t.gene_id))

@classmethod
def _omim_sort(cls, r, omim_gene_set):
return [
hl.if_else(omim_gene_set.contains(r.selected_transcript.gene_id), 0, 1),
-cls._gene_ids_expr(r).intersection(omim_gene_set).size(),
]

@classmethod
def _gene_rank_sort(cls, r, gene_ranks):
return [
gene_ranks.get(r.selected_transcript.gene_id),
hl.min(cls._gene_ids_expr(r).map(gene_ranks.get)),
]


QUERY_CLASS_MAP = {cls.DATA_TYPE: cls for cls in [VariantHailTableQuery]}
113 changes: 109 additions & 4 deletions hail_search/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@
'numAlt': 1, 'dp': 28, 'gq': 99, 'ab': 0.5,
}

# Ensures no variants are filtered out by annotation/path filters for compound hets
COMP_HET_ALL_PASS_FILTERS = {'annotations': {'splice_ai': '0.0'}, 'pathogenicity': {'clinvar': ['likely_pathogenic']}}


def _sorted(variant, sorts):
return {**variant, '_sort': sorts + variant['_sort']}


class HailSearchTestCase(AioHTTPTestCase):

Expand Down Expand Up @@ -155,16 +162,14 @@ async def test_inheritance_filter(self):
await self._assert_expected_search(
[VARIANT2, VARIANT3], inheritance_filter=gt_inheritance_filter, sample_data=FAMILY_2_VARIANT_SAMPLE_DATA)

# Ensures no variants are filtered out by annotation/path filters for compound hets
comp_het_filters = {'annotations': {'splice_ai': '0.0'}, 'pathogenicity': {'clinvar': ['likely_pathogenic']}}
await self._assert_expected_search(
[[VARIANT3, VARIANT4]], inheritance_mode='compound_het', sample_data=MULTI_PROJECT_SAMPLE_DATA,
**comp_het_filters,
**COMP_HET_ALL_PASS_FILTERS,
)

await self._assert_expected_search(
[PROJECT_2_VARIANT1, VARIANT2, [VARIANT3, VARIANT4]], inheritance_mode='recessive',
sample_data=MULTI_PROJECT_SAMPLE_DATA, **comp_het_filters,
sample_data=MULTI_PROJECT_SAMPLE_DATA, **COMP_HET_ALL_PASS_FILTERS,
)

async def test_quality_filter(self):
Expand Down Expand Up @@ -402,3 +407,103 @@ async def test_search_errors(self):
self.assertEqual(resp.status, 400)
reason = resp.reason
self.assertEqual(reason, 'Invalid intervals: 1:1-99999999999')

async def test_sort(self):
await self._assert_expected_search(
[_sorted(VARIANT2, [11, 11]), _sorted(VARIANT4, [11, 11]), _sorted(MULTI_FAMILY_VARIANT, [22, 24]),
_sorted(VARIANT1, [None, None])], omit_sample_type='SV_WES', sort='protein_consequence',
)

await self._assert_expected_search(
[_sorted(VARIANT4, [11, 11]), _sorted(SELECTED_ANNOTATION_TRANSCRIPT_VARIANT_2, [11, 22]),
_sorted(SELECTED_ANNOTATION_TRANSCRIPT_MULTI_FAMILY_VARIANT, [22, 22])],
omit_sample_type='SV_WES', sort='protein_consequence',
annotations={'other': ['non_coding_transcript_exon_variant'], 'splice_ai': '0'},
)

await self._assert_expected_search(
[_sorted(VARIANT1, [4]), _sorted(VARIANT2, [8]), _sorted(MULTI_FAMILY_VARIANT, [12.5]),
_sorted(VARIANT4, [12.5])], omit_sample_type='SV_WES', sort='pathogenicity',
)

await self._assert_expected_search(
[_sorted(VARIANT1, [4, None]), _sorted(VARIANT2, [8, 3]), _sorted(MULTI_FAMILY_VARIANT, [12.5, None]),
_sorted(VARIANT4, [12.5, None])], omit_sample_type='SV_WES', sort='pathogenicity_hgmd',
)

await self._assert_expected_search(
[_sorted(VARIANT2, [0]), _sorted(VARIANT4, [0.00026519427774474025]),
_sorted(VARIANT1, [0.034449315071105957]), _sorted(MULTI_FAMILY_VARIANT, [0.38041073083877563])],
omit_sample_type='SV_WES', sort='gnomad',
)

await self._assert_expected_search(
[_sorted(VARIANT1, [0]), _sorted(MULTI_FAMILY_VARIANT, [0]), _sorted(VARIANT4, [0]),
_sorted(VARIANT2, [0.28899794816970825])], omit_sample_type='SV_WES', sort='gnomad_exomes',
)

await self._assert_expected_search(
[_sorted(VARIANT4, [0.02222222276031971]), _sorted(VARIANT1, [0.10000000149011612]),
_sorted(VARIANT2, [0.31111112236976624]), _sorted(MULTI_FAMILY_VARIANT, [0.6666666865348816])],
omit_sample_type='SV_WES', sort='callset_af',
)

await self._assert_expected_search(
[_sorted(VARIANT4, [-29.899999618530273]), _sorted(VARIANT2, [-20.899999618530273]),
_sorted(VARIANT1, [-4.668000221252441]), _sorted(MULTI_FAMILY_VARIANT, [-2.753999948501587]), ],
omit_sample_type='SV_WES', sort='cadd',
)

await self._assert_expected_search(
[_sorted(VARIANT4, [-0.5260000228881836]), _sorted(VARIANT2, [-0.19699999690055847]),
_sorted(VARIANT1, [None]), _sorted(MULTI_FAMILY_VARIANT, [None])], omit_sample_type='SV_WES', sort='revel',
)

await self._assert_expected_search(
[_sorted(MULTI_FAMILY_VARIANT, [-0.009999999776482582]), _sorted(VARIANT2, [0]), _sorted(VARIANT4, [0]),
_sorted(VARIANT1, [None])], omit_sample_type='SV_WES', sort='splice_ai',
)

await self._assert_expected_search(
[_sorted(MULTI_FAMILY_VARIANT, [0, -2]), _sorted(VARIANT2, [0, -1]), _sorted(VARIANT4, [0, -1]), _sorted(VARIANT1, [1, 0])],
omit_sample_type='SV_WES', sort='in_omim', sort_metadata=['ENSG00000177000', 'ENSG00000097046'],
)

await self._assert_expected_search(
[_sorted(VARIANT2, [0, -1]), _sorted(MULTI_FAMILY_VARIANT, [1, -1]), _sorted(VARIANT1, [1, 0]), _sorted(VARIANT4, [1, 0])],
omit_sample_type='SV_WES', sort='in_omim', sort_metadata=['ENSG00000177000'],
)

await self._assert_expected_search(
[_sorted(VARIANT2, [2, 2]), _sorted(MULTI_FAMILY_VARIANT, [4, 2]), _sorted(VARIANT4, [4, 4]),
_sorted(VARIANT1, [None, None])], omit_sample_type='SV_WES', sort='constraint',
sort_metadata={'ENSG00000177000': 2, 'ENSG00000097046': 4},
)

await self._assert_expected_search(
[_sorted(VARIANT2, [3, 3]), _sorted(MULTI_FAMILY_VARIANT, [None, 3]), _sorted(VARIANT1, [None, None]),
_sorted(VARIANT4, [None, None])], omit_sample_type='SV_WES', sort='prioritized_gene',
sort_metadata={'ENSG00000177000': 3},
)

# size sort only applies to SVs, so has no impact on other variants
await self._assert_expected_search(
[VARIANT1, VARIANT2, MULTI_FAMILY_VARIANT, VARIANT4], sort='size', omit_sample_type='SV_WES',
)

# sort applies to compound hets
await self._assert_expected_search(
[_sorted(VARIANT2, [11, 11]), [_sorted(VARIANT4, [11, 11]), _sorted(VARIANT3, [22, 24])]],
sort='protein_consequence', inheritance_mode='recessive', omit_sample_type='SV_WES', **COMP_HET_ALL_PASS_FILTERS,
)

await self._assert_expected_search(
[[_sorted(VARIANT4, [-0.5260000228881836]), _sorted(VARIANT3, [None])],
_sorted(VARIANT2, [-0.19699999690055847])],
sort='revel', inheritance_mode='recessive', omit_sample_type='SV_WES', **COMP_HET_ALL_PASS_FILTERS,
)

await self._assert_expected_search(
[[_sorted(VARIANT3, [-0.009999999776482582]), _sorted(VARIANT4, [0])], _sorted(VARIANT2, [0])],
sort='splice_ai', inheritance_mode='recessive', omit_sample_type='SV_WES', **COMP_HET_ALL_PASS_FILTERS,
)

0 comments on commit 81f4ec0

Please sign in to comment.