Skip to content

Commit

Permalink
Merge pull request #13 from pauldmccarthy/bf/cross-validate
Browse files Browse the repository at this point in the history
Bug: Fix indexing in cross validate preparation
  • Loading branch information
v-sundaresan authored Jul 27, 2024
2 parents 8229deb + 5e67cc4 commit a324783
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 26 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
install_requires = [l.strip() for l in f.readlines()]

setup(name='truenet',
version='1.0.1',
version='1.0.2',
description='DL method for WMH segmentation',
author='Vaanathi Sundaresan',
install_requires=install_requires,
Expand Down
2 changes: 2 additions & 0 deletions truenet/scripts/truenet
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ if __name__ == "__main__":
optionalNamedft.add_argument('-cp_type', '--cp_save_type', type = str, required=False, default='last', help='Checkpoint saving options: best, last, everyN (default=last)')
optionalNamedft.add_argument('-cp_n', '--cp_everyn_N', type = int, required=False, default=10, help='If -cp_type=everyN, the N value')
optionalNamedft.add_argument('-v', '--verbose', type = bool, required=False, default=False, help='Display debug messages (default=False)')
optionalNamedft.add_argument('-cpu', '--use_cpu', type=bool, required=False, default=False,
help='Perform model fine-tuning on CPU True/False (default=False)')
parser_finetune.set_defaults(func=truenet_commands.fine_tune)


Expand Down
4 changes: 2 additions & 2 deletions truenet/true_net/truenet_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,8 @@ def fine_tune(args):
'Pretrained': pretrained,
'Modelname': model_name,
'SaveResume': args.save_resume_training,
'Numchannels': num_channels
'Numchannels': num_channels,
'Use_CPU': args.use_cpu,
}

if args.cp_save_type not in ['best', 'last', 'everyN']:
Expand Down Expand Up @@ -935,4 +936,3 @@ def cross_validate(args):
intermediate=args.intermediate, save_cp=args.save_checkpoint, save_wei=save_wei,
save_case=args.cp_save_type, verbose=args.verbose, dir_cp=out_dir,
output_dir=out_dir)

4 changes: 2 additions & 2 deletions truenet/true_net/truenet_cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ def main(sub_name_dicts, cv_params, aug=True, weighted=True, intermediate=False,

if fld == (fold - 1):
test_ids = np.arange(fld * test_subs_per_fold, len(sub_name_dicts))
test_sub_dicts = sub_name_dicts[test_ids]
test_sub_dicts = [sub_name_dicts[i] for i in test_ids]
else:
test_ids = np.arange(fld * test_subs_per_fold, (fld+1) * test_subs_per_fold)
test_sub_dicts = sub_name_dicts[test_ids]
test_sub_dicts = [sub_name_dicts[i] for i in test_ids]

rem_sub_ids = np.setdiff1d(np.arange(len(sub_name_dicts)),test_ids)
rem_sub_name_dicts = [sub_name_dicts[idx] for idx in rem_sub_ids]
Expand Down
39 changes: 18 additions & 21 deletions truenet/true_net/truenet_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_
device = torch.device("cpu")

nclass = ft_params['Nclass']
numchannels = ft_params['Num_channels']
numchannels = ft_params['Numchannels']
pretrained = ft_params['Pretrained']
model_name = ft_params['Modelname']

if pretrained == 1:
if pretrained:
nclass = 2
model_axial = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='axial')
model_sagittal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='sagittal')
Expand All @@ -56,13 +56,13 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_
model_sagittal = nn.DataParallel(model_sagittal)
model_coronal = nn.DataParallel(model_coronal)
model_path = os.path.join(model_dir, model_name + '_axial.pth')
model_axial = truenet_utils.loading_model(model_path, model_axial, mode='full_model')
model_axial = truenet_utils.loading_model(model_path, model_axial, device, mode='full_model')

model_path = os.path.join(model_dir, model_name + '_sagittal.pth')
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, mode='full_model')
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, device, mode='full_model')

model_path = os.path.join(model_dir, model_name + '_coronal.pth')
model_coronal = truenet_utils.loading_model(model_path, model_coronal, mode='full_model')
model_coronal = truenet_utils.loading_model(model_path, model_coronal, device, mode='full_model')
else:
try:
model_path = os.path.join(model_dir, model_name + '_axial.pth')
Expand All @@ -75,19 +75,19 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_
model_axial = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='axial')
model_axial.to(device=device)
model_axial = nn.DataParallel(model_axial)
model_axial = truenet_utils.loading_model(model_path, model_axial)
model_axial = truenet_utils.loading_model(model_path, model_axial, device)

model_sagittal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='sagittal')
model_sagittal.to(device=device)
model_sagittal = nn.DataParallel(model_sagittal)
model_path = os.path.join(model_dir, model_name + '_sagittal.pth')
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal)
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, device)

