Skip to content

Commit

Permalink
Merge pull request #3900 from broadinstitute/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
hanars authored Feb 16, 2024
2 parents 9b2815f + d1851c6 commit 760e8c8
Show file tree
Hide file tree
Showing 18 changed files with 110 additions and 32 deletions.
Binary file modified hail_search/fixtures/GRCh38/MITO/annotations.ht/.README.txt.crc
Binary file not shown.
Binary file not shown.
4 changes: 2 additions & 2 deletions hail_search/fixtures/GRCh38/MITO/annotations.ht/README.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
This folder comprises a Hail (www.hail.is) native Table or MatrixTable.
Written with version 0.2.124-13536b531342
Created at 2023/11/22 10:50:28
Written with version 0.2.120-f00f916faf78
Created at 2024/02/15 17:49:47
Binary file modified hail_search/fixtures/GRCh38/MITO/annotations.ht/metadata.json.gz
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
23 changes: 19 additions & 4 deletions hail_search/queries/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
DATASETS_DIR = os.environ.get('DATASETS_DIR', '/hail_datasets')
SSD_DATASETS_DIR = os.environ.get('SSD_DATASETS_DIR', DATASETS_DIR)

# Number of filtered genes at which pre-filtering a table by gene-intervals does not improve performance
# Estimated based on behavior for several representative gene lists
MAX_GENE_INTERVALS = 100

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -245,13 +249,19 @@ def _read_table(self, path, drop_globals=None, use_ssd_dir=False, skip_missing_f
table_path = self._get_table_path(path, use_ssd_dir=use_ssd_dir)
if 'variant_ht' in self._load_table_kwargs:
ht = self._query_table_annotations(self._load_table_kwargs['variant_ht'], table_path)
if skip_missing_field and not ht.any(hl.is_defined(ht[skip_missing_field])):
if self._should_skip_ht(ht, skip_missing_field):
return None
ht_globals = hl.read_table(table_path).globals
if drop_globals:
ht_globals = ht_globals.drop(*drop_globals)
return ht.annotate_globals(**hl.eval(ht_globals))
return hl.read_table(table_path, **self._load_table_kwargs)

ht = hl.read_table(table_path, **self._load_table_kwargs)
return None if self._should_skip_ht(ht, skip_missing_field) else ht

@staticmethod
def _should_skip_ht(ht, skip_missing_field):
return skip_missing_field and not ht.any(hl.is_defined(ht[skip_missing_field]))

@staticmethod
def _query_table_annotations(ht, query_table_path):
Expand Down Expand Up @@ -290,6 +300,8 @@ def _load_filtered_project_hts(self, project_samples, skip_all_missing=False, **
if exception_messages:
raise HTTPBadRequest(reason='; '.join(exception_messages))

if len(project_samples) > len(filtered_project_hts):
logger.info(f'Found {len(filtered_project_hts)} {self.DATA_TYPE} projects with matched entries')
return filtered_project_hts

def import_filtered_table(self, project_samples, num_families, intervals=None, **kwargs):
Expand Down Expand Up @@ -561,7 +573,7 @@ def _filter_rs_ids(self, ht, rs_ids):
rs_id_set = hl.set(rs_ids)
return ht.filter(rs_id_set.contains(ht.rsid))

def _parse_intervals(self, intervals, **kwargs):
def _parse_intervals(self, intervals, gene_ids=None, **kwargs):
parsed_variant_keys = self._parse_variant_keys(**kwargs)
if parsed_variant_keys:
self._load_table_kwargs['variant_ht'] = hl.Table.parallelize(parsed_variant_keys).key_by(*self.KEY_FIELD)
Expand All @@ -582,6 +594,9 @@ def _parse_intervals(self, intervals, **kwargs):
reference_genome = hl.get_reference(self.GENOME_VERSION)
intervals = (intervals or []) + [reference_genome.x_contigs[0]]

if len(intervals) > MAX_GENE_INTERVALS and len(intervals) == len(gene_ids or []):
return []

parsed_intervals = [
hl.eval(hl.parse_locus_interval(interval, reference_genome=self.GENOME_VERSION, invalid_missing=True))
for interval in intervals
Expand Down Expand Up @@ -976,7 +991,7 @@ def _get_sort_expressions(self, ht, sort):

if sort in self.PREDICTION_FIELDS_CONFIG:
prediction_path = self.PREDICTION_FIELDS_CONFIG[sort]
return [-hl.float64(ht[prediction_path.source][prediction_path.field])]
return [hl.or_else(-hl.float64(ht[prediction_path.source][prediction_path.field]), 0)]

if sort == OMIM_SORT:
return self._omim_sort(ht, hl.set(set(self._sort_metadata)))
Expand Down
4 changes: 3 additions & 1 deletion hail_search/queries/mito.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ class MitoHailTableQuery(BaseHailTableQuery):
CORE_FIELDS = BaseHailTableQuery.CORE_FIELDS + ['rsid']
MITO_ANNOTATION_FIELDS = {
'commonLowHeteroplasmy': lambda r: r.common_low_heteroplasmy,
'highConstraintRegion': lambda r: r.high_constraint_region,
'highConstraintRegion': (
lambda r: r.high_constraint_region if hasattr(r, 'high_constraint_region') else r.high_constraint_region_mito
),
'mitomapPathogenic': lambda r: r.mitomap.pathogenic,
}
BASE_ANNOTATION_FIELDS = {
Expand Down
25 changes: 14 additions & 11 deletions hail_search/queries/multi_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ def format_search_ht(self):
dt_ht = query.format_search_ht()
if dt_ht is None:
continue
merged_sort_expr = self._merged_sort_expr(data_type, dt_ht)
if merged_sort_expr is not None:
dt_ht = dt_ht.annotate(_sort=merged_sort_expr)
dt_ht = self._merged_sort(data_type, dt_ht)
hts.append(dt_ht.select('_sort', **{data_type: dt_ht.row}))

for data_type, ch_ht in self._comp_het_hts.items():
Expand All @@ -126,7 +124,7 @@ def format_search_ht(self):
v2=self._format_comp_het_result(ch_ht.v2, data_type),
)
hts.append(ch_ht.select(
_sort=hl.sorted([ch_ht.v1._sort, ch_ht.v2._sort])[0],
_sort=hl.sorted([ch_ht.v1._sort.map(hl.float64), ch_ht.v2._sort.map(hl.float64)])[0],
**{f'comp_het_{data_type}': ch_ht.row},
))

Expand All @@ -137,21 +135,26 @@ def format_search_ht(self):
return ht

def _format_comp_het_result(self, v, data_type):
return self._data_type_queries[data_type]._format_results(v)
result = self._data_type_queries[data_type]._format_results(v)
return self._merged_sort(data_type, result)

def _merged_sort_expr(self, data_type, ht):
def _merged_sort(self, data_type, ht):
# Certain sorts have an extra element for variant-type data, so need to add an element for SV data
if not data_type.startswith('SV'):
return None
return ht

sort_expr = None
if self._sort == CONSEQUENCE_SORT:
return hl.array([hl.float64(4.5)]).extend(ht._sort.map(hl.float64))
sort_expr = hl.array([hl.float64(4.5)]).extend(ht._sort.map(hl.float64))
elif self._sort == OMIM_SORT:
return hl.array([hl.int64(0)]).extend(ht._sort)
sort_expr = hl.array([hl.int64(0)]).extend(ht._sort)
elif self._sort_metadata:
return ht._sort[:1].extend(ht._sort)
sort_expr = ht._sort[:1].extend(ht._sort)

if sort_expr is not None:
ht = ht.annotate(_sort=sort_expr)

return None
return ht

def _format_collected_rows(self, collected):
data_types = [*self._data_type_queries, *[f'comp_het_{data_type}' for data_type in self._comp_het_hts]]
Expand Down
86 changes: 72 additions & 14 deletions hail_search/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@
'ENSG00000277972': {'total': 1, 'families': {'F000002_2': 1}},
}

OMIM_SORT_METADATA = ['ENSG00000177000', 'ENSG00000097046', 'ENSG00000275023']


def _sorted(variant, sorts):
return {**variant, '_sort': sorts + variant['_sort']}
Expand Down Expand Up @@ -1092,30 +1094,29 @@ async def test_sort(self):

await self._assert_expected_search(
[_sorted(VARIANT4, [-0.5260000228881836]), _sorted(VARIANT2, [-0.19699999690055847]),
_sorted(VARIANT1, [None]), _sorted(MULTI_FAMILY_VARIANT, [None])], omit_sample_type='SV_WES', sort='revel',
_sorted(VARIANT1, [0]), _sorted(MULTI_FAMILY_VARIANT, [0])], omit_sample_type='SV_WES', sort='revel',
)

await self._assert_expected_search(
[_sorted(MULTI_FAMILY_VARIANT, [-0.009999999776482582]), _sorted(VARIANT2, [0]), _sorted(VARIANT4, [0]),
_sorted(VARIANT1, [None])], omit_sample_type='SV_WES', sort='splice_ai',
_sorted(VARIANT1, [0])], omit_sample_type='SV_WES', sort='splice_ai',
)

