Skip to content

Commit

Permalink
accurately compute comp hets if different samples
Browse files Browse the repository at this point in the history
  • Loading branch information
hanars committed Feb 14, 2024
1 parent eec6bb8 commit 2ee7b3e
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 9 deletions.
14 changes: 8 additions & 6 deletions hail_search/queries/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def __init__(self, sample_data, sort=XPOS, sort_metadata=None, num_results=100,
self._is_multi_data_type_comp_het = False
self.max_unaffected_samples = None
self._load_table_kwargs = {}
self.entry_samples_by_family_guid = {}

if sample_data:
self._load_filtered_table(sample_data, **kwargs)
Expand Down Expand Up @@ -367,8 +368,7 @@ def _filter_entries_table(self, ht, sample_data, inheritance_filter=None, qualit

return ht, ch_ht

@classmethod
def _add_entry_sample_families(cls, ht, sample_data):
def _add_entry_sample_families(self, ht, sample_data):
ht_globals = hl.eval(ht.globals)

missing_samples = set()
Expand All @@ -381,12 +381,14 @@ def _add_entry_sample_families(cls, ht, sample_data):
if missing_family_samples:
missing_samples.update(missing_family_samples)
else:
sample_index_data = [
(ht_family_samples.index(s['sample_id']), self._sample_entry_data(s, family_guid, ht_globals))
for s in samples
]
family_sample_index_data.append(
(ht_globals.family_guids.index(family_guid), [
(ht_family_samples.index(s['sample_id']), cls._sample_entry_data(s, family_guid, ht_globals))
for s in samples
])
(ht_globals.family_guids.index(family_guid), sample_index_data)
)
self.entry_samples_by_family_guid[family_guid] = [s['sampleId'] for _, s in sample_index_data]

if missing_samples:
raise HTTPBadRequest(
Expand Down
20 changes: 17 additions & 3 deletions hail_search/queries/multi_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ def _load_filtered_table(self, *args, **kwargs):
for data_type in sv_data_types:
self._current_sv_data_type = data_type
sv_query = self._data_type_queries[data_type]
self.max_unaffected_samples = max(variant_query.max_unaffected_samples, sv_query.max_unaffected_samples)
merged_ht = self._filter_data_type_comp_hets(variant_ht, variant_families, sv_query)
self.max_unaffected_samples = min(variant_query.max_unaffected_samples, sv_query.max_unaffected_samples)
merged_ht = self._filter_data_type_comp_hets(variant_query, variant_families, sv_query)
if merged_ht is not None:
self._comp_het_hts[data_type] = merged_ht.key_by()

def _filter_data_type_comp_hets(self, variant_ht, variant_families, sv_query):
def _filter_data_type_comp_hets(self, variant_query, variant_families, sv_query):
variant_ht = variant_query.unfiltered_comp_het_ht
sv_ht = sv_query.unfiltered_comp_het_ht
sv_type_del_ids = sv_query.get_allowed_sv_type_ids([f'{getattr(sv_query, "SV_TYPE_PREFIX", "")}DEL'])
self._sv_type_del_id = list(sv_type_del_ids)[0] if sv_type_del_ids else None
Expand All @@ -63,6 +64,19 @@ def _filter_data_type_comp_hets(self, variant_ht, variant_families, sv_query):
if variant_families != sv_families:
variant_ht = self._family_filtered_ch_ht(variant_ht, overlapped_families, variant_families)
sv_ht = self._family_filtered_ch_ht(sv_ht, overlapped_families, sv_families)
else:
overlapped_families = variant_families

variant_samples_by_family = variant_query.entry_samples_by_family_guid
sv_samples_by_family = sv_query.entry_samples_by_family_guid
if any(f for f in overlapped_families if variant_samples_by_family[f] != sv_samples_by_family[f]):
sv_sample_indices = hl.array([[
sv_samples_by_family[f].index(s) if s in sv_samples_by_family[f] else None
for s in variant_samples_by_family[f]
] for f in overlapped_families])
sv_ht = sv_ht.annotate(family_entries=hl.enumerate(sv_sample_indices).starmap(
lambda family_i, indices: indices.map(lambda sample_i: sv_ht.family_entries[family_i][sample_i])
))

variant_ch_ht = variant_ht.group_by('gene_ids').aggregate(v1=hl.agg.collect(variant_ht.row))
sv_ch_ht = sv_ht.group_by('gene_ids').aggregate(v2=hl.agg.collect(sv_ht.row))
Expand Down
16 changes: 16 additions & 0 deletions hail_search/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ async def _assert_expected_search(self, results, gene_counts=None, **search_kwar
self.assertSetEqual(set(resp_json.keys()), {'results', 'total'})
self.assertEqual(resp_json['total'], len(results))
for i, result in enumerate(resp_json['results']):
if result != results[i]:
diff_0 = {k for k, v in results[i][0].items() if v != result[0][k]}
diff_1 = {k for k, v in results[i][1].items() if v != result[1][k]}
import pdb; pdb.set_trace()
self.assertEqual(result, results[i])

if gene_counts:
Expand Down Expand Up @@ -905,6 +909,18 @@ async def test_secondary_annotations_filter(self):
annotations=gcnv_annotations_2, annotations_secondary=selected_transcript_annotations,
)

# Search works with a different number of samples within the family
missing_gt_gcnv_variant = {
**GCNV_VARIANT4, 'genotypes': {k: v for k, v in GCNV_VARIANT4['genotypes'].items() if k != 'I000005_hg00732'}
}
await self._assert_expected_search(
[[MULTI_DATA_TYPE_COMP_HET_VARIANT2, missing_gt_gcnv_variant]],
inheritance_mode='compound_het', pathogenicity=pathogenicity,
annotations=gcnv_annotations_2, annotations_secondary=selected_transcript_annotations,
sample_data={**EXPECTED_SAMPLE_DATA, 'SV_WES': [EXPECTED_SAMPLE_DATA['SV_WES'][0], EXPECTED_SAMPLE_DATA['SV_WES'][2]]}

)

# Do not return pairs where annotations match in a non-paired gene
await self._assert_expected_search(
[GCNV_VARIANT3], inheritance_mode='recessive',
Expand Down

0 comments on commit 2ee7b3e

Please sign in to comment.