Skip to content

Commit

Permalink
Support jagged arrays as observers.
Browse files Browse the repository at this point in the history
  • Loading branch information
hqucms committed Dec 15, 2023
1 parent 59f0c6b commit a16973f
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 32 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
install_requires.append(line)

setup(name="weaver-core",
version='0.4.8',
version='0.4.9',
description="A streamlined deep-learning framework for high energy physics",
long_description_content_type="text/markdown",
author="H. Qu, C. Li",
Expand Down
14 changes: 8 additions & 6 deletions weaver/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,14 +607,16 @@ def iotest(args, data_loader):

for X, y, Z in tqdm(data_loader):
for k, v in Z.items():
monitor_info[k].append(v.cpu().numpy())
monitor_info[k].append(v)
monitor_info = {k: _concat(v) for k, v in monitor_info.items()}
if monitor_info:
monitor_output_path = 'weaver_monitor_info.pkl'
import pickle
with open(monitor_output_path, 'wb') as f:
pickle.dump(monitor_info, f)
_logger.info('Monitor info written to %s' % monitor_output_path)
monitor_output_path = 'weaver_monitor_info.parquet'
try:
import awkward as ak
ak.to_parquet(ak.Array(monitor_info), monitor_output_path, compression='LZ4', compression_level=4)
_logger.info('Monitor info written to %s' % monitor_output_path, color='bold')
except Exception as e:
_logger.error('Error when writing output parquet file: \n' + str(e))


def save_root(args, output_path, data_config, scores, labels, observers):
Expand Down
2 changes: 2 additions & 0 deletions weaver/utils/data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def _get(idx, default):
self.observer_names = tuple(opts['observers'])
# monitor variables
self.monitor_variables = tuple(opts['monitor_variables'])
if self.observer_names and self.monitor_variables:
raise RuntimeError('Cannot set `observers` and `monitor_variables` at the same time.')
# Z variables: returned as `Z` in the dataloader (use monitor_variables for training, observers for eval)
self.z_variables = self.observer_names if len(self.observer_names) > 0 else self.monitor_variables

Expand Down
34 changes: 19 additions & 15 deletions weaver/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,22 @@
from functools import partial
from concurrent.futures.thread import ThreadPoolExecutor
from .logger import _logger
from .data.tools import _pad, _repeat_pad, _clip
from .data.tools import _pad, _repeat_pad, _clip, _stack
from .data.fileio import _read_files
from .data.config import DataConfig, _md5
from .data.preprocess import _apply_selection, _build_new_variables, _build_weights, AutoStandardizer, WeightMaker


def _collate_awkward_array_fn(batch, *, collate_fn_map=None):
return _stack(batch, axis=0)


def _finalize_inputs(table, data_config):
output = {}
# copy observer variables before transformation
for k in data_config.z_variables:
if k in data_config.observer_names:
a = ak.to_numpy(table[k])
if a.dtype in (np.uint16, np.uint32, np.uint64):
# FIXME: hack as torch only supports float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool
a = a.astype('int64')
output[k] = a
output[k] = table[k] # ak.Array
# copy labels
for k in data_config.label_names:
output[k] = ak.to_numpy(table[k])
Expand All @@ -47,10 +47,10 @@ def _finalize_inputs(table, data_config):
output['_' + k] = ak.to_numpy(ak.values_astype(table[names[0]], 'float32'))
else:
output['_' + k] = ak.to_numpy(np.stack([ak.to_numpy(table[n]).astype('float32') for n in names], axis=1))
# copy monitor variables
# copy monitor variables (after transformation)
for k in data_config.z_variables:
if k not in output:
output[k] = ak.to_numpy(table[k])
if k in data_config.monitor_variables:
output[k] = table[k] # ak.Array
return output


Expand All @@ -67,7 +67,7 @@ def _get_reweight_indices(weights, up_sample=True, max_resample=10, weight_scale
all_indices = np.repeat(np.arange(len(weights)), n_repeats)
randwgt = np.random.uniform(low=0, high=weight_scale, size=len(weights) * n_repeats)
keep_indices = all_indices[randwgt < np.repeat(weights, n_repeats)]
return keep_indices.copy()
return copy.deepcopy(keep_indices)


def _check_labels(table):
Expand Down Expand Up @@ -135,7 +135,7 @@ def __init__(self, **kwargs):

self._seed = None
worker_info = torch.utils.data.get_worker_info()
file_dict = self._init_file_dict.copy()
file_dict = copy.deepcopy(self._init_file_dict)
if worker_info is not None:
# in a worker process
self._name += '_worker%d' % worker_info.id
Expand All @@ -156,7 +156,7 @@ def __init__(self, **kwargs):
def restart(self):
print('=== Restarting DataIter %s, seed=%s ===' % (self._name, self._seed))
# re-shuffle filelist and load range if for training
filelist = self.worker_filelist.copy()
filelist = copy.deepcopy(self.worker_filelist)
if self._sampler_options['shuffle']:
np.random.shuffle(filelist)
if self._file_fraction < 1:
Expand Down Expand Up @@ -260,11 +260,11 @@ def _try_get_next(self, init=False):

def get_data(self, i):
# inputs
X = {k: self.table['_' + k][i].copy() for k in self._data_config.input_names}
X = {k: copy.deepcopy(self.table['_' + k][i]) for k in self._data_config.input_names}
# labels
y = {k: self.table[k][i].copy() for k in self._data_config.label_names}
y = {k: copy.deepcopy(self.table[k][i]) for k in self._data_config.label_names}
# observers / monitor variables
Z = {k: self.table[k][i].copy() for k in self._data_config.z_variables}
Z = {k: copy.deepcopy(self.table[k][i]) for k in self._data_config.z_variables}
return X, y, Z


Expand Down Expand Up @@ -315,6 +315,10 @@ def __init__(self, file_dict, data_config_file,
'max_resample': max_resample,
}

# ==== torch collate_fn map ====
from torch.utils.data._utils.collate import default_collate_fn_map
default_collate_fn_map.update({ak.Array: _collate_awkward_array_fn})

if for_training:
self._sampler_options.update(training=True, shuffle=True, reweight=True)
else:
Expand Down
25 changes: 15 additions & 10 deletions weaver/utils/nn/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,14 @@ def train_classification(
total_loss = 0
num_batches = 0
total_correct = 0
entry_count = 0
count = 0
start_time = time.time()
with tqdm.tqdm(train_loader) as tq:
for X, y, _ in tq:
inputs = [X[k].to(dev) for k in data_config.input_names]
label = y[data_config.label_names[0]].long().to(dev)
entry_count += label.shape[0]
try:
mask = y[data_config.label_names[0] + '_mask'].bool().to(dev)
except KeyError:
Expand Down Expand Up @@ -117,7 +119,7 @@ def train_classification(
break

time_diff = time.time() - start_time
_logger.info('Processed %d entries in total (avg. speed %.1f entries/s)' % (count, count / time_diff))
_logger.info('Processed %d entries in total (avg. speed %.1f entries/s)' % (entry_count, entry_count / time_diff))
_logger.info('Train AvgLoss: %.5f, AvgAcc: %.5f' % (total_loss / num_batches, total_correct / count))
_logger.info('Train class distribution: \n %s', str(sorted(label_counter.items())))

Expand Down Expand Up @@ -157,6 +159,7 @@ def evaluate_classification(model, test_loader, dev, epoch, for_training=True, l
with torch.no_grad():
with tqdm.tqdm(test_loader) as tq:
for X, y, Z in tq:
# X, y: torch.Tensor; Z: ak.Array
inputs = [X[k].to(dev) for k in data_config.input_names]
label = y[data_config.label_names[0]].long().to(dev)
entry_count += label.shape[0]
Expand All @@ -174,7 +177,7 @@ def evaluate_classification(model, test_loader, dev, epoch, for_training=True, l
labels[k].append(_flatten_label(v, mask).numpy(force=True))
if not for_training:
for k, v in Z.items():
observers[k].append(v.numpy(force=True))
observers[k].append(v)

num_examples = label.shape[0]
label_counter.update(label.numpy(force=True))
Expand Down Expand Up @@ -206,7 +209,7 @@ def evaluate_classification(model, test_loader, dev, epoch, for_training=True, l
break

time_diff = time.time() - start_time
_logger.info('Processed %d entries in total (avg. speed %.1f entries/s)' % (count, count / time_diff))
_logger.info('Processed %d entries in total (avg. speed %.1f entries/s)' % (entry_count, entry_count / time_diff))
_logger.info('Evaluation class distribution: \n %s', str(sorted(label_counter.items())))

if tb_helper:
Expand Down Expand Up @@ -259,18 +262,19 @@ def evaluate_onnx(model_path, test_loader, eval_metrics=['roc_auc_score', 'roc_a
start_time = time.time()
with tqdm.tqdm(test_loader) as tq:
for X, y, Z in tq:
inputs = {k: v.cpu().numpy() for k, v in X.items()}
label = y[data_config.label_names[0]].cpu().numpy()
# X, y: torch.Tensor; Z: ak.Array
inputs = {k: v.numpy(force=True) for k, v in X.items()}
label = y[data_config.label_names[0]].numpy(force=True)
num_examples = label.shape[0]
label_counter.update(label)
score = sess.run([], inputs)[0]
preds = score.argmax(1)

scores.append(score)
for k, v in y.items():
labels[k].append(v.cpu().numpy())
labels[k].append(v.numpy(force=True))
for k, v in Z.items():
observers[k].append(v.cpu().numpy())
observers[k].append(v)

correct = (preds == label).sum()
total_correct += correct
Expand Down Expand Up @@ -404,19 +408,20 @@ def evaluate_regression(model, test_loader, dev, epoch, for_training=True, loss_
with torch.no_grad():
with tqdm.tqdm(test_loader) as tq:
for X, y, Z in tq:
# X, y: torch.Tensor; Z: ak.Array
inputs = [X[k].to(dev) for k in data_config.input_names]
label = y[data_config.label_names[0]].float()
num_examples = label.shape[0]
label = label.to(dev)
model_output = model(*inputs)
preds = model_output.squeeze().float()

scores.append(preds.detach().cpu().numpy())
scores.append(preds.numpy(force=True))
for k, v in y.items():
labels[k].append(v.cpu().numpy())
labels[k].append(v.numpy(force=True))
if not for_training:
for k, v in Z.items():
observers[k].append(v.cpu().numpy())
observers[k].append(v)

loss = 0 if loss_func is None else loss_func(preds, label).item()

Expand Down

0 comments on commit a16973f

Please sign in to comment.