model_coronal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='coronal')
model_coronal.to(device=device)
model_coronal = nn.DataParallel(model_coronal)
model_path = os.path.join(model_dir, model_name + '_coronal.pth')
model_coronal = truenet_utils.loading_model(model_path, model_coronal)
model_coronal = truenet_utils.loading_model(model_path, model_coronal, device)
except:
try:
model_path = os.path.join(model_dir, model_name + '_axial.pth')
Expand All @@ -100,20 +100,20 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_
model_axial = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='axial')
model_axial.to(device=device)
model_axial = nn.DataParallel(model_axial)
model_axial = truenet_utils.loading_model(model_path, model_axial, mode='full_model')
model_axial = truenet_utils.loading_model(model_path, model_axial, device, mode='full_model')

model_path = os.path.join(model_dir, model_name + '_sagittal.pth')
model_sagittal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64,
plane='sagittal')
model_sagittal.to(device=device)
model_sagittal = nn.DataParallel(model_sagittal)
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, mode='full_model')
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, device, mode='full_model')

model_path = os.path.join(model_dir, model_name + '_coronal.pth')
model_coronal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='coronal')
model_coronal.to(device=device)
model_coronal = nn.DataParallel(model_coronal)
model_coronal = truenet_utils.loading_model(model_path, model_coronal, mode='full_model')
model_coronal = truenet_utils.loading_model(model_path, model_coronal, device, mode='full_model')
except ImportError:
raise ImportError('In directory ' + model_dir + ', ' + model_name + '_axial.pth or' +
model_name + '_sagittal.pth or' + model_name + '_coronal.pth ' +
Expand Down Expand Up @@ -153,14 +153,14 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_
print('Axial model: ', str(sum([p.numel() for p in model_axial.parameters()])), flush=True)
print('Sagittal model: ', str(sum([p.numel() for p in model_sagittal.parameters()])), flush=True)
print('Coronal model: ', str(sum([p.numel() for p in model_coronal.parameters()])), flush=True)

model_axial = truenet_utils.freeze_layer_for_finetuning(model_axial, layers_to_ft, verbose=verbose)
model_sagittal = truenet_utils.freeze_layer_for_finetuning(model_sagittal, layers_to_ft, verbose=verbose)
model_coronal = truenet_utils.freeze_layer_for_finetuning(model_coronal, layers_to_ft, verbose=verbose)
model_axial.to(device=device)
model_sagittal.to(device=device)
model_coronal.to(device=device)

print('Total number of trainable parameters', flush=True)
model_parameters = filter(lambda p: p.requires_grad, model_axial.parameters())
params = sum([p.numel() for p in model_parameters])
Expand All @@ -171,7 +171,7 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_
model_parameters = filter(lambda p: p.requires_grad, model_coronal.parameters())
params = sum([p.numel() for p in model_parameters])
print('Coronal model: ', str(params), flush=True)

if optim_type == 'adam':
epsilon = ft_params['Epsilon']
optimizer_axial = optim.Adam(filter(lambda p: p.requires_grad,
Expand All @@ -190,12 +190,12 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_
model_coronal.parameters()), lr=ft_lrt, momentum=moment)
else:
raise ValueError("Invalid optimiser choice provided! Valid options: 'adam', 'sgd'")

if nclass == 2:
criterion = truenet_loss_functions.CombinedLoss()
else:
criterion = truenet_loss_functions.CombinedMultiLoss(nclasses=nclass)

if verbose:
print('Found' + str(len(sub_name_dicts)) + 'subjects', flush=True)

Expand All @@ -210,15 +210,15 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_
augment=aug, weighted=weighted, save_checkpoint=save_cp,
save_weights=save_wei, save_case=save_case, verbose=verbose,
dir_checkpoint=dir_cp)

if req_plane == 'all' or req_plane == 'sagittal':
scheduler = optim.lr_scheduler.MultiStepLR(optimizer_sagittal, milestones, gamma=gamma, last_epoch=-1)
model_sagittal = truenet_train.train_truenet(train_name_dicts, val_name_dicts, model_sagittal, criterion,
optimizer_sagittal, scheduler, ft_params, device, mode='sagittal',
augment=aug, weighted=weighted, save_checkpoint=save_cp,
save_weights=save_wei, save_case=save_case, verbose=verbose,
dir_checkpoint=dir_cp)

if req_plane == 'all' or req_plane == 'coronal':
scheduler = optim.lr_scheduler.MultiStepLR(optimizer_coronal, milestones, gamma=gamma, last_epoch=-1)
model_coronal = truenet_train.train_truenet(train_name_dicts, val_name_dicts, model_coronal, criterion,
Expand All @@ -228,6 +228,3 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_
dir_checkpoint=dir_cp)

print('Model Fine-tuning done!', flush=True)



0 comments on commit a324783

Please sign in to comment.