Skip to content

Commit

Permalink
Bug: Missing options and pass-through to fine tune routine
Browse files Browse the repository at this point in the history
  • Loading branch information
pauldmccarthy committed Jul 26, 2024
1 parent a420439 commit 5e67cc4
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
2 changes: 2 additions & 0 deletions truenet/scripts/truenet
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ if __name__ == "__main__":
optionalNamedft.add_argument('-cp_type', '--cp_save_type', type = str, required=False, default='last', help='Checkpoint saving options: best, last, everyN (default=last)')
optionalNamedft.add_argument('-cp_n', '--cp_everyn_N', type = int, required=False, default=10, help='If -cp_type=everyN, the N value')
optionalNamedft.add_argument('-v', '--verbose', type = bool, required=False, default=False, help='Display debug messages (default=False)')
optionalNamedft.add_argument('-cpu', '--use_cpu', type=bool, required=False, default=False,
help='Perform model fine-tuning on CPU True/False (default=False)')
parser_finetune.set_defaults(func=truenet_commands.fine_tune)


Expand Down
4 changes: 2 additions & 2 deletions truenet/true_net/truenet_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,8 @@ def fine_tune(args):
'Pretrained': pretrained,
'Modelname': model_name,
'SaveResume': args.save_resume_training,
'Numchannels': num_channels
'Numchannels': num_channels,
'Use_CPU': args.use_cpu,
}

if args.cp_save_type not in ['best', 'last', 'everyN']:
Expand Down Expand Up @@ -935,4 +936,3 @@ def cross_validate(args):
intermediate=args.intermediate, save_cp=args.save_checkpoint, save_wei=save_wei,
save_case=args.cp_save_type, verbose=args.verbose, dir_cp=out_dir,
output_dir=out_dir)

20 changes: 10 additions & 10 deletions truenet/true_net/truenet_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_
device = torch.device("cpu")

nclass = ft_params['Nclass']
numchannels = ft_params['Num_channels']
numchannels = ft_params['Numchannels']
pretrained = ft_params['Pretrained']
model_name = ft_params['Modelname']

Expand All @@ -56,13 +56,13 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_
model_sagittal = nn.DataParallel(model_sagittal)
model_coronal = nn.DataParallel(model_coronal)
model_path = os.path.join(model_dir, model_name + '_axial.pth')
model_axial = truenet_utils.loading_model(model_path, model_axial, mode='full_model')
model_axial = truenet_utils.loading_model(model_path, model_axial, device, mode='full_model')

model_path = os.path.join(model_dir, model_name + '_sagittal.pth')
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, mode='full_model')
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, device, mode='full_model')

model_path = os.path.join(model_dir, model_name + '_coronal.pth')
model_coronal = truenet_utils.loading_model(model_path, model_coronal, mode='full_model')
model_coronal = truenet_utils.loading_model(model_path, model_coronal, device, mode='full_model')
else:
try:
model_path = os.path.join(model_dir, model_name + '_axial.pth')
Expand All @@ -75,19 +75,19 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_
model_axial = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='axial')
model_axial.to(device=device)
model_axial = nn.DataParallel(model_axial)
model_axial = truenet_utils.loading_model(model_path, model_axial)
model_axial = truenet_utils.loading_model(model_path, model_axial, device)

model_sagittal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='sagittal')
model_sagittal.to(device=device)
model_sagittal = nn.DataParallel(model_sagittal)
model_path = os.path.join(model_dir, model_name + '_sagittal.pth')
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal)
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, device)

model_coronal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='coronal')
model_coronal.to(device=device)
model_coronal = nn.DataParallel(model_coronal)
model_path = os.path.join(model_dir, model_name + '_coronal.pth')
model_coronal = truenet_utils.loading_model(model_path, model_coronal)
model_coronal = truenet_utils.loading_model(model_path, model_coronal, device)
except:
try:
model_path = os.path.join(model_dir, model_name + '_axial.pth')
Expand All @@ -100,20 +100,20 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_
model_axial = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='axial')
model_axial.to(device=device)
model_axial = nn.DataParallel(model_axial)
model_axial = truenet_utils.loading_model(model_path, model_axial, mode='full_model')
model_axial = truenet_utils.loading_model(model_path, model_axial, device, mode='full_model')

model_path = os.path.join(model_dir, model_name + '_sagittal.pth')
model_sagittal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64,
plane='sagittal')
model_sagittal.to(device=device)
model_sagittal = nn.DataParallel(model_sagittal)
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, mode='full_model')
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, device, mode='full_model')

model_path = os.path.join(model_dir, model_name + '_coronal.pth')
model_coronal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='coronal')
model_coronal.to(device=device)
model_coronal = nn.DataParallel(model_coronal)
model_coronal = truenet_utils.loading_model(model_path, model_coronal, mode='full_model')
model_coronal = truenet_utils.loading_model(model_path, model_coronal, device, mode='full_model')
except ImportError:
raise ImportError('In directory ' + model_dir + ', ' + model_name + '_axial.pth or' +
model_name + '_sagittal.pth or' + model_name + '_coronal.pth ' +
Expand Down

0 comments on commit 5e67cc4

Please sign in to comment.