Skip to content

Commit

Permalink
Merge pull request #3536 from broadinstitute/hail-backend-enum-setting
Browse files Browse the repository at this point in the history
save table globals to class instance
  • Loading branch information
hanars authored Aug 3, 2023
2 parents 6483883 + f8a3707 commit b933c47
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions hail_search/hail_search_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,27 +64,24 @@ def _format_population_config(cls, pop_config):
return base_pop_config

def annotation_fields(self):
ht_globals = {k: hl.eval(self._ht[k]) for k in self.GLOBALS}
enums = ht_globals.pop('enums')

annotation_fields = {
'populations': lambda r: hl.struct(**{
population: self.population_expression(r, population) for population in self.POPULATIONS.keys()
}),
'predictions': lambda r: hl.struct(**{
prediction: hl.array(enums[path.source][path.field])[r[path.source][f'{path.field}_id']]
if enums.get(path.source, {}).get(path.field) else r[path.source][path.field]
prediction: hl.array(self._enums[path.source][path.field])[r[path.source][f'{path.field}_id']]
if self._enums.get(path.source, {}).get(path.field) else r[path.source][path.field]
for prediction, path in self.PREDICTION_FIELDS_CONFIG.items()
}),
'transcripts': lambda r: hl.or_else(
r.sorted_transcript_consequences, hl.empty_array(r.sorted_transcript_consequences.dtype.element_type)
).map(
lambda t: self._enum_field(t, enums['sorted_transcript_consequences'], **self._format_transcript_args())
lambda t: self._enum_field(t, self._enums['sorted_transcript_consequences'], **self._format_transcript_args())
).group_by(lambda t: t.geneId),
}
annotation_fields.update(self.BASE_ANNOTATION_FIELDS)

format_enum = lambda k, enum_config: lambda r: self._enum_field(r[k], enums[k], ht_globals=ht_globals, **enum_config)
format_enum = lambda k, enum_config: lambda r: self._enum_field(r[k], self._enums[k], ht_globals=self._globals, **enum_config)
annotation_fields.update({
enum_config.get('response_key', k): format_enum(k, enum_config)
for k, enum_config in self.ENUM_ANNOTATION_FIELDS.items()
Expand Down Expand Up @@ -139,6 +136,8 @@ def __init__(self, data_type, sample_data, genome_version, sort=XPOS, num_result
self._sort = sort
self._num_results = num_results
self._ht = None
self._enums = None
self._globals = None

self._load_filtered_table(data_type, sample_data, **kwargs)

Expand Down Expand Up @@ -197,8 +196,11 @@ def import_filtered_table(self, data_type, sample_data, intervals=None, **kwargs
annotation_ht_query_result = hl.query_table(
annotations_ht_path, families_ht.key).first().drop(*families_ht.key)
ht = families_ht.annotate(**annotation_ht_query_result)
# Add globals
ht = ht.join(hl.read_table(annotations_ht_path).head(0).select().select_globals(*self.GLOBALS), how='left')

# Get globals
annotation_globals_ht = hl.read_table(annotations_ht_path).head(0).select()
self._globals = {k: hl.eval(annotation_globals_ht[k]) for k in self.GLOBALS}
self._enums = self._globals.pop('enums')

self._ht = ht.transmute(
genotypes=ht.family_entries.flatmap(lambda x: x).filter(
Expand Down

0 comments on commit b933c47

Please sign in to comment.