omim_sort_metadata = ['ENSG00000177000', 'ENSG00000097046', 'ENSG00000275023']
sort = 'in_omim'
await self._assert_expected_search(
[_sorted(MULTI_FAMILY_VARIANT, [0, -2]), _sorted(VARIANT2, [0, -1]), _sorted(VARIANT4, [0, -1]), _sorted(VARIANT1, [1, 0])],
omit_sample_type='SV_WES', sort=sort, sort_metadata=omim_sort_metadata,
omit_sample_type='SV_WES', sort=sort, sort_metadata=OMIM_SORT_METADATA,
)

await self._assert_expected_search(
[_sorted(GCNV_VARIANT3, [-1]), _sorted(GCNV_VARIANT4, [-1]), _sorted(GCNV_VARIANT1, [0]), _sorted(GCNV_VARIANT2, [0])],
omit_sample_type='SNV_INDEL', sort=sort, sort_metadata=omim_sort_metadata,
omit_sample_type='SNV_INDEL', sort=sort, sort_metadata=OMIM_SORT_METADATA,
)

await self._assert_expected_search(
[_sorted(MULTI_FAMILY_VARIANT, [0, -2]), _sorted(VARIANT2, [0, -1]), _sorted(VARIANT4, [0, -1]),
_sorted(GCNV_VARIANT3, [0, -1]), _sorted(GCNV_VARIANT4, [0, -1]), _sorted(GCNV_VARIANT1, [0, 0]),
_sorted(GCNV_VARIANT2, [0, 0]), _sorted(VARIANT1, [1, 0])], sort=sort, sort_metadata=omim_sort_metadata,
_sorted(GCNV_VARIANT2, [0, 0]), _sorted(VARIANT1, [1, 0])], sort=sort, sort_metadata=OMIM_SORT_METADATA,
)

