-
Notifications
You must be signed in to change notification settings - Fork 3
/
build_modules.py
190 lines (166 loc) · 8.24 KB
/
build_modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import torch
from torch.utils.data import DataLoader, DistributedSampler, ConcatDataset
from torch.utils.data.sampler import BatchSampler, RandomSampler
from datasets.coco_style_dataset import CocoStyleDataset, CocoStyleDatasetTeaching
from models.backbones import ResNet50MultiScale, ResNet18MultiScale, ResNet101MultiScale
from models.positional_encoding import PositionEncodingSine
from models.deformable_detr import DeformableDETR
from models.deformable_transformer import DeformableTransformer
from models.criterion import SetCriterion
from datasets.augmentations import weak_aug, strong_aug, base_trans
def build_sampler(args, dataset, split):
if split == 'train':
if args.distributed:
sampler = DistributedSampler(dataset, shuffle=True, drop_last=True)
else:
sampler = RandomSampler(dataset)
batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=True)
else:
if args.distributed:
sampler = DistributedSampler(dataset, shuffle=False)
else:
sampler = torch.utils.data.SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, args.eval_batch_size, drop_last=False)
return batch_sampler
def build_dataloader(args, dataset_name, domain, split, trans):
dataset = CocoStyleDataset(root_dir=args.data_root,
dataset_name=dataset_name,
domain=domain,
split=split,
transforms=trans)
batch_sampler = build_sampler(args, dataset, split)
data_loader = DataLoader(dataset=dataset,
batch_sampler=batch_sampler,
collate_fn=CocoStyleDataset.collate_fn,
num_workers=args.num_workers)
return data_loader
def build_dataloader_mae(args, source_dataset_name, target_dataset_name, split, trans):
# Create source dataset
source_dataset = CocoStyleDataset(root_dir=args.data_root,
dataset_name=source_dataset_name,
domain='source', # Assuming 'source' as the domain for the source dataset
split=split,
transforms=trans)
# Create target dataset
target_dataset = CocoStyleDataset(root_dir=args.data_root,
dataset_name=target_dataset_name,
domain='target', # Assuming 'target' as the domain for the target dataset
split=split,
transforms=trans)
# Combine both datasets using ConcatDataset
combined_dataset = ConcatDataset([source_dataset, target_dataset])
# Build sampler for the combined dataset
batch_sampler = build_sampler(args, combined_dataset, split)
# Create DataLoader for the combined dataset
data_loader = DataLoader(dataset=combined_dataset,
batch_sampler=batch_sampler,
collate_fn=CocoStyleDataset.collate_fn,
num_workers=args.num_workers)
return data_loader
def build_dataloader_teaching(args, dataset_name, domain, split):
dataset = CocoStyleDatasetTeaching(root_dir=args.data_root,
dataset_name=dataset_name,
domain=domain,
split=split,
weak_aug=weak_aug,
strong_aug=strong_aug,
final_trans=base_trans)
batch_sampler = build_sampler(args, dataset, split)
data_loader = DataLoader(dataset=dataset,
batch_sampler=batch_sampler,
collate_fn=CocoStyleDatasetTeaching.collate_fn_teaching,
num_workers=args.num_workers)
return data_loader
def build_model(args, device):
if args.backbone == 'resnet50':
backbone = ResNet50MultiScale()
elif args.backbone == 'resnet18':
backbone = ResNet18MultiScale()
elif args.backbone == 'resnet101':
backbone = ResNet101MultiScale()
else:
raise ValueError('Invalid args.backbone name: ' + args.backbone)
position_encoding = PositionEncodingSine()
transformer = DeformableTransformer(
hidden_dim=args.hidden_dim,
num_heads=args.num_heads,
num_encoder_layers=args.num_encoder_layers,
num_decoder_layers=args.num_decoder_layers,
feedforward_dim=args.feedforward_dim,
dropout=args.dropout
)
model = DeformableDETR(
backbone=backbone,
position_encoding=position_encoding,
transformer=transformer,
num_classes=args.num_classes,
num_queries=args.num_queries,
num_feature_levels=args.num_feature_levels
)
model.to(device)
return model
def build_criterion(args, device, box_loss=True):
criterion = SetCriterion(
num_classes=args.num_classes,
coef_class=args.coef_class,
coef_boxes=args.coef_boxes if box_loss else 0.0,
coef_giou=args.coef_giou if box_loss else 0.0,
coef_domain=args.coef_domain,
coef_domain_bac=args.coef_domain_bac,
coef_mae=args.coef_mae,
alpha_focal=args.alpha_focal,
alpha_dt=args.alpha_dt,
gamma_dt=args.gamma_dt,
max_dt=args.max_dt,
device=device
)
criterion.to(device)
return criterion
def build_optimizer(args, model, enable_mae=False):
params_backbone = [param for name, param in model.named_parameters()
if 'backbone' in name]
params_linear_proj = [param for name, param in model.named_parameters()
if 'reference_points' in name or 'sampling_offsets' in name]
params = [param for name, param in model.named_parameters()
if 'backbone' not in name and 'reference_points' not in name and 'sampling_offsets' not in name]
param_dicts = [
{'params': params, 'lr': args.lr},
{'params': params_backbone, 'lr': 0.0 if enable_mae else args.lr_backbone},
{'params': params_linear_proj, 'lr': args.lr_linear_proj},
]
if args.sgd:
optimizer = torch.optim.SGD(param_dicts, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
else:
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay)
return optimizer
def build_optimizer_mae(args, model, enable_mae=False):
params_backbone = [param for name, param in model.named_parameters()
if 'backbone' in name]
params_linear_proj = [param for name, param in model.named_parameters()
if 'reference_points' in name or 'sampling_offsets' in name]
params = [param for name, param in model.named_parameters()
if 'backbone' not in name and 'reference_points' not in name and 'sampling_offsets' not in name]
param_dicts = [
{'params': params, 'lr': args.lr},
{'params': params_backbone, 'lr': 0.0 if enable_mae else args.lr_backbone},
{'params': params_linear_proj, 'lr': args.lr_linear_proj},
]
param_dicts_mr = [
{'params': params, 'lr': args.lr},
{'params': params_backbone, 'lr': 0.0 if enable_mae else args.lr_backbone},
{'params': params_linear_proj, 'lr': args.lr_linear_proj},
]
if args.sgd:
optimizer = torch.optim.SGD(param_dicts, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
optimizer_mr = torch.optim.SGD(param_dicts_mr, lr=args.mr_step, weight_decay=args.weight_decay)
else:
optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay)
optimizer_mr = torch.optim.SGD(param_dicts_mr, lr=args.mr_step, weight_decay=args.weight_decay)
return optimizer, optimizer_mr
def build_teacher(args, student_model, device):
teacher_model = build_model(args, device)
state_dict, student_state_dict = teacher_model.state_dict(), student_model.state_dict()
for key, value in state_dict.items():
state_dict[key] = student_state_dict[key].clone().detach()
teacher_model.load_state_dict(state_dict)
return teacher_model