Skip to content

Commit

Permalink
Update truenet_finetune.py
Browse files Browse the repository at this point in the history
  • Loading branch information
v-sundaresan authored Jul 25, 2024
1 parent 1054274 commit 8229deb
Showing 1 changed file with 81 additions and 61 deletions.
142 changes: 81 additions & 61 deletions truenet/true_net/truenet_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# 10-03-2021, Oxford
#=========================================================================================


def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_wei=True, save_case='best',
verbose=True, model_dir=None, dir_cp=None):
'''
Expand All @@ -34,87 +35,106 @@ def main(sub_name_dicts, ft_params, aug=True, weighted=True, save_cp=True, save_
assert len(sub_name_dicts) >= 5, "Number of distinct subjects for fine-tuning cannot be less than 5"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nclass = ft_params['Nclass']
num_channels = ft_params['Numchannels']

model_axial = truenet_model.TrUENet(n_channels=num_channels, n_classes=nclass, init_channels=64, plane='axial')
model_sagittal = truenet_model.TrUENet(n_channels=num_channels, n_classes=nclass, init_channels=64,
plane='sagittal')
model_coronal = truenet_model.TrUENet(n_channels=num_channels, n_classes=nclass, init_channels=64, plane='coronal')


model_axial.to(device=device)
model_sagittal.to(device=device)
model_coronal.to(device=device)
model_axial = nn.DataParallel(model_axial)
model_sagittal = nn.DataParallel(model_sagittal)
model_coronal = nn.DataParallel(model_coronal)
if ft_params['Use_CPU']:
device = torch.device("cpu")

load_case = ft_params['Load_type']
nclass = ft_params['Nclass']
numchannels = ft_params['Num_channels']
pretrained = ft_params['Pretrained']

if pretrained:
if load_case == 'last':
model_path = os.path.join(model_dir, 'Truenet_model_beforeES_axial.pth')
model_axial = truenet_utils.loading_model(model_path, model_axial, device, mode='full_model')

model_path = os.path.join(model_dir, 'Truenet_model_beforeES_sagittal.pth')
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, device, mode='full_model')

model_path = os.path.join(model_dir, 'Truenet_model_beforeES_coronalal.pth')
model_coronal = truenet_utils.loading_model(model_path, model_coronal, device, mode='full_model')
elif load_case == 'best':
model_path = os.path.join(model_dir, 'Truenet_model_bestdice_axial.pth')
model_axial = truenet_utils.loading_model(model_path, model_axial, device, mode='full_model')

model_path = os.path.join(model_dir, 'Truenet_model_bestdice_sagittal.pth')
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, device, mode='full_model')

model_path = os.path.join(model_dir, 'Truenet_model_bestdice_coronal.pth')
model_coronal = truenet_utils.loading_model(model_path, model_coronal, device, mode='full_model')
elif load_case == 'everyN':
cpn = ft_params['EveryN']
try:
model_path = os.path.join(model_dir, 'Truenet_model_epoch' + str(cpn) + '_axial.pth')
model_axial = truenet_utils.loading_model(model_path, model_axial, mode='full_model')

model_path = os.path.join(model_dir, 'Truenet_model_epoch' + str(cpn) + '_sagittal.pth')
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, mode='full_model')

model_path = os.path.join(model_dir, 'Truenet_model_epoch' + str(cpn) + '_coronal.pth')
model_coronal = truenet_utils.loading_model(model_path, model_coronal, mode='full_model')
except ImportError:
raise ImportError(
'Incorrect N value provided for the available pretrained models')
else:
raise ValueError("Invalid saving condition provided! Valid options: best, specific, last")
model_name = ft_params['Modelname']

if pretrained == 1:
nclass = 2
model_axial = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='axial')
model_sagittal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='sagittal')
model_coronal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='coronal')

model_axial.to(device=device)
model_sagittal.to(device=device)
model_coronal.to(device=device)
model_axial = nn.DataParallel(model_axial)
model_sagittal = nn.DataParallel(model_sagittal)
model_coronal = nn.DataParallel(model_coronal)
model_path = os.path.join(model_dir, model_name + '_axial.pth')
model_axial = truenet_utils.loading_model(model_path, model_axial, mode='full_model')

model_path = os.path.join(model_dir, model_name + '_sagittal.pth')
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, mode='full_model')

model_path = os.path.join(model_dir, model_name + '_coronal.pth')
model_coronal = truenet_utils.loading_model(model_path, model_coronal, mode='full_model')
else:
model_name = ft_params['Modelname']

try:
model_path = os.path.join(model_dir, model_name + '_axial.pth')
model_axial = truenet_utils.loading_model(model_path, model_axial, device)

state_dict = torch.load(model_path)
for key, value in state_dict.items():
if 'outconv' in key and 'weight' in key:
nclass = state_dict[key].size()[0]
if 'inpconv' in key and 'weight' in key:
numchannels = value.size()[1]
model_axial = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='axial')
model_axial.to(device=device)
model_axial = nn.DataParallel(model_axial)
model_axial = truenet_utils.loading_model(model_path, model_axial)

model_sagittal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='sagittal')
model_sagittal.to(device=device)
model_sagittal = nn.DataParallel(model_sagittal)
model_path = os.path.join(model_dir, model_name + '_sagittal.pth')
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, device)
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal)

model_coronal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='coronal')
model_coronal.to(device=device)
model_coronal = nn.DataParallel(model_coronal)
model_path = os.path.join(model_dir, model_name + '_coronal.pth')
model_coronal = truenet_utils.loading_model(model_path, model_coronal, device)
model_coronal = truenet_utils.loading_model(model_path, model_coronal)
except:
try:
model_path = os.path.join(model_dir, model_name + '_axial.pth')
model_axial = truenet_utils.loading_model(model_path, model_axial, device, mode='full_model')
state_dict = torch.load(model_path)
for key, value in state_dict.items():
if 'outconv' in key and 'weight' in key:
nclass = state_dict[key].size()[0]
if 'inpconv' in key and 'weight' in key:
numchannels = value.size()[1]
model_axial = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='axial')
model_axial.to(device=device)
model_axial = nn.DataParallel(model_axial)
model_axial = truenet_utils.loading_model(model_path, model_axial, mode='full_model')

model_path = os.path.join(model_dir, model_name + '_sagittal.pth')
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, device, mode='full_model')
model_sagittal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64,
plane='sagittal')
model_sagittal.to(device=device)
model_sagittal = nn.DataParallel(model_sagittal)
model_sagittal = truenet_utils.loading_model(model_path, model_sagittal, mode='full_model')

model_path = os.path.join(model_dir, model_name + '_coronal.pth')
model_coronal = truenet_utils.loading_model(model_path, model_coronal, device, mode='full_model')
model_coronal = truenet_model.TrUENet(n_channels=numchannels, n_classes=nclass, init_channels=64, plane='coronal')
model_coronal.to(device=device)
model_coronal = nn.DataParallel(model_coronal)
model_coronal = truenet_utils.loading_model(model_path, model_coronal, mode='full_model')
except ImportError:
raise ImportError('In directory ' + model_dir + ', ' + model_name + '_axial.pth or' +
model_name + '_sagittal.pth or' + model_name + '_coronal.pth ' +
'does not appear to be a valid model file')

if sub_name_dicts[0]['flair_path'] is None and sub_name_dicts[0]['t1_path'] is None:
raise ImportError('At least FLAIR or T1 must be provided in the masterfile')

if numchannels > 1:
if sub_name_dicts[0]['flair_path'] is None:
raise ImportError('The pretrained model requires 2 channels but FLAIR path not found in masterfile')
elif sub_name_dicts[0]['t1_path'] is None:
raise ImportError('The pretrained model requires 2 channels but T1 path not found in masterfile')

if numchannels == 1:
if sub_name_dicts[0]['flair_path'] is not None and sub_name_dicts[0]['t1_path'] is not None:
raise ImportError(
'Pretrained model requires only 1 channel but FLAIR and T1 are provided in the masterfile')

ft_params['Num_channels'] = numchannels

layers_to_ft = ft_params['Finetuning_layers'] # list of numbers [1,8]
optim_type = ft_params['Optimizer'] # adam, sgd
milestones = ft_params['LR_Milestones'] # list of integers [1, N]
Expand Down

0 comments on commit 8229deb

Please sign in to comment.