diff --git a/hail_search/hail_search_query.py b/hail_search/hail_search_query.py index 75b2fa7182..def9bbf10e 100644 --- a/hail_search/hail_search_query.py +++ b/hail_search/hail_search_query.py @@ -104,6 +104,12 @@ def _format_transcript_args(self): 'format_value': lambda value: value.rename({k: _to_camel_case(k) for k in value.keys()}), } + def _get_enum_lookup(self, field, subfield): + enum_field = self._enums.get(field, {}).get(subfield) + if enum_field is None: + return None + return {v: i for i, v in enumerate(enum_field)} + @staticmethod def _enum_field(value, enum, ht_globals=None, annotate_value=None, format_value=None, drop_fields=None, **kwargs): annotations = {} @@ -404,7 +410,7 @@ def _filter_variant_ids(self, ht, variant_ids): variant_id_q |= q return ht.filter(variant_id_q) - def _filter_annotated_table(self, gene_ids=None, rs_ids=None, frequencies=None, **kwargs): + def _filter_annotated_table(self, gene_ids=None, rs_ids=None, frequencies=None, in_silico=None, **kwargs): if gene_ids: self._filter_by_gene_ids(gene_ids) @@ -413,6 +419,8 @@ def _filter_annotated_table(self, gene_ids=None, rs_ids=None, frequencies=None, self._filter_by_frequency(frequencies) + self._filter_by_in_silico(in_silico) + def _filter_by_gene_ids(self, gene_ids): gene_ids = hl.set(gene_ids) self._ht = self._ht.filter(self._ht.sorted_transcript_consequences.any(lambda t: gene_ids.contains(t.gene_id))) @@ -484,6 +492,34 @@ def _filter_by_frequency(self, frequencies): pop_filter &= pf self._ht = self._ht.filter(hl.is_missing(pop_expr) | pop_filter) + def _filter_by_in_silico(self, in_silico_filters): + in_silico_filters = in_silico_filters or {} + require_score = in_silico_filters.get('requireScore', False) + in_silico_filters = {k: v for k, v in in_silico_filters.items() if k in self.PREDICTION_FIELDS_CONFIG and v} + if not in_silico_filters: + return + + in_silico_qs = [] + missing_qs = [] + for in_silico, value in in_silico_filters.items(): + score_path = self.PREDICTION_FIELDS_CONFIG[in_silico] + enum_lookup = self._get_enum_lookup(*score_path) + if enum_lookup is not None: + ht_value = self._ht[score_path.source][f'{score_path.field}_id'] + score_filter = ht_value == enum_lookup[value] + else: + ht_value = self._ht[score_path.source][score_path.field] + score_filter = ht_value >= float(value) + + in_silico_qs.append(score_filter) + if not require_score: + missing_qs.append(hl.is_missing(ht_value)) + + if missing_qs: + in_silico_qs.append(hl.all(missing_qs)) + + self._ht = self._ht.filter(hl.any(in_silico_qs)) + def _format_results(self, ht): annotations = {k: v(ht) for k, v in self.annotation_fields().items()} annotations.update({ diff --git a/hail_search/test_search.py b/hail_search/test_search.py index beb2e0a8eb..5fcee973d9 100644 --- a/hail_search/test_search.py +++ b/hail_search/test_search.py @@ -245,6 +245,17 @@ async def test_frequency_filter(self): omit_sample_type='SV_WES', ) + async def test_in_silico_filter(self): + in_silico = {'eigen': '5.5', 'mut_taster': 'P'} + await self._assert_expected_search( + [VARIANT1, VARIANT2, VARIANT4], in_silico=in_silico, omit_sample_type='SV_WES', + ) + + in_silico['requireScore'] = True + await self._assert_expected_search( + [VARIANT2, VARIANT4], in_silico=in_silico, omit_sample_type='SV_WES', + ) + async def test_search_errors(self): search_body = get_hail_search_body(sample_data=FAMILY_2_MISSING_SAMPLE_DATA) async with self.client.request('POST', '/search', json=search_body) as resp: