diff --git a/setup.py b/setup.py index eeb80001..bf1bfa78 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ install_requires.append(line) setup(name="weaver-core", - version='0.4.5', + version='0.4.6', description="A streamlined deep-learning framework for high energy physics", long_description_content_type="text/markdown", author="H. Qu, C. Li", diff --git a/weaver/train.py b/weaver/train.py index 1514f0b4..fad4a185 100644 --- a/weaver/train.py +++ b/weaver/train.py @@ -9,6 +9,7 @@ import functools import numpy as np import math +import copy import torch from torch.utils.data import DataLoader @@ -824,6 +825,11 @@ def _main(args): else: gpus = None dev = torch.device('cpu') + try: + if torch.backends.mps.is_available(): + dev = torch.device('mps') + except AttributeError: + pass model = orig_model.to(dev) model_path = args.model_prefix if args.model_prefix.endswith( '.pt') else args.model_prefix + '_best_epoch_state.pt' @@ -913,20 +919,23 @@ def main(): if args.cross_validation: model_dir, model_fn = os.path.split(args.model_prefix) - predict_output_base, predict_output_ext = os.path.splitext(args.predict_output) + if args.predict_output: + predict_output_base, predict_output_ext = os.path.splitext(args.predict_output) load_model = args.load_model_weights or None var_name, kfold = args.cross_validation.split('%') kfold = int(kfold) for i in range(kfold): _logger.info(f'\n=== Running cross validation, fold {i} of {kfold} ===') - args.model_prefix = os.path.join(f'{model_dir}_fold{i}', model_fn) - args.predict_output = f'{predict_output_base}_fold{i}' + predict_output_ext - args.extra_selection = f'{var_name}%{kfold}!={i}' - args.extra_test_selection = f'{var_name}%{kfold}=={i}' + opts = copy.deepcopy(args) + opts.model_prefix = os.path.join(f'{model_dir}_fold{i}', model_fn) + if args.predict_output: + opts.predict_output = f'{predict_output_base}_fold{i}' + predict_output_ext + opts.extra_selection = f'{var_name}%{kfold}!={i}' + opts.extra_test_selection = f'{var_name}%{kfold}=={i}' if load_model and '{fold}' in load_model: - args.load_model_weights = load_model.replace('{fold}', f'fold{i}') + opts.load_model_weights = load_model.replace('{fold}', f'fold{i}') - _main(args) + _main(opts) else: _main(args)