diff --git a/setup.py b/setup.py index 51994808..52c93a8b 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ install_requires.append(line) setup(name="weaver-core", - version='0.4.2', + version='0.4.3', description="A streamlined deep-learning framework for high energy physics", long_description_content_type="text/markdown", author="H. Qu, C. Li", diff --git a/weaver/utils/data/config.py b/weaver/utils/data/config.py index 1b3075c7..abd0b7c9 100644 --- a/weaver/utils/data/config.py +++ b/weaver/utils/data/config.py @@ -33,6 +33,7 @@ def __init__(self, print_info=True, **kwargs): opts = { 'treename': None, + 'branch_magic': None, 'selection': None, 'test_time_selection': None, 'preprocess': {'method': 'manual', 'data_fraction': 0.1, 'params': None}, diff --git a/weaver/utils/data/fileio.py b/weaver/utils/data/fileio.py index 2d01fec5..0f9c9da1 100644 --- a/weaver/utils/data/fileio.py +++ b/weaver/utils/data/fileio.py @@ -20,7 +20,7 @@ def _read_hdf5(filepath, branches, load_range=None): return ak.Array(outputs) -def _read_root(filepath, branches, load_range=None, treename=None): +def _read_root(filepath, branches, load_range=None, treename=None, branch_magic=None): import uproot with uproot.open(filepath) as f: if treename is None: @@ -30,14 +30,27 @@ def _read_root(filepath, branches, load_range=None, treename=None): else: raise RuntimeError( 'Need to specify `treename` as more than one trees are found in file %s: %s' % - (filepath, str(branches))) + (filepath, str(treenames))) tree = f[treename] if load_range is not None: start = math.trunc(load_range[0] * tree.num_entries) stop = max(start + 1, math.trunc(load_range[1] * tree.num_entries)) else: start, stop = None, None - outputs = tree.arrays(filter_name=branches, entry_start=start, entry_stop=stop) + if branch_magic is not None: + branch_dict = {} + for name in branches: + decoded_name = name + for src, tgt in branch_magic.items(): + if src in decoded_name: + decoded_name = decoded_name.replace(src, tgt) + branch_dict[name] = decoded_name + outputs = tree.arrays(filter_name=list(branch_dict.values()), entry_start=start, entry_stop=stop) + for name, decoded_name in branch_dict.items(): + if name != decoded_name: + outputs[name] = outputs[decoded_name] + else: + outputs = tree.arrays(filter_name=branches, entry_start=start, entry_stop=stop) return outputs @@ -77,7 +90,9 @@ def _read_files(filelist, branches, load_range=None, show_progressbar=False, **k if ext == '.h5': a = _read_hdf5(filepath, branches, load_range=load_range) elif ext == '.root': - a = _read_root(filepath, branches, load_range=load_range, treename=kwargs.get('treename', None)) + a = _read_root(filepath, branches, load_range=load_range, + treename=kwargs.get('treename', None), + branch_magic=kwargs.get('branch_magic', None)) elif ext == '.awkd': a = _read_awkd(filepath, branches, load_range=load_range) elif ext == '.parquet': diff --git a/weaver/utils/data/preprocess.py b/weaver/utils/data/preprocess.py index 297eb899..3909e511 100644 --- a/weaver/utils/data/preprocess.py +++ b/weaver/utils/data/preprocess.py @@ -104,8 +104,8 @@ def read_file(self, filelist): self.load_branches.update(_get_variable_names(self._data_config.selection)) _logger.debug('[AutoStandardizer] keep_branches:\n %s', ','.join(self.keep_branches)) _logger.debug('[AutoStandardizer] load_branches:\n %s', ','.join(self.load_branches)) - table = _read_files(filelist, self.load_branches, self.load_range, - show_progressbar=True, treename=self._data_config.treename) + table = _read_files(filelist, self.load_branches, self.load_range, show_progressbar=True, + treename=self._data_config.treename, branch_magic=self._data_config.branch_magic) table = _apply_selection(table, self._data_config.selection) table = _build_new_variables( table, {k: v for k, v in self._data_config.var_funcs.items() if k in self.keep_branches}) @@ -178,7 +178,8 @@ def read_file(self, filelist): self.load_branches.update(_get_variable_names(self._data_config.selection)) _logger.debug('[WeightMaker] keep_branches:\n %s', ','.join(self.keep_branches)) _logger.debug('[WeightMaker] load_branches:\n %s', ','.join(self.load_branches)) - table = _read_files(filelist, self.load_branches, show_progressbar=True, treename=self._data_config.treename) + table = _read_files(filelist, self.load_branches, show_progressbar=True, + treename=self._data_config.treename, branch_magic=self._data_config.branch_magic) table = _apply_selection(table, self._data_config.selection) table = _build_new_variables( table, {k: v for k, v in self._data_config.var_funcs.items() if k in self.keep_branches}) diff --git a/weaver/utils/dataset.py b/weaver/utils/dataset.py index 1ed04637..7a268c64 100644 --- a/weaver/utils/dataset.py +++ b/weaver/utils/dataset.py @@ -106,7 +106,8 @@ def _preprocess(table, data_config, options): def _load_next(data_config, filelist, load_range, options): - table = _read_files(filelist, data_config.load_branches, load_range, treename=data_config.treename) + table = _read_files(filelist, data_config.load_branches, load_range, + treename=data_config.treename, branch_magic=data_config.branch_magic) table, indices = _preprocess(table, data_config, options) return table, indices @@ -142,7 +143,7 @@ def __init__(self, **kwargs): new_file_dict = {} for name, files in file_dict.items(): new_files = files[worker_info.id::worker_info.num_workers] - assert(len(new_files) > 0) + assert (len(new_files) > 0) new_file_dict[name] = new_files file_dict = new_file_dict self.worker_file_dict = file_dict