From c705cdc7fb217c8ef2b3e6fa3f0a3979fc3dc6c3 Mon Sep 17 00:00:00 2001 From: Shojin Tam <99107633+shojint@users.noreply.github.com> Date: Wed, 1 Jan 2025 15:34:45 +0800 Subject: [PATCH] Fix bugs in train.py This commit fixes Ahmednull#37 --- train.py | 38 +++++++++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index b244d87b..dc2df2b3 100644 --- a/train.py +++ b/train.py @@ -93,6 +93,34 @@ def get_fc_params(model): for module_name, module in b[i].named_modules(): for name, param in module.named_parameters(): yield param + +def get_ignored_params_mpii(model): + # Generator function that yields ignored params. + b = [model.module.conv1, model.module.bn1, model.module.fc_finetune] + for i in range(len(b)): + for module_name, module in b[i].named_modules(): + if 'bn' in module_name: + module.eval() + for name, param in module.named_parameters(): + yield param + +def get_non_ignored_params_mpii(model): + # Generator function that yields params that will be optimized. + b = [model.module.layer1, model.module.layer2, model.module.layer3, model.module.layer4] + for i in range(len(b)): + for module_name, module in b[i].named_modules(): + if 'bn' in module_name: + module.eval() + for name, param in module.named_parameters(): + yield param + +def get_fc_params_mpii(model): + # Generator function that yields fc layer params. + b = [model.module.fc_yaw_gaze, model.module.fc_pitch_gaze] + for i in range(len(b)): + for module_name, module in b[i].named_modules(): + for name, param in module.named_parameters(): + yield param def load_filtered_state_dict(model, snapshot): # By user apaszke from discuss.pytorch.org @@ -272,7 +300,7 @@ def getArch_weights(arch, bins): model = nn.DataParallel(model) model.to(gpu) print('Loading data.') - dataset=Mpiigaze(testlabelpathombined,args.gazeMpiimage_dir, transformations, True, fold) + dataset=Mpiigaze(testlabelpathombined,args.gazeMpiimage_dir, transformations, True, 180, fold) train_loader_gaze = DataLoader( dataset=dataset, batch_size=int(batch_size), @@ -296,9 +324,9 @@ def getArch_weights(arch, bins): # Optimizer gaze optimizer_gaze = torch.optim.Adam([ - {'params': get_ignored_params(model, args.arch), 'lr': 0}, - {'params': get_non_ignored_params(model, args.arch), 'lr': args.lr}, - {'params': get_fc_params(model, args.arch), 'lr': args.lr} + {'params': get_ignored_params_mpii(model), 'lr': 0}, + {'params': get_non_ignored_params_mpii(model), 'lr': args.lr}, + {'params': get_fc_params_mpii(model), 'lr': args.lr} ], args.lr) @@ -375,7 +403,7 @@ def getArch_weights(arch, bins): if epoch % 1 == 0 and epoch < num_epochs: print('Taking snapshot...', torch.save(model.state_dict(), - output+'/fold' + str(fold) +'/'+ + output+'fold' + str(fold) +'/'+ '_epoch_' + str(epoch+1) + '.pkl') )