Skip to content

Commit

Permalink
Add support for reading object branches in ROOT.
Browse files Browse the repository at this point in the history
To use this, add:
```
branch_magic:
   __DOT__: .
```
to the yaml, and rewrite object branch like `obj.x` as `obj__DOT__x` in
the yaml file. The `branch_magic` will be used to translate the branch
name when reading the ROOT files, but any intermediate computation still
uses the escaped name like `obj__DOT__x`.
  • Loading branch information
hqucms committed Jun 14, 2023
1 parent af16ad8 commit 9c4b72f
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 10 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
install_requires.append(line)

setup(name="weaver-core",
version='0.4.2',
version='0.4.3',
description="A streamlined deep-learning framework for high energy physics",
long_description_content_type="text/markdown",
author="H. Qu, C. Li",
Expand Down
1 change: 1 addition & 0 deletions weaver/utils/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, print_info=True, **kwargs):

opts = {
'treename': None,
'branch_magic': None,
'selection': None,
'test_time_selection': None,
'preprocess': {'method': 'manual', 'data_fraction': 0.1, 'params': None},
Expand Down
23 changes: 19 additions & 4 deletions weaver/utils/data/fileio.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _read_hdf5(filepath, branches, load_range=None):
return ak.Array(outputs)


def _read_root(filepath, branches, load_range=None, treename=None):
def _read_root(filepath, branches, load_range=None, treename=None, branch_magic=None):
import uproot
with uproot.open(filepath) as f:
if treename is None:
Expand All @@ -30,14 +30,27 @@ def _read_root(filepath, branches, load_range=None, treename=None):
else:
raise RuntimeError(
'Need to specify `treename` as more than one trees are found in file %s: %s' %
(filepath, str(branches)))
(filepath, str(treenames)))
tree = f[treename]
if load_range is not None:
start = math.trunc(load_range[0] * tree.num_entries)
stop = max(start + 1, math.trunc(load_range[1] * tree.num_entries))
else:
start, stop = None, None
outputs = tree.arrays(filter_name=branches, entry_start=start, entry_stop=stop)
if branch_magic is not None:
branch_dict = {}
for name in branches:
decoded_name = name
for src, tgt in branch_magic.items():
if src in decoded_name:
decoded_name = decoded_name.replace(src, tgt)
branch_dict[name] = decoded_name
outputs = tree.arrays(filter_name=list(branch_dict.values()), entry_start=start, entry_stop=stop)
for name, decoded_name in branch_dict.items():
if name != decoded_name:
outputs[name] = outputs[decoded_name]
else:
outputs = tree.arrays(filter_name=branches, entry_start=start, entry_stop=stop)
return outputs


Expand Down Expand Up @@ -77,7 +90,9 @@ def _read_files(filelist, branches, load_range=None, show_progressbar=False, **k
if ext == '.h5':
a = _read_hdf5(filepath, branches, load_range=load_range)
elif ext == '.root':
a = _read_root(filepath, branches, load_range=load_range, treename=kwargs.get('treename', None))
a = _read_root(filepath, branches, load_range=load_range,
treename=kwargs.get('treename', None),
branch_magic=kwargs.get('branch_magic', None))
elif ext == '.awkd':
a = _read_awkd(filepath, branches, load_range=load_range)
elif ext == '.parquet':
Expand Down
7 changes: 4 additions & 3 deletions weaver/utils/data/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def read_file(self, filelist):
self.load_branches.update(_get_variable_names(self._data_config.selection))
_logger.debug('[AutoStandardizer] keep_branches:\n %s', ','.join(self.keep_branches))
_logger.debug('[AutoStandardizer] load_branches:\n %s', ','.join(self.load_branches))
table = _read_files(filelist, self.load_branches, self.load_range,
show_progressbar=True, treename=self._data_config.treename)
table = _read_files(filelist, self.load_branches, self.load_range, show_progressbar=True,
treename=self._data_config.treename, branch_magic=self._data_config.branch_magic)
table = _apply_selection(table, self._data_config.selection)
table = _build_new_variables(
table, {k: v for k, v in self._data_config.var_funcs.items() if k in self.keep_branches})
Expand Down Expand Up @@ -178,7 +178,8 @@ def read_file(self, filelist):
self.load_branches.update(_get_variable_names(self._data_config.selection))
_logger.debug('[WeightMaker] keep_branches:\n %s', ','.join(self.keep_branches))
_logger.debug('[WeightMaker] load_branches:\n %s', ','.join(self.load_branches))
table = _read_files(filelist, self.load_branches, show_progressbar=True, treename=self._data_config.treename)
table = _read_files(filelist, self.load_branches, show_progressbar=True,
treename=self._data_config.treename, branch_magic=self._data_config.branch_magic)
table = _apply_selection(table, self._data_config.selection)
table = _build_new_variables(
table, {k: v for k, v in self._data_config.var_funcs.items() if k in self.keep_branches})
Expand Down
5 changes: 3 additions & 2 deletions weaver/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ def _preprocess(table, data_config, options):


def _load_next(data_config, filelist, load_range, options):
table = _read_files(filelist, data_config.load_branches, load_range, treename=data_config.treename)
table = _read_files(filelist, data_config.load_branches, load_range,
treename=data_config.treename, branch_magic=data_config.branch_magic)
table, indices = _preprocess(table, data_config, options)
return table, indices

Expand Down Expand Up @@ -142,7 +143,7 @@ def __init__(self, **kwargs):
new_file_dict = {}
for name, files in file_dict.items():
new_files = files[worker_info.id::worker_info.num_workers]
assert(len(new_files) > 0)
assert (len(new_files) > 0)
new_file_dict[name] = new_files
file_dict = new_file_dict
self.worker_file_dict = file_dict
Expand Down

0 comments on commit 9c4b72f

Please sign in to comment.