From a16973f4644a962fd56c60cb942e4e093a1d4484 Mon Sep 17 00:00:00 2001 From: Huilin Qu Date: Fri, 15 Dec 2023 17:22:03 +0800 Subject: [PATCH] Support jagged arrays as observers. --- setup.py | 2 +- weaver/train.py | 14 ++++++++------ weaver/utils/data/config.py | 2 ++ weaver/utils/dataset.py | 34 +++++++++++++++++++--------------- weaver/utils/nn/tools.py | 25 +++++++++++++++---------- 5 files changed, 45 insertions(+), 32 deletions(-) diff --git a/setup.py b/setup.py index ba7ba245..45c157da 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ install_requires.append(line) setup(name="weaver-core", - version='0.4.8', + version='0.4.9', description="A streamlined deep-learning framework for high energy physics", long_description_content_type="text/markdown", author="H. Qu, C. Li", diff --git a/weaver/train.py b/weaver/train.py index 1f4c39d6..51f5c57f 100644 --- a/weaver/train.py +++ b/weaver/train.py @@ -607,14 +607,16 @@ def iotest(args, data_loader): for X, y, Z in tqdm(data_loader): for k, v in Z.items(): - monitor_info[k].append(v.cpu().numpy()) + monitor_info[k].append(v) monitor_info = {k: _concat(v) for k, v in monitor_info.items()} if monitor_info: - monitor_output_path = 'weaver_monitor_info.pkl' - import pickle - with open(monitor_output_path, 'wb') as f: - pickle.dump(monitor_info, f) - _logger.info('Monitor info written to %s' % monitor_output_path) + monitor_output_path = 'weaver_monitor_info.parquet' + try: + import awkward as ak + ak.to_parquet(ak.Array(monitor_info), monitor_output_path, compression='LZ4', compression_level=4) + _logger.info('Monitor info written to %s' % monitor_output_path, color='bold') + except Exception as e: + _logger.error('Error when writing output parquet file: \n' + str(e)) def save_root(args, output_path, data_config, scores, labels, observers): diff --git a/weaver/utils/data/config.py b/weaver/utils/data/config.py index b32bb038..bfb6a24a 100644 --- a/weaver/utils/data/config.py +++ b/weaver/utils/data/config.py @@ -134,6 +134,8 @@ def _get(idx, default): self.observer_names = tuple(opts['observers']) # monitor variables self.monitor_variables = tuple(opts['monitor_variables']) + if self.observer_names and self.monitor_variables: + raise RuntimeError('Cannot set `observers` and `monitor_variables` at the same time.') # Z variables: returned as `Z` in the dataloader (use monitor_variables for training, observers for eval) self.z_variables = self.observer_names if len(self.observer_names) > 0 else self.monitor_variables diff --git a/weaver/utils/dataset.py b/weaver/utils/dataset.py index 45cca82c..5d48fb46 100644 --- a/weaver/utils/dataset.py +++ b/weaver/utils/dataset.py @@ -8,22 +8,22 @@ from functools import partial from concurrent.futures.thread import ThreadPoolExecutor from .logger import _logger -from .data.tools import _pad, _repeat_pad, _clip +from .data.tools import _pad, _repeat_pad, _clip, _stack from .data.fileio import _read_files from .data.config import DataConfig, _md5 from .data.preprocess import _apply_selection, _build_new_variables, _build_weights, AutoStandardizer, WeightMaker +def _collate_awkward_array_fn(batch, *, collate_fn_map=None): + return _stack(batch, axis=0) + + def _finalize_inputs(table, data_config): output = {} # copy observer variables before transformation for k in data_config.z_variables: if k in data_config.observer_names: - a = ak.to_numpy(table[k]) - if a.dtype in (np.uint16, np.uint32, np.uint64): - # FIXME: hack as torch only supports float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool - a = a.astype('int64') - output[k] = a + output[k] = table[k] # ak.Array # copy labels for k in data_config.label_names: output[k] = ak.to_numpy(table[k]) @@ -47,10 +47,10 @@ def _finalize_inputs(table, data_config): output['_' + k] = ak.to_numpy(ak.values_astype(table[names[0]], 'float32')) else: output['_' + k] = ak.to_numpy(np.stack([ak.to_numpy(table[n]).astype('float32') for n in names], axis=1)) - # copy monitor variables + # copy monitor variables (after transformation) for k in data_config.z_variables: - if k not in output: - output[k] = ak.to_numpy(table[k]) + if k in data_config.monitor_variables: + output[k] = table[k] # ak.Array return output @@ -67,7 +67,7 @@ def _get_reweight_indices(weights, up_sample=True, max_resample=10, weight_scale all_indices = np.repeat(np.arange(len(weights)), n_repeats) randwgt = np.random.uniform(low=0, high=weight_scale, size=len(weights) * n_repeats) keep_indices = all_indices[randwgt < np.repeat(weights, n_repeats)] - return keep_indices.copy() + return copy.deepcopy(keep_indices) def _check_labels(table): @@ -135,7 +135,7 @@ def __init__(self, **kwargs): self._seed = None worker_info = torch.utils.data.get_worker_info() - file_dict = self._init_file_dict.copy() + file_dict = copy.deepcopy(self._init_file_dict) if worker_info is not None: # in a worker process self._name += '_worker%d' % worker_info.id @@ -156,7 +156,7 @@ def __init__(self, **kwargs): def restart(self): print('=== Restarting DataIter %s, seed=%s ===' % (self._name, self._seed)) # re-shuffle filelist and load range if for training - filelist = self.worker_filelist.copy() + filelist = copy.deepcopy(self.worker_filelist) if self._sampler_options['shuffle']: np.random.shuffle(filelist) if self._file_fraction < 1: @@ -260,11 +260,11 @@ def _try_get_next(self, init=False): def get_data(self, i): # inputs - X = {k: self.table['_' + k][i].copy() for k in self._data_config.input_names} + X = {k: copy.deepcopy(self.table['_' + k][i]) for k in self._data_config.input_names} # labels - y = {k: self.table[k][i].copy() for k in self._data_config.label_names} + y = {k: copy.deepcopy(self.table[k][i]) for k in self._data_config.label_names} # observers / monitor variables - Z = {k: self.table[k][i].copy() for k in self._data_config.z_variables} + Z = {k: copy.deepcopy(self.table[k][i]) for k in self._data_config.z_variables} return X, y, Z @@ -315,6 +315,10 @@ def __init__(self, file_dict, data_config_file, 'max_resample': max_resample, } + # ==== torch collate_fn map ==== + from torch.utils.data._utils.collate import default_collate_fn_map + default_collate_fn_map.update({ak.Array: _collate_awkward_array_fn}) + if for_training: self._sampler_options.update(training=True, shuffle=True, reweight=True) else: diff --git a/weaver/utils/nn/tools.py b/weaver/utils/nn/tools.py index cf3f1316..6727378f 100644 --- a/weaver/utils/nn/tools.py +++ b/weaver/utils/nn/tools.py @@ -59,12 +59,14 @@ def train_classification( total_loss = 0 num_batches = 0 total_correct = 0 + entry_count = 0 count = 0 start_time = time.time() with tqdm.tqdm(train_loader) as tq: for X, y, _ in tq: inputs = [X[k].to(dev) for k in data_config.input_names] label = y[data_config.label_names[0]].long().to(dev) + entry_count += label.shape[0] try: mask = y[data_config.label_names[0] + '_mask'].bool().to(dev) except KeyError: @@ -117,7 +119,7 @@ def train_classification( break time_diff = time.time() - start_time - _logger.info('Processed %d entries in total (avg. speed %.1f entries/s)' % (count, count / time_diff)) + _logger.info('Processed %d entries in total (avg. speed %.1f entries/s)' % (entry_count, entry_count / time_diff)) _logger.info('Train AvgLoss: %.5f, AvgAcc: %.5f' % (total_loss / num_batches, total_correct / count)) _logger.info('Train class distribution: \n %s', str(sorted(label_counter.items()))) @@ -157,6 +159,7 @@ def evaluate_classification(model, test_loader, dev, epoch, for_training=True, l with torch.no_grad(): with tqdm.tqdm(test_loader) as tq: for X, y, Z in tq: + # X, y: torch.Tensor; Z: ak.Array inputs = [X[k].to(dev) for k in data_config.input_names] label = y[data_config.label_names[0]].long().to(dev) entry_count += label.shape[0] @@ -174,7 +177,7 @@ def evaluate_classification(model, test_loader, dev, epoch, for_training=True, l labels[k].append(_flatten_label(v, mask).numpy(force=True)) if not for_training: for k, v in Z.items(): - observers[k].append(v.numpy(force=True)) + observers[k].append(v) num_examples = label.shape[0] label_counter.update(label.numpy(force=True)) @@ -206,7 +209,7 @@ def evaluate_classification(model, test_loader, dev, epoch, for_training=True, l break time_diff = time.time() - start_time - _logger.info('Processed %d entries in total (avg. speed %.1f entries/s)' % (count, count / time_diff)) + _logger.info('Processed %d entries in total (avg. speed %.1f entries/s)' % (entry_count, entry_count / time_diff)) _logger.info('Evaluation class distribution: \n %s', str(sorted(label_counter.items()))) if tb_helper: @@ -259,8 +262,9 @@ def evaluate_onnx(model_path, test_loader, eval_metrics=['roc_auc_score', 'roc_a start_time = time.time() with tqdm.tqdm(test_loader) as tq: for X, y, Z in tq: - inputs = {k: v.cpu().numpy() for k, v in X.items()} - label = y[data_config.label_names[0]].cpu().numpy() + # X, y: torch.Tensor; Z: ak.Array + inputs = {k: v.numpy(force=True) for k, v in X.items()} + label = y[data_config.label_names[0]].numpy(force=True) num_examples = label.shape[0] label_counter.update(label) score = sess.run([], inputs)[0] @@ -268,9 +272,9 @@ def evaluate_onnx(model_path, test_loader, eval_metrics=['roc_auc_score', 'roc_a scores.append(score) for k, v in y.items(): - labels[k].append(v.cpu().numpy()) + labels[k].append(v.numpy(force=True)) for k, v in Z.items(): - observers[k].append(v.cpu().numpy()) + observers[k].append(v) correct = (preds == label).sum() total_correct += correct @@ -404,6 +408,7 @@ def evaluate_regression(model, test_loader, dev, epoch, for_training=True, loss_ with torch.no_grad(): with tqdm.tqdm(test_loader) as tq: for X, y, Z in tq: + # X, y: torch.Tensor; Z: ak.Array inputs = [X[k].to(dev) for k in data_config.input_names] label = y[data_config.label_names[0]].float() num_examples = label.shape[0] @@ -411,12 +416,12 @@ def evaluate_regression(model, test_loader, dev, epoch, for_training=True, loss_ model_output = model(*inputs) preds = model_output.squeeze().float() - scores.append(preds.detach().cpu().numpy()) + scores.append(preds.numpy(force=True)) for k, v in y.items(): - labels[k].append(v.cpu().numpy()) + labels[k].append(v.numpy(force=True)) if not for_training: for k, v in Z.items(): - observers[k].append(v.cpu().numpy()) + observers[k].append(v) loss = 0 if loss_func is None else loss_func(preds, label).item()