forked from Ahmednull/L2CS-Net
-
Notifications
You must be signed in to change notification settings - Fork 0
/
clear_training_utils.py
144 lines (128 loc) · 6.18 KB
/
clear_training_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import argparse
import torchvision
from model import L2CS
def parse_args():
"""Parse input arguments."""
parser = argparse.ArgumentParser(description='Gaze estimation using L2CSNet.')
# mpiigaze
parser.add_argument(
'--gazeMpiimage_dir', dest='gazeMpiimage_dir', help='Directory path for gaze images.',
default='datasets/MPIIFaceGaze/Image', type=str)
parser.add_argument(
'--gazeMpiilabel_dir', dest='gazeMpiilabel_dir', help='Directory path for gaze labels.',
default='datasets/MPIIFaceGaze/Label', type=str)
# GazeCapture
parser.add_argument(
'--gazecapture-dir', help='Root path of dataset.', type=str
)
parser.add_argument(
'--gazecapture-ann', help='Annotations filepath.', type=str
)
# Validation Dataset
parser.add_argument(
'--validation-dir', help='Root path of the dataset.', type=str
)
parser.add_argument(
'--validation-ann', help='Annotations filepath.', type=str
)
# Important args -------------------------------------------------------------------------------------------------------
# ----------------------------------------------------------------------------------------------------------------------
parser.add_argument(
'--dataset', dest='dataset', help='mpiigaze, rtgene, gaze360, ethgaze, gazecapture',
default= "gaze360", type=str)
parser.add_argument(
'--output', dest='output', help='Path of output models.',
default='output/snapshots/', type=str)
parser.add_argument(
'--snapshot', dest='snapshot', help='Path of model snapshot.',
default='', type=str)
parser.add_argument(
'--gpu', dest='gpu_id', help='GPU device id to use [0] or multiple 0,1,2,3',
default='0', type=str)
parser.add_argument(
'--num_epochs', dest='num_epochs', help='Maximum number of training epochs.',
default=60, type=int)
parser.add_argument(
'--batch_size', dest='batch_size', help='Batch size.',
default=1, type=int)
parser.add_argument(
'--arch', dest='arch', help='Network architecture, can be: ResNet18, ResNet34, [ResNet50], ''ResNet101, ResNet152, Squeezenet_1_0, Squeezenet_1_1, MobileNetV2',
default='ResNet50', type=str)
parser.add_argument(
'--alpha', dest='alpha', help='Regression loss coefficient.',
default=1, type=float)
parser.add_argument(
'--lr', dest='lr', help='Base learning rate.',
default=0.00001, type=float)
# ---------------------------------------------------------------------------------------------------------------------
# Important args ------------------------------------------------------------------------------------------------------
# Tensorboard args
parser.add_argument(
'--tb', help='name of the output folder that will be containing data about an experiment.', required=True)
parser.add_argument(
'--cml-tags', nargs='*', help='tags that will be used by ClearML', required=True)
parser.add_argument(
'--beta', type=float, help='weight of pitch loss', default=1.0)
parser.add_argument(
'--gamma', type=float, help='weight of yaw loss', default=1.0)
parser.add_argument(
'--pitch-angle-range', type=int, help='pitch +/- angle range in degrees', default=42)
parser.add_argument(
'--yaw-angle-range', type=int, help='yaw +/- angle range in degrees', default=42)
parser.add_argument(
'--pitch-bin-count', type=int, help='bin count for pitch angle', default=28)
parser.add_argument(
'--yaw-bin-count', type=int, help='bin count for yaw angle', default=28)
args = parser.parse_args()
return args
def get_ignored_params(model):
# Generator function that yields ignored params.
b = [model.module.conv1, model.module.bn1, model.module.fc_finetune]
for i in range(len(b)):
for module_name, module in b[i].named_modules():
if 'bn' in module_name:
module.eval()
for name, param in module.named_parameters():
yield param
def get_non_ignored_params(model):
# Generator function that yields params that will be optimized.
b = [model.module.layer1, model.module.layer2, model.module.layer3, model.module.layer4]
for i in range(len(b)):
for module_name, module in b[i].named_modules():
if 'bn' in module_name:
module.eval()
for name, param in module.named_parameters():
yield param
def get_fc_params(model):
# Generator function that yields fc layer params.
b = [model.module.fc_yaw_gaze, model.module.fc_pitch_gaze]
for i in range(len(b)):
for module_name, module in b[i].named_modules():
for name, param in module.named_parameters():
yield param
def load_filtered_state_dict(model, snapshot):
# By user apaszke from discuss.pytorch.org
model_dict = model.state_dict()
snapshot = {k: v for k, v in snapshot.items() if k in model_dict}
model_dict.update(snapshot)
model.load_state_dict(model_dict)
def getArch_weights(arch, bins):
if arch == 'ResNet18':
model = L2CS(torchvision.models.resnet.BasicBlock, [2, 2, 2, 2], bins)
pre_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
elif arch == 'ResNet34':
model = L2CS(torchvision.models.resnet.BasicBlock, [3, 4, 6, 3], bins)
pre_url = 'https://download.pytorch.org/models/resnet34-333f7ec4.pth'
elif arch == 'ResNet101':
model = L2CS(torchvision.models.resnet.Bottleneck, [3, 4, 23, 3], bins)
pre_url = 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'
elif arch == 'ResNet152':
model = L2CS(torchvision.models.resnet.Bottleneck,[3, 8, 36, 3], bins)
pre_url = 'https://download.pytorch.org/models/resnet152-b121ed2d.pth'
else:
if arch != 'ResNet50':
print('Invalid value for architecture is passed! '
'The default value of ResNet50 will be used instead!')
model = L2CS(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], bins)
pre_url = 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
return model, pre_url