await self._assert_expected_search(
Expand Down Expand Up @@ -1161,14 +1162,7 @@ async def test_sort(self):

# sort applies to compound hets
await self._assert_expected_search(
[[_sorted(GCNV_VARIANT4, [0]), _sorted(MULTI_DATA_TYPE_COMP_HET_VARIANT2, [11, 11])],
_sorted(GCNV_VARIANT3, [4.5, 0]), [_sorted(GCNV_VARIANT3, [0]), _sorted(GCNV_VARIANT4, [0])],
_sorted(VARIANT2, [11, 11]), [_sorted(VARIANT4, [11, 11]), _sorted(VARIANT3, [22, 24])]],
sort='protein_consequence', inheritance_mode='recessive', **COMP_HET_ALL_PASS_FILTERS,
)

await self._assert_expected_search(
[[_sorted(VARIANT4, [-0.5260000228881836]), _sorted(VARIANT3, [None])],
[[_sorted(VARIANT4, [-0.5260000228881836]), _sorted(VARIANT3, [0])],
_sorted(VARIANT2, [-0.19699999690055847])],
sort='revel', inheritance_mode='recessive', omit_sample_type='SV_WES', **COMP_HET_ALL_PASS_FILTERS,
)
Expand All @@ -1177,3 +1171,67 @@ async def test_sort(self):
[[_sorted(VARIANT3, [-0.009999999776482582]), _sorted(VARIANT4, [0])], _sorted(VARIANT2, [0])],
sort='splice_ai', inheritance_mode='recessive', omit_sample_type='SV_WES', **COMP_HET_ALL_PASS_FILTERS,
)

async def test_multi_data_type_comp_het_sort(self):
await self._assert_expected_search(
[_sorted(GCNV_VARIANT3, [4.5, 0]), [_sorted(GCNV_VARIANT3, [0]), _sorted(GCNV_VARIANT4, [0])],
[_sorted(GCNV_VARIANT4, [4.5, 0]), _sorted(MULTI_DATA_TYPE_COMP_HET_VARIANT2, [11, 11])],
_sorted(VARIANT2, [11, 11]), [_sorted(VARIANT4, [11, 11]), _sorted(VARIANT3, [22, 24])]],
sort='protein_consequence', inheritance_mode='recessive', **COMP_HET_ALL_PASS_FILTERS,
)

await self._assert_expected_search(
[[_sorted(GCNV_VARIANT4, [-14487]), _sorted(GCNV_VARIANT3, [-2666])],
[_sorted(GCNV_VARIANT4, [-14487]), MULTI_DATA_TYPE_COMP_HET_VARIANT2],
[VARIANT3, VARIANT4]],
sort='size', inheritance_mode='compound_het', **COMP_HET_ALL_PASS_FILTERS,
)

await self._assert_expected_search(
[[_sorted(MULTI_DATA_TYPE_COMP_HET_VARIANT2, [8]), GCNV_VARIANT4],
[_sorted(VARIANT3, [12.5]), _sorted(VARIANT4, [12.5])],
[GCNV_VARIANT3, GCNV_VARIANT4]],
sort='pathogenicity', inheritance_mode='compound_het', **COMP_HET_ALL_PASS_FILTERS,
)

await self._assert_expected_search(
[[_sorted(VARIANT4, [-0.6869999766349792]), _sorted(VARIANT3, [0])], _sorted(VARIANT2, [0]),
[_sorted(MULTI_DATA_TYPE_COMP_HET_VARIANT2, [0]), GCNV_VARIANT4],
GCNV_VARIANT3, [GCNV_VARIANT3, GCNV_VARIANT4]],
sort='mut_pred', inheritance_mode='recessive', **COMP_HET_ALL_PASS_FILTERS,
)

await self._assert_expected_search(
[[_sorted(VARIANT3, [-0.009999999776482582]), _sorted(VARIANT4, [0])],
[_sorted(MULTI_DATA_TYPE_COMP_HET_VARIANT2, [0]), GCNV_VARIANT4],
[GCNV_VARIANT3, GCNV_VARIANT4]],
sort='splice_ai', inheritance_mode='compound_het', **COMP_HET_ALL_PASS_FILTERS,
)

await self._assert_expected_search(
[[_sorted(GCNV_VARIANT3, [-0.7860000133514404]), _sorted(GCNV_VARIANT4, [-0.7099999785423279])],
[_sorted(GCNV_VARIANT4, [-0.7099999785423279]), MULTI_DATA_TYPE_COMP_HET_VARIANT2],
[VARIANT3, VARIANT4]],
sort='strvctvre', inheritance_mode='compound_het', **COMP_HET_ALL_PASS_FILTERS,
)

await self._assert_expected_search(
[[_sorted(GCNV_VARIANT3, [0.0015185698866844177]), _sorted(GCNV_VARIANT4, [0.004989586770534515])],
[_sorted(GCNV_VARIANT4, [0.004989586770534515]), _sorted(MULTI_DATA_TYPE_COMP_HET_VARIANT2, [0.31111112236976624])],
[_sorted(VARIANT4, [0.02222222276031971]), _sorted(VARIANT3, [0.6666666865348816])]],
sort='callset_af', inheritance_mode='compound_het', **COMP_HET_ALL_PASS_FILTERS,
)

await self._assert_expected_search(
[[_sorted(VARIANT3, [0]), _sorted(VARIANT4, [0])],
[_sorted(MULTI_DATA_TYPE_COMP_HET_VARIANT2, [0.28899794816970825]), GCNV_VARIANT4],
[GCNV_VARIANT3, GCNV_VARIANT4]],
sort='gnomad_exomes', inheritance_mode='compound_het', **COMP_HET_ALL_PASS_FILTERS,
)

await self._assert_expected_search(
[[_sorted(VARIANT3, [0, -2]), _sorted(VARIANT4, [0, -1])],
[_sorted(GCNV_VARIANT3, [-1]), _sorted(GCNV_VARIANT4, [-1])],
[_sorted(GCNV_VARIANT4, [0, -1]), _sorted(MULTI_DATA_TYPE_COMP_HET_VARIANT2, [1, -1])]],
sort='in_omim', sort_metadata=OMIM_SORT_METADATA, inheritance_mode='compound_het', **COMP_HET_ALL_PASS_FILTERS,
)

0 comments on commit 760e8c8

Please sign in to comment.