diff --git a/setup.py b/setup.py index b54a95b..9c1271f 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ install_requires = [l.strip() for l in f.readlines()] setup(name='truenet', - version='1.0.1', + version='1.0.2', description='DL method for WMH segmentation', author='Vaanathi Sundaresan', install_requires=install_requires, diff --git a/truenet/scripts/truenet b/truenet/scripts/truenet index e594f0e..94bfd68 100644 --- a/truenet/scripts/truenet +++ b/truenet/scripts/truenet @@ -109,6 +109,8 @@ if __name__ == "__main__": optionalNamedft.add_argument('-cp_type', '--cp_save_type', type = str, required=False, default='last', help='Checkpoint saving options: best, last, everyN (default=last)') optionalNamedft.add_argument('-cp_n', '--cp_everyn_N', type = int, required=False, default=10, help='If -cp_type=everyN, the N value') optionalNamedft.add_argument('-v', '--verbose', type = bool, required=False, default=False, help='Display debug messages (default=False)') + optionalNamedft.add_argument('-cpu', '--use_cpu', type=bool, required=False, default=False, + help='Perform model fine-tuning on CPU True/False (default=False)') parser_finetune.set_defaults(func=truenet_commands.fine_tune) diff --git a/truenet/true_net/truenet_commands.py b/truenet/true_net/truenet_commands.py index 317d0a3..27fe7f0 100644 --- a/truenet/true_net/truenet_commands.py +++ b/truenet/true_net/truenet_commands.py @@ -683,7 +683,8 @@ def fine_tune(args): 'Pretrained': pretrained, 'Modelname': model_name, 'SaveResume': args.save_resume_training, - 'Numchannels': num_channels + 'Numchannels': num_channels, + 'Use_CPU': args.use_cpu, } if args.cp_save_type not in ['best', 'last', 'everyN']: @@ -935,4 +936,3 @@ def cross_validate(args): intermediate=args.intermediate, save_cp=args.save_checkpoint, save_wei=save_wei, save_case=args.cp_save_type, verbose=args.verbose, dir_cp=out_dir, output_dir=out_dir) - diff --git a/truenet/true_net/truenet_cross_validate.py b/truenet/true_net/truenet_cross_validate.py index 27a1c01..61c6b19 100644 --- a/truenet/true_net/truenet_cross_validate.py +++ b/truenet/true_net/truenet_cross_validate.py @@ -95,10 +95,10 @@ def main(sub_name_dicts, cv_params, aug=True, weighted=True, intermediate=False, if fld == (fold - 1): test_ids = np.arange(fld * test_subs_per_fold, len(sub_name_dicts)) - test_sub_dicts = sub_name_dicts[test_ids] + test_sub_dicts = [sub_name_dicts[i] for i in test_ids] else: test_ids = np.arange(fld * test_subs_per_fold, (fld+1) * test_subs_per_fold) - test_sub_dicts = sub_name_dicts[test_ids] + test_sub_dicts = [sub_name_dicts[i] for i in test_ids] rem_sub_ids = np.setdiff1d(np.arange(len(sub_name_dicts)),test_ids) rem_sub_name_dicts = [sub_name_dicts[idx] for idx in rem_sub_ids] diff --git a/truenet/true_net/truenet_finetune.py b/truenet/true_net/truenet_finetune.py index 3f44a47..adb9cca 100644 --- a/truenet/true_net/truenet_finetune.py +++ b/truenet/true_net/truenet_finetune.py @@ -39,11 +39,11 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_ device = torch.device("cpu") nclass = ft_params['Nclass'] - numchannels = ft_params['Num_channels'] + numchannels = ft_params['Numchannels'] pretrained = ft_params['Pretrained'] model_name = ft_params['Modelname'] - if pretrained == 1: + if pretrained: nclass = 2 model_axial = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='axial') model_sagittal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='sagittal') @@ -56,13 +56,13 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_ model_sagittal = nn.DataParallel(model_sagittal) model_coronal = nn.DataParallel(model_coronal) model_path = os.path.join(model_dir, model_name + '_axial.pth') - model_axial = truenet_utils.loading_model(model_path, model_axial, mode='full_model') + model_axial = truenet_utils.loading_model(model_path, model_axial, device, mode='full_model') model_path = os.path.join(model_dir, model_name + '_sagittal.pth') - model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, mode='full_model') + model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, device, mode='full_model') model_path = os.path.join(model_dir, model_name + '_coronal.pth') - model_coronal = truenet_utils.loading_model(model_path, model_coronal, mode='full_model') + model_coronal = truenet_utils.loading_model(model_path, model_coronal, device, mode='full_model') else: try: model_path = os.path.join(model_dir, model_name + '_axial.pth') @@ -75,19 +75,19 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_ model_axial = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='axial') model_axial.to(device=device) model_axial = nn.DataParallel(model_axial) - model_axial = truenet_utils.loading_model(model_path, model_axial) + model_axial = truenet_utils.loading_model(model_path, model_axial, device) model_sagittal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='sagittal') model_sagittal.to(device=device) model_sagittal = nn.DataParallel(model_sagittal) model_path = os.path.join(model_dir, model_name + '_sagittal.pth') - model_sagittal = truenet_utils.loading_model(model_path, model_sagittal) + model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, device) model_coronal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='coronal') model_coronal.to(device=device) model_coronal = nn.DataParallel(model_coronal) model_path = os.path.join(model_dir, model_name + '_coronal.pth') - model_coronal = truenet_utils.loading_model(model_path, model_coronal) + model_coronal = truenet_utils.loading_model(model_path, model_coronal, device) except: try: model_path = os.path.join(model_dir, model_name + '_axial.pth') @@ -100,20 +100,20 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_ model_axial = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='axial') model_axial.to(device=device) model_axial = nn.DataParallel(model_axial) - model_axial = truenet_utils.loading_model(model_path, model_axial, mode='full_model') + model_axial = truenet_utils.loading_model(model_path, model_axial, device, mode='full_model') model_path = os.path.join(model_dir, model_name + '_sagittal.pth') model_sagittal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='sagittal') model_sagittal.to(device=device) model_sagittal = nn.DataParallel(model_sagittal) - model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, mode='full_model') + model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, device, mode='full_model') model_path = os.path.join(model_dir, model_name + '_coronal.pth') model_coronal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='coronal') model_coronal.to(device=device) model_coronal = nn.DataParallel(model_coronal) - model_coronal = truenet_utils.loading_model(model_path, model_coronal, mode='full_model') + model_coronal = truenet_utils.loading_model(model_path, model_coronal, device, mode='full_model') except ImportError: raise ImportError('In directory ' + model_dir + ', ' + model_name + '_axial.pth or' + model_name + '_sagittal.pth or' + model_name + '_coronal.pth ' + @@ -153,14 +153,14 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_ print('Axial model: ', str(sum([p.numel() for p in model_axial.parameters()])), flush=True) print('Sagittal model: ', str(sum([p.numel() for p in model_sagittal.parameters()])), flush=True) print('Coronal model: ', str(sum([p.numel() for p in model_coronal.parameters()])), flush=True) - + model_axial = truenet_utils.freeze_layer_for_finetuning(model_axial, layers_to_ft, verbose=verbose) model_sagittal = truenet_utils.freeze_layer_for_finetuning(model_sagittal, layers_to_ft, verbose=verbose) model_coronal = truenet_utils.freeze_layer_for_finetuning(model_coronal, layers_to_ft, verbose=verbose) model_axial.to(device=device) model_sagittal.to(device=device) model_coronal.to(device=device) - + print('Total number of trainable parameters', flush=True) model_parameters = filter(lambda p: p.requires_grad, model_axial.parameters()) params = sum([p.numel() for p in model_parameters]) @@ -171,7 +171,7 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_ model_parameters = filter(lambda p: p.requires_grad, model_coronal.parameters()) params = sum([p.numel() for p in model_parameters]) print('Coronal model: ', str(params), flush=True) - + if optim_type == 'adam': epsilon = ft_params['Epsilon'] optimizer_axial = optim.Adam(filter(lambda p: p.requires_grad, @@ -190,12 +190,12 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_ model_coronal.parameters()), lr=ft_lrt, momentum=moment) else: raise ValueError("Invalid optimiser choice provided! Valid options: 'adam', 'sgd'") - + if nclass == 2: criterion = truenet_loss_functions.CombinedLoss() else: criterion = truenet_loss_functions.CombinedMultiLoss(nclasses=nclass) - + if verbose: print('Found' + str(len(sub_name_dicts)) + 'subjects', flush=True) @@ -210,7 +210,7 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_ augment=aug, weighted=weighted, save_checkpoint=save_cp, save_weights=save_wei, save_case=save_case, verbose=verbose, dir_checkpoint=dir_cp) - + if req_plane == 'all' or req_plane == 'sagittal': scheduler = optim.lr_scheduler.MultiStepLR(optimizer_sagittal, milestones, gamma=gamma, last_epoch=-1) model_sagittal = truenet_train.train_truenet(train_name_dicts, val_name_dicts, model_sagittal, criterion, @@ -218,7 +218,7 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_ augment=aug, weighted=weighted, save_checkpoint=save_cp, save_weights=save_wei, save_case=save_case, verbose=verbose, dir_checkpoint=dir_cp) - + if req_plane == 'all' or req_plane == 'coronal': scheduler = optim.lr_scheduler.MultiStepLR(optimizer_coronal, milestones, gamma=gamma, last_epoch=-1) model_coronal = truenet_train.train_truenet(train_name_dicts, val_name_dicts, model_coronal, criterion, @@ -228,6 +228,3 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_ dir_checkpoint=dir_cp) print('Model Fine-tuning done!', flush=True) - - -