Skip to content

Commit

Permalink
Fix cross validation.
Browse files Browse the repository at this point in the history
  • Loading branch information
hqucms committed Aug 16, 2023
1 parent 7f89e71 commit 077c6ab
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
install_requires.append(line)

setup(name="weaver-core",
version='0.4.5',
version='0.4.6',
description="A streamlined deep-learning framework for high energy physics",
long_description_content_type="text/markdown",
author="H. Qu, C. Li",
Expand Down
23 changes: 16 additions & 7 deletions weaver/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import functools
import numpy as np
import math
import copy
import torch

from torch.utils.data import DataLoader
Expand Down Expand Up @@ -824,6 +825,11 @@ def _main(args):
else:
gpus = None
dev = torch.device('cpu')
try:
if torch.backends.mps.is_available():
dev = torch.device('mps')
except AttributeError:
pass
model = orig_model.to(dev)
model_path = args.model_prefix if args.model_prefix.endswith(
'.pt') else args.model_prefix + '_best_epoch_state.pt'
Expand Down Expand Up @@ -913,20 +919,23 @@ def main():

if args.cross_validation:
model_dir, model_fn = os.path.split(args.model_prefix)
predict_output_base, predict_output_ext = os.path.splitext(args.predict_output)
if args.predict_output:
predict_output_base, predict_output_ext = os.path.splitext(args.predict_output)
load_model = args.load_model_weights or None
var_name, kfold = args.cross_validation.split('%')
kfold = int(kfold)
for i in range(kfold):
_logger.info(f'\n=== Running cross validation, fold {i} of {kfold} ===')
args.model_prefix = os.path.join(f'{model_dir}_fold{i}', model_fn)
args.predict_output = f'{predict_output_base}_fold{i}' + predict_output_ext
args.extra_selection = f'{var_name}%{kfold}!={i}'
args.extra_test_selection = f'{var_name}%{kfold}=={i}'
opts = copy.deepcopy(args)
opts.model_prefix = os.path.join(f'{model_dir}_fold{i}', model_fn)
if args.predict_output:
opts.predict_output = f'{predict_output_base}_fold{i}' + predict_output_ext
opts.extra_selection = f'{var_name}%{kfold}!={i}'
opts.extra_test_selection = f'{var_name}%{kfold}=={i}'
if load_model and '{fold}' in load_model:
args.load_model_weights = load_model.replace('{fold}', f'fold{i}')
opts.load_model_weights = load_model.replace('{fold}', f'fold{i}')

_main(args)
_main(opts)
else:
_main(args)

Expand Down

0 comments on commit 077c6ab

Please sign in to comment.