diff --git a/hail_search/queries/base.py b/hail_search/queries/base.py index 972b33dd15..9c853cd4b2 100644 --- a/hail_search/queries/base.py +++ b/hail_search/queries/base.py @@ -214,6 +214,7 @@ def __init__(self, sample_data, sort=XPOS, sort_metadata=None, num_results=100, self._is_multi_data_type_comp_het = False self.max_unaffected_samples = None self._load_table_kwargs = {} + self.entry_samples_by_family_guid = {} if sample_data: self._load_filtered_table(sample_data, **kwargs) @@ -367,8 +368,7 @@ def _filter_entries_table(self, ht, sample_data, inheritance_filter=None, qualit return ht, ch_ht - @classmethod - def _add_entry_sample_families(cls, ht, sample_data): + def _add_entry_sample_families(self, ht, sample_data): ht_globals = hl.eval(ht.globals) missing_samples = set() @@ -381,12 +381,14 @@ def _add_entry_sample_families(cls, ht, sample_data): if missing_family_samples: missing_samples.update(missing_family_samples) else: + sample_index_data = [ + (ht_family_samples.index(s['sample_id']), self._sample_entry_data(s, family_guid, ht_globals)) + for s in samples + ] family_sample_index_data.append( - (ht_globals.family_guids.index(family_guid), [ - (ht_family_samples.index(s['sample_id']), cls._sample_entry_data(s, family_guid, ht_globals)) - for s in samples - ]) + (ht_globals.family_guids.index(family_guid), sample_index_data) ) + self.entry_samples_by_family_guid[family_guid] = [s['sampleId'] for _, s in sample_index_data] if missing_samples: raise HTTPBadRequest( diff --git a/hail_search/queries/multi_data_types.py b/hail_search/queries/multi_data_types.py index 2c69c25275..33fae6d12d 100644 --- a/hail_search/queries/multi_data_types.py +++ b/hail_search/queries/multi_data_types.py @@ -45,12 +45,13 @@ def _load_filtered_table(self, *args, **kwargs): for data_type in sv_data_types: self._current_sv_data_type = data_type sv_query = self._data_type_queries[data_type] - self.max_unaffected_samples = max(variant_query.max_unaffected_samples, sv_query.max_unaffected_samples) - merged_ht = self._filter_data_type_comp_hets(variant_ht, variant_families, sv_query) + self.max_unaffected_samples = min(variant_query.max_unaffected_samples, sv_query.max_unaffected_samples) + merged_ht = self._filter_data_type_comp_hets(variant_query, variant_families, sv_query) if merged_ht is not None: self._comp_het_hts[data_type] = merged_ht.key_by() - def _filter_data_type_comp_hets(self, variant_ht, variant_families, sv_query): + def _filter_data_type_comp_hets(self, variant_query, variant_families, sv_query): + variant_ht = variant_query.unfiltered_comp_het_ht sv_ht = sv_query.unfiltered_comp_het_ht sv_type_del_ids = sv_query.get_allowed_sv_type_ids([f'{getattr(sv_query, "SV_TYPE_PREFIX", "")}DEL']) self._sv_type_del_id = list(sv_type_del_ids)[0] if sv_type_del_ids else None @@ -63,6 +64,19 @@ def _filter_data_type_comp_hets(self, variant_ht, variant_families, sv_query): if variant_families != sv_families: variant_ht = self._family_filtered_ch_ht(variant_ht, overlapped_families, variant_families) sv_ht = self._family_filtered_ch_ht(sv_ht, overlapped_families, sv_families) + else: + overlapped_families = variant_families + + variant_samples_by_family = variant_query.entry_samples_by_family_guid + sv_samples_by_family = sv_query.entry_samples_by_family_guid + if any(f for f in overlapped_families if variant_samples_by_family[f] != sv_samples_by_family[f]): + sv_sample_indices = hl.array([[ + sv_samples_by_family[f].index(s) if s in sv_samples_by_family[f] else None + for s in variant_samples_by_family[f] + ] for f in overlapped_families]) + sv_ht = sv_ht.annotate(family_entries=hl.enumerate(sv_sample_indices).starmap( + lambda family_i, indices: indices.map(lambda sample_i: sv_ht.family_entries[family_i][sample_i]) + )) variant_ch_ht = variant_ht.group_by('gene_ids').aggregate(v1=hl.agg.collect(variant_ht.row)) sv_ch_ht = sv_ht.group_by('gene_ids').aggregate(v2=hl.agg.collect(sv_ht.row)) diff --git a/hail_search/test_search.py b/hail_search/test_search.py index c02e1c0935..7a4b5bdf7a 100644 --- a/hail_search/test_search.py +++ b/hail_search/test_search.py @@ -203,6 +203,10 @@ async def _assert_expected_search(self, results, gene_counts=None, **search_kwar self.assertSetEqual(set(resp_json.keys()), {'results', 'total'}) self.assertEqual(resp_json['total'], len(results)) for i, result in enumerate(resp_json['results']): + if result != results[i]: + diff_0 = {k for k, v in results[i][0].items() if v != result[0][k]} + diff_1 = {k for k, v in results[i][1].items() if v != result[1][k]} + import pdb; pdb.set_trace() self.assertEqual(result, results[i]) if gene_counts: @@ -905,6 +909,18 @@ async def test_secondary_annotations_filter(self): annotations=gcnv_annotations_2, annotations_secondary=selected_transcript_annotations, ) + # Search works with a different number of samples within the family + missing_gt_gcnv_variant = { + **GCNV_VARIANT4, 'genotypes': {k: v for k, v in GCNV_VARIANT4['genotypes'].items() if k != 'I000005_hg00732'} + } + await self._assert_expected_search( + [[MULTI_DATA_TYPE_COMP_HET_VARIANT2, missing_gt_gcnv_variant]], + inheritance_mode='compound_het', pathogenicity=pathogenicity, + annotations=gcnv_annotations_2, annotations_secondary=selected_transcript_annotations, + sample_data={**EXPECTED_SAMPLE_DATA, 'SV_WES': [EXPECTED_SAMPLE_DATA['SV_WES'][0], EXPECTED_SAMPLE_DATA['SV_WES'][2]]} + + ) + # Do not return pairs where annotations match in a non-paired gene await self._assert_expected_search( [GCNV_VARIANT3], inheritance_mode='recessive',