Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding DINO's contrastive sampling #45

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
with torch.cuda.amp.autocast(enabled=args.amp):
if need_tgt_for_training:
outputs, mask_dict = model(samples, dn_args=(targets, args.scalar, args.label_noise_scale,
args.box_noise_scale, args.num_patterns))
args.box_noise_scale, args.num_patterns, args.contrastive))
loss_dict = criterion(outputs, targets, mask_dict)
else:
outputs = model(samples)
Expand Down Expand Up @@ -82,6 +82,7 @@ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
scaler.update()
else:
# original backward function
optimizer.zero_grad()
losses.backward()
if max_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
Expand Down
20 changes: 15 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import random
import time
from pathlib import Path
from os import path
import os, sys
from typing import Optional


from util.logger import setup_logger

import numpy as np
Expand Down Expand Up @@ -39,6 +39,8 @@ def get_args_parser():
help="label noise ratio to flip")
parser.add_argument('--box_noise_scale', default=0.4, type=float,
help="box noise scale to shift and scale")
parser.add_argument('--contrastive', action="store_true",
help="use contrastive training.")

# about lr
parser.add_argument('--lr', default=1e-4, type=float,
Expand All @@ -50,6 +52,7 @@ def get_args_parser():
parser.add_argument('--weight_decay', default=1e-4, type=float)
parser.add_argument('--epochs', default=50, type=int)
parser.add_argument('--lr_drop', default=40, type=int)
parser.add_argument('--override_resumed_lr_drop', default=False, action='store_true')
parser.add_argument('--drop_lr_now', action="store_true", help="load checkpoint and drop for 12epoch setting")
parser.add_argument('--save_checkpoint_interval', default=10, type=int)
parser.add_argument('--clip_max_norm', default=0.1, type=float,
Expand Down Expand Up @@ -94,6 +97,8 @@ def get_args_parser():
help="Number of attention heads inside the transformer's attentions")
parser.add_argument('--num_queries', default=300, type=int,
help="Number of query slots")
parser.add_argument('--num_results', default=300, type=int,
help="Number of detection results")
parser.add_argument('--pre_norm', action='store_true',
help="Using pre-norm in the Transformer blocks.")
parser.add_argument('--num_select', default=300, type=int,
Expand Down Expand Up @@ -170,7 +175,7 @@ def get_args_parser():
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--debug', action='store_true',
help="For debug only. It will perform only a few steps during trainig and val.")
parser.add_argument('--find_unused_params', action='store_true')
parser.add_argument('--find_unused_params', default=False, action='store_true')

parser.add_argument('--save_results', action='store_true',
help="For eval only. Save the outputs for all images.")
Expand Down Expand Up @@ -222,8 +227,8 @@ def main(args):
logger.info('local_rank: {}'.format(args.local_rank))
logger.info("args: " + str(args) + '\n')

if args.frozen_weights is not None:
assert args.masks, "Frozen training is meant for segmentation only"
#if args.frozen_weights is not None:
# assert args.masks, "Frozen training is meant for segmentation only"
print(args)

device = torch.device(args.device)
Expand Down Expand Up @@ -293,7 +298,7 @@ def main(args):
model_without_ddp.detr.load_state_dict(checkpoint['model'])

output_dir = Path(args.output_dir)
if args.resume:
if args.resume and (args.resume.startswith('https') or path.exists(args.resume)):
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
Expand All @@ -303,6 +308,11 @@ def main(args):
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
if args.override_resumed_lr_drop:
print('Warning: (hack) args.override_resumed_lr_drop is set to True, so args.lr_drop would override lr_drop in resumed lr_scheduler.')
lr_scheduler.step_size = args.lr_drop
lr_scheduler.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
lr_scheduler.step(lr_scheduler.last_epoch)
args.start_epoch = checkpoint['epoch'] + 1

if args.drop_lr_now:
Expand Down
61 changes: 39 additions & 22 deletions models/dn_dab_deformable_detr/dab_deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,22 @@ def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_
self.use_dab = use_dab
self.num_patterns = num_patterns
self.random_refpoints_xy = random_refpoints_xy
self.two_stage = two_stage
# dn label enc
self.label_enc = nn.Embedding(num_classes + 1, hidden_dim - 1) # # for indicator
if not two_stage:
if not use_dab:
self.query_embed = nn.Embedding(num_queries, hidden_dim*2)
else:
if not use_dab:
self.query_embed = nn.Embedding(num_queries, hidden_dim*2)
else:
if not self.two_stage:
self.tgt_embed = nn.Embedding(num_queries, hidden_dim-1) # for indicator
self.refpoint_embed = nn.Embedding(num_queries, 4)

if random_refpoints_xy:
# import ipdb; ipdb.set_trace()
self.refpoint_embed.weight.data[:, :2].uniform_(0,1)
self.refpoint_embed.weight.data[:, :2] = inverse_sigmoid(self.refpoint_embed.weight.data[:, :2])
self.refpoint_embed.weight.data[:, :2].requires_grad = False


if self.num_patterns > 0:
self.patterns_embed = nn.Embedding(self.num_patterns, hidden_dim)

Expand Down Expand Up @@ -116,7 +117,6 @@ def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_
self.backbone = backbone
self.aux_loss = aux_loss
self.with_box_refine = with_box_refine
self.two_stage = two_stage

prior_prob = 0.01
bias_value = -math.log((1 - prior_prob) / prior_prob)
Expand All @@ -129,6 +129,7 @@ def __init__(self, backbone, transformer, num_classes, num_queries, num_feature_

# if two-stage, the last class_embed and bbox_embed is for region proposal generation
num_pred = (transformer.decoder.num_layers + 1) if two_stage else transformer.decoder.num_layers

if with_box_refine:
self.class_embed = _get_clones(self.class_embed, num_pred)
self.bbox_embed = _get_clones(self.bbox_embed, num_pred)
Expand Down Expand Up @@ -188,33 +189,43 @@ def forward(self, samples: NestedTensor, dn_args=None):
masks.append(mask)
pos.append(pos_l)

if self.two_stage:
assert NotImplementedError
elif self.use_dab:
if self.num_patterns == 0:
tgt_all_embed = tgt_embed = self.tgt_embed.weight # nq, 256
refanchor = self.refpoint_embed.weight # nq, 4
# query_embeds = torch.cat((tgt_embed, refanchor), dim=1)
#if self.two_stage:
# assert NotImplementedError
#elif self.use_dab:
if self.use_dab:
if not self.two_stage:
if self.num_patterns == 0:
tgt_all_embed = tgt_embed = self.tgt_embed.weight # nq, 256
refanchor = self.refpoint_embed.weight # nq, 4
# query_embeds = torch.cat((tgt_embed, refanchor), dim=1)
else:
# multi patterns is not used in this version
assert NotImplementedError
else:
# multi patterns is not used in this version
assert NotImplementedError
tgt_all_embed = None
refanchor = None
else:
assert NotImplementedError

# prepare for dn
input_query_label, input_query_bbox, attn_mask, mask_dict = \
prepare_for_dn(dn_args, tgt_all_embed, refanchor, src.size(0), self.training, self.num_queries, self.num_classes,
self.hidden_dim, self.label_enc)
query_embeds = torch.cat((input_query_label, input_query_bbox), dim=2)
if input_query_label is not None and input_query_bbox is not None:
# sometimes the target is empty, add a zero part of label_enc to avoid unused parameters
input_query_label += self.label_enc.weight[0][0]*torch.tensor(0).cuda()
query_embeds = torch.cat((input_query_label, input_query_bbox), dim=2)
else:
query_embeds = None

hs, init_reference, inter_references, enc_outputs_class, enc_outputs_coord_unact = \
self.transformer(srcs, masks, pos, query_embeds, attn_mask)


levels = hs.shape[0]

outputs_classes = []
outputs_coords = []
for lvl in range(hs.shape[0]):
#for lvl in range(hs.shape[0]):
for lvl in range(levels):
if lvl == 0:
reference = init_reference
else:
Expand Down Expand Up @@ -462,6 +473,10 @@ def forward(self, outputs, targets, mask_dict=None):

class PostProcess(nn.Module):
""" This module converts the model's output into the format expected by the coco api"""
def __init__(self, num_select=300, nms_iou_threshold=-1) -> None:
super().__init__()
self.num_select = num_select
self.nms_iou_threshold = nms_iou_threshold

@torch.no_grad()
def forward(self, outputs, target_sizes):
Expand All @@ -472,13 +487,14 @@ def forward(self, outputs, target_sizes):
For evaluation, this must be the original image size (before any data augmentation)
For visualization, this should be the image size after data augment, but before padding
"""
num_select = self.num_select
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']

assert len(out_logits) == len(target_sizes)
assert target_sizes.shape[1] == 2

prob = out_logits.sigmoid()
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1)
topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), num_select, dim=1)
scores = topk_values
topk_boxes = topk_indexes // out_logits.shape[2]
labels = topk_indexes % out_logits.shape[2]
Expand Down Expand Up @@ -547,7 +563,8 @@ def build_dab_deformable_detr(args):
# TODO this is a hack
if args.aux_loss:
aux_weight_dict = {}
for i in range(args.dec_layers - 1):
num_layers = args.dec_layers
for i in range(num_layers - 1):
aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
aux_weight_dict.update({k + f'_enc': v for k, v in weight_dict.items()})
weight_dict.update(aux_weight_dict)
Expand All @@ -558,7 +575,7 @@ def build_dab_deformable_detr(args):
# num_classes, matcher, weight_dict, losses, focal_alpha=0.25
criterion = SetCriterion(num_classes, matcher, weight_dict, losses, focal_alpha=args.focal_alpha)
criterion.to(device)
postprocessors = {'bbox': PostProcess()}
postprocessors = {'bbox': PostProcess(num_select=args.num_results)}
if args.masks:
postprocessors['segm'] = PostProcessSegm()
if args.dataset_file == "coco_panoptic":
Expand Down
Loading