diff --git a/DINO/datasets/__init__.py b/DINO/datasets/__init__.py index 696cd9f..66a52e0 100644 --- a/DINO/datasets/__init__.py +++ b/DINO/datasets/__init__.py @@ -2,7 +2,8 @@ import torch.utils.data import torchvision -from .coco import build as build_coco +# from .coco import build as build_coco +from DINO.datasets.coco import build as build_coco def get_coco_api_from_dataset(dataset): diff --git a/DINO/datasets/coco.py b/DINO/datasets/coco.py index feca575..7aac303 100644 --- a/DINO/datasets/coco.py +++ b/DINO/datasets/coco.py @@ -19,9 +19,9 @@ import torchvision from pycocotools import mask as coco_mask -from datasets.data_util import preparing_dataset -import datasets.transforms as T -from util.box_ops import box_cxcywh_to_xyxy, box_iou +from DINO.datasets.data_util import preparing_dataset +import DINO.datasets.transforms as T +from DINO.util.box_ops import box_cxcywh_to_xyxy, box_iou __all__ = ['build'] @@ -390,15 +390,34 @@ def __getitem__(self, idx): image_id = self.ids[idx] target = {'image_id': image_id, 'annotations': target} img, target = self.prepare(img, target) + # print() + # print() + # print(type(img)) + # print() + # print() if self._transforms is not None: img, target = self._transforms(img, target) + # print() + # print() + # print('type', type(img)) + # print() + # print() + # qq + # convert to needed format if self.aux_target_hacks is not None: for hack_runner in self.aux_target_hacks: target, img = hack_runner(target, img=img) + # print() + # print() + # print(type(img)) + # print() + # print() + # qqq + return img, target @@ -434,6 +453,7 @@ def __call__(self, image, target): anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] boxes = [obj["bbox"] for obj in anno] + # guard against no boxes via resizing boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) boxes[:, 2:] += boxes[:, :2] diff --git a/DINO/datasets/coco_eval.py b/DINO/datasets/coco_eval.py index 3afa5e0..ee2df6e 100644 --- a/DINO/datasets/coco_eval.py +++ b/DINO/datasets/coco_eval.py @@ -16,7 +16,7 @@ from pycocotools.coco import COCO import pycocotools.mask as mask_util -from util.misc import all_gather +from DINO.util.misc import all_gather class CocoEvaluator(object): diff --git a/DINO/datasets/data_util.py b/DINO/datasets/data_util.py index 8019ccd..f57f324 100644 --- a/DINO/datasets/data_util.py +++ b/DINO/datasets/data_util.py @@ -6,7 +6,7 @@ import torch -from util.slconfig import SLConfig +from DINO.util.slconfig import SLConfig class Error(OSError): pass diff --git a/DINO/datasets/panoptic_eval.py b/DINO/datasets/panoptic_eval.py index 9cb4f83..99e35f8 100644 --- a/DINO/datasets/panoptic_eval.py +++ b/DINO/datasets/panoptic_eval.py @@ -2,7 +2,7 @@ import json import os -import util.misc as utils +import DINO.util.misc as utils try: from panopticapi.evaluation import pq_compute diff --git a/DINO/datasets/transforms.py b/DINO/datasets/transforms.py index dd7ca57..254e8ca 100644 --- a/DINO/datasets/transforms.py +++ b/DINO/datasets/transforms.py @@ -9,8 +9,8 @@ import torchvision.transforms as T import torchvision.transforms.functional as F -from util.box_ops import box_xyxy_to_cxcywh -from util.misc import interpolate +from DINO.util.box_ops import box_xyxy_to_cxcywh +from DINO.util.misc import interpolate def crop(image, target, region): diff --git a/DINO/engine.py b/DINO/engine.py index e5287f5..832a755 100644 --- a/DINO/engine.py +++ b/DINO/engine.py @@ -8,13 +8,13 @@ import sys from typing import Iterable -from util.utils import slprint, to_device +from DINO.util.utils import slprint, to_device import torch -import util.misc as utils -from datasets.coco_eval import CocoEvaluator -from datasets.panoptic_eval import PanopticEvaluator +import DINO.util.misc as utils +from DINO.datasets.coco_eval import CocoEvaluator +from DINO.datasets.panoptic_eval import PanopticEvaluator def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, diff --git a/DINO/main.py b/DINO/main.py index 6f384eb..c472a3e 100644 --- a/DINO/main.py +++ b/DINO/main.py @@ -10,25 +10,25 @@ from pathlib import Path import os, sys from typing import Optional -from util.get_param_dicts import get_param_dict +from DINO.util.get_param_dicts import get_param_dict -from util.logger import setup_logger +from DINO.util.logger import setup_logger import numpy as np import torch from torch.utils.data import DataLoader, DistributedSampler import torch.distributed as dist -import datasets -import util.misc as utils -from datasets import build_dataset, get_coco_api_from_dataset -from engine import evaluate, train_one_epoch, test -import models -from util.slconfig import DictAction, SLConfig -from util.utils import ModelEma, BestMetricHolder +import DINO.datasets +import DINO.util.misc as utils +from DINO.datasets import build_dataset, get_coco_api_from_dataset +from DINO.engine import evaluate, train_one_epoch, test +import DINO.models +from DINO.util.slconfig import DictAction, SLConfig +from DINO.util.utils import ModelEma, BestMetricHolder def get_args_parser(): @@ -86,7 +86,7 @@ def get_args_parser(): def build_model_main(args): # we use register to maintain models from catdet6 on. - from models.registry import MODULE_BUILD_FUNCS + from DINO.models.registry import MODULE_BUILD_FUNCS assert args.modelname in MODULE_BUILD_FUNCS._module_dict build_func = MODULE_BUILD_FUNCS.get(args.modelname) model, criterion, postprocessors = build_func(args) diff --git a/DINO/models/dino/backbone.py b/DINO/models/dino/backbone.py index 01c04d0..ef7b6f3 100644 --- a/DINO/models/dino/backbone.py +++ b/DINO/models/dino/backbone.py @@ -25,7 +25,7 @@ from typing import Dict, List -from util.misc import NestedTensor, clean_state_dict, is_main_process +from DINO.util.misc import NestedTensor, clean_state_dict, is_main_process from .position_encoding import build_position_encoding from .convnext import build_convnext diff --git a/DINO/models/dino/convnext.py b/DINO/models/dino/convnext.py index 93cb006..e3d5db5 100644 --- a/DINO/models/dino/convnext.py +++ b/DINO/models/dino/convnext.py @@ -12,7 +12,7 @@ import torch.nn.functional as F from timm.models.layers import trunc_normal_, DropPath -from util.misc import NestedTensor +from DINO.util.misc import NestedTensor # from timm.models.registry import register_model class Block(nn.Module): diff --git a/DINO/models/dino/deformable_transformer.py b/DINO/models/dino/deformable_transformer.py index e9b699e..46ac5f2 100644 --- a/DINO/models/dino/deformable_transformer.py +++ b/DINO/models/dino/deformable_transformer.py @@ -15,7 +15,7 @@ import copy from typing import Optional -from util.misc import inverse_sigmoid +from DINO.util.misc import inverse_sigmoid import torch from torch import nn, Tensor diff --git a/DINO/models/dino/dino.py b/DINO/models/dino/dino.py index d8d09f8..9e895c3 100644 --- a/DINO/models/dino/dino.py +++ b/DINO/models/dino/dino.py @@ -21,8 +21,8 @@ from torch import nn from torchvision.ops.boxes import nms -from util import box_ops -from util.misc import (NestedTensor, nested_tensor_from_tensor_list, +from DINO.util import box_ops +from DINO.util.misc import (NestedTensor, nested_tensor_from_tensor_list, accuracy, get_world_size, interpolate, is_dist_avail_and_initialized, inverse_sigmoid) diff --git a/DINO/models/dino/dn_components.py b/DINO/models/dino/dn_components.py index e9f995b..bfca7fd 100644 --- a/DINO/models/dino/dn_components.py +++ b/DINO/models/dino/dn_components.py @@ -9,11 +9,11 @@ import torch -from util.misc import (NestedTensor, nested_tensor_from_tensor_list, +from DINO.util.misc import (NestedTensor, nested_tensor_from_tensor_list, accuracy, get_world_size, interpolate, is_dist_avail_and_initialized, inverse_sigmoid) # from .DABDETR import sigmoid_focal_loss -from util import box_ops +from DINO.util import box_ops import torch.nn.functional as F diff --git a/DINO/models/dino/matcher.py b/DINO/models/dino/matcher.py index 3b54730..2a9ad82 100644 --- a/DINO/models/dino/matcher.py +++ b/DINO/models/dino/matcher.py @@ -19,7 +19,7 @@ from scipy.optimize import linear_sum_assignment from torch import nn -from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou +from DINO.util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou class HungarianMatcher(nn.Module): diff --git a/DINO/models/dino/position_encoding.py b/DINO/models/dino/position_encoding.py index 255a4d1..083408f 100644 --- a/DINO/models/dino/position_encoding.py +++ b/DINO/models/dino/position_encoding.py @@ -18,7 +18,7 @@ import torch from torch import nn -from util.misc import NestedTensor +from DINO.util.misc import NestedTensor class PositionEmbeddingSine(nn.Module): diff --git a/DINO/models/dino/segmentation.py b/DINO/models/dino/segmentation.py index 1e09314..3af8096 100644 --- a/DINO/models/dino/segmentation.py +++ b/DINO/models/dino/segmentation.py @@ -24,8 +24,8 @@ from torch import Tensor from PIL import Image -import util.box_ops as box_ops -from util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list +import DINO.util.box_ops as box_ops +from DINO.util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list try: from panopticapi.utils import id2rgb, rgb2id diff --git a/DINO/models/dino/swin_transformer.py b/DINO/models/dino/swin_transformer.py index 63b6038..fc0299c 100644 --- a/DINO/models/dino/swin_transformer.py +++ b/DINO/models/dino/swin_transformer.py @@ -12,7 +12,7 @@ import torch.utils.checkpoint as checkpoint import numpy as np from timm.models.layers import DropPath, to_2tuple, trunc_normal_ -from util.misc import NestedTensor +from DINO.util.misc import NestedTensor class Mlp(nn.Module): diff --git a/DINO/util/utils.py b/DINO/util/utils.py index d747bef..4b9159a 100644 --- a/DINO/util/utils.py +++ b/DINO/util/utils.py @@ -202,7 +202,7 @@ def inverse_sigmoid(x, eps=1e-5): return torch.log(x1/x2) import argparse -from util.slconfig import SLConfig +from DINO.util.slconfig import SLConfig def get_raw_dict(args): """ return the dicf contained in args. diff --git a/main.py b/main.py index 3e9689a..0e51cc3 100644 --- a/main.py +++ b/main.py @@ -254,6 +254,60 @@ def _get_args_parser(): help="Softmax temperature for part-seg model", ) parser.add_argument("--save-all-epochs", action="store_true") + + + + + ### DINO args + # TODO: clean + from DINO.util.slconfig import DictAction + # parser.add_argument('--config_file', '-c', type=str, required=True) + parser.add_argument('--options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file.') + + # dataset parameters + parser.add_argument('--dataset_file', default='coco') + parser.add_argument('--coco_path', type=str, default='/comp_robot/cv_public_dataset/COCO2017/') + parser.add_argument('--coco_panoptic_path', type=str) + parser.add_argument('--remove_difficult', action='store_true') + parser.add_argument('--fix_size', action='store_true') + + + # training parameters + parser.add_argument('--output_dir', default='', + help='path where to save, empty for no saving') + parser.add_argument('--note', default='', + help='add some notes to the experiment') + parser.add_argument('--device', default='cuda', + help='device to use for training / testing') + parser.add_argument('--pretrain_model_path', help='load from other checkpoint') + parser.add_argument('--finetune_ignore', type=str, nargs='+') + parser.add_argument('--start_epoch', default=0, type=int, metavar='N', + help='start epoch') + parser.add_argument('--eval', action='store_true') + parser.add_argument('--num_workers', default=10, type=int) + parser.add_argument('--test', action='store_true') + parser.add_argument('--find_unused_params', action='store_true') + + + parser.add_argument('--save_results', action='store_true') + parser.add_argument('--save_log', action='store_true') + + # distributed training parameters + parser.add_argument('--world_size', default=1, type=int, + help='number of distributed processes') + parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') + + parser.add_argument("--local_rank", type=int, help='local rank for DistributedDataParallel') + parser.add_argument('--amp', action='store_true', + help="Train with mixed precision") + + + + return parser @@ -277,8 +331,78 @@ def main(args): loaders = load_dataset(args) train_loader, train_sampler, val_loader, test_loader = loaders + debug = False + if debug: + # DEBUGGING DATALOADER + for i, samples in enumerate(train_loader): + + import torchvision + + images, target_bbox, targets = samples + + # TODO: what is second return here? + images, _ = images.decompose() + + debug_index = 1 + # shape = images[debug_index].shape + torchvision.utils.save_image(images[debug_index], 'debug.png') + img_uint8 = torchvision.io.read_image('debug.png') + shape = target_bbox[debug_index]['size'] + print(target_bbox[debug_index]) + + # xc, xy, w, h convert to xmin, ymin, xmax, ymax + boxes = target_bbox[debug_index]['boxes'] + boxes[:, ::2] = boxes[:, ::2] * shape[1] + boxes[:, 1::2] = boxes[:, 1::2] * shape[0] + + box_width = boxes[:, 2] + box_height = boxes[:, 3] + + boxes[:, 0] = boxes[:, 0] - box_width/2 + boxes[:, 2] = boxes[:, 0] + box_width + boxes[:, 1] = boxes[:, 1] - box_height/2 + boxes[:, 3] = boxes[:, 1] + box_height + + boxes = torch.tensor(boxes, dtype=torch.int) + img_with_boxes = torchvision.utils.draw_bounding_boxes(img_uint8, boxes=boxes, colors='red') + torchvision.utils.save_image(img_with_boxes/255, 'debug_mask.png') + import pdb + pdb.set_trace() + # Create model print("=> creating model") + + + if 'dino' in args.experiment: + from DINO.util.slconfig import SLConfig + # TODO: do not hardcode + args.config_file = 'DINO/config/DINO/DINO_4scale_modified.py' + # duplicate keys that exist are [num_classes, lr] + args.options = {'dn_scalar': 100} + + # TODO: add as args. this was in the original script by dino, DINO_eval.sh + # args.embed_init_tgt = True + args.dn_label_coef=1.0 + args.dn_bbox_coef=1.0 + # args.use_ema=False + # args.dn_box_noise_scale=1.0 + + # load cfg file and update the args + print("Loading config file from {}".format(args.config_file)) + cfg = SLConfig.fromfile(args.config_file) + if args.options is not None: + cfg.merge_from_dict(args.options) + + cfg_dict = cfg._cfg_dict.to_dict() + args_vars = vars(args) + for k,v in cfg_dict.items(): + if k not in args_vars: + setattr(args, k, v) + else: + raise ValueError("Key {} can used by args only".format(k)) + + + model, optimizer, scaler = build_model(args) cudnn.benchmark = True @@ -457,39 +581,86 @@ def _train( images, targets, _ = samples segs = None else: - images, segs, targets = samples - segs = segs.cuda(args.gpu, non_blocking=True) - - images = images.cuda(args.gpu, non_blocking=True) - targets = targets.cuda(args.gpu, non_blocking=True) - batch_size = images.size(0) + if 'dino' in args.experiment: + try: + need_tgt_for_training = args.use_dn + except: + need_tgt_for_training = False + + images, target_bbox, targets = samples + targets = torch.Tensor(targets) + + + else: + images, segs, targets = samples + segs = segs.cuda(args.gpu, non_blocking=True) + + + if 'dino' in args.experiment: + from DINO.util.utils import to_device + device = 'cuda' + images = images.to(device) + target_bbox = [{k: to_device(v, device) for k, v in t.items()} for t in target_bbox] + # images = images.cuda(args.gpu, non_blocking=True) + # target_bbox = [{k: v.cuda(args.gpu, non_blocking=True) for k, v in t.items()} for t in target_bbox] + targets = targets.cuda(args.gpu, non_blocking=True) + batch_size = targets.size(0) + # batch_size = images.size(0) + else: + images = images.cuda(args.gpu, non_blocking=True) + targets = targets.cuda(args.gpu, non_blocking=True) + batch_size = images.size(0) # Compute output with amp.autocast(enabled=not args.full_precision): - if attack.use_mask: - # Attack for part models where both class and segmentation - # labels are used - images = attack(images, targets, segs) - if attack.dual_losses: - targets = torch.cat([targets, targets], axis=0) - segs = torch.cat([segs, segs], axis=0) - else: - # Attack for either classifier or segmenter alone - images = attack(images, targets) + if 'dino' in args.experiment: + if need_tgt_for_training: + outputs, dino_outputs = model(images, return_mask=need_tgt_for_training) + else: + outputs = model(images) + + loss_dict = criterion(outputs, dino_outputs, target_bbox, targets) + + # def forward( + # self, + # logits: Union[list, tuple], + # dino_outputs: dict, + # dino_targets: list, + # targets: torch.Tensor, + # return_indices=False): + + + weight_dict = criterion.weight_dict + # import ipdb; ipdb.set_trace() + loss = criterion() + loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) - if segs is None or seg_only: - outputs = model(images) - loss = criterion(outputs, targets) - elif "groundtruth" in args.experiment: - outputs = model(images, segs=segs) - loss = criterion(outputs, targets) else: - outputs = model(images, return_mask=True) - loss = criterion(outputs, targets, segs) - outputs = outputs[0] - - if args.adv_train in ("trades", "mat"): - outputs = outputs[batch_size:] + pass + # if attack.use_mask: + # # Attack for part models where both class and segmentation + # # labels are used + # images = attack(images, targets, segs) + # if attack.dual_losses: + # targets = torch.cat([targets, targets], axis=0) + # segs = torch.cat([segs, segs], axis=0) + # else: + # # Attack for either classifier or segmenter alone + # images = attack(images, targets) + + # if segs is None or seg_only: + # outputs = model(images) + # loss = criterion(outputs, targets) + # elif "groundtruth" in args.experiment: + # outputs = model(images, segs=segs) + # loss = criterion(outputs, targets) + # else: + # outputs = model(images, return_mask=True) + # loss = criterion(outputs, targets, segs) + # outputs = outputs[0] + + # if args.adv_train in ("trades", "mat"): + # outputs = outputs[batch_size:] if not math.isfinite(loss.item()): print("Loss is {}, stopping training".format(loss.item())) @@ -562,17 +733,40 @@ def _validate(val_loader, model, criterion, attack, args): images, targets, _ = samples segs = None else: - images, segs, targets = samples - segs = segs.cuda(args.gpu, non_blocking=True) + if 'dino' in args.experiment: + try: + need_tgt_for_training = args.use_dn + except: + need_tgt_for_training = False + + images, target_bbox, targets = samples + targets = torch.Tensor(targets) + + + else: + images, segs, targets = samples + segs = segs.cuda(args.gpu, non_blocking=True) + # DEBUG if args.debug: save_image(COLORMAP[segs].permute(0, 3, 1, 2), "gt.png") save_image(images, "test.png") - images = images.cuda(args.gpu, non_blocking=True) - targets = targets.cuda(args.gpu, non_blocking=True) - batch_size = images.size(0) + if 'dino' in args.experiment: + from DINO.util.utils import to_device + device = 'cuda' + images = images.to(device) + target_bbox = [{k: to_device(v, device) for k, v in t.items()} for t in target_bbox] + # images = images.cuda(args.gpu, non_blocking=True) + # target_bbox = [{k: v.cuda(args.gpu, non_blocking=True) for k, v in t.items()} for t in target_bbox] + targets = targets.cuda(args.gpu, non_blocking=True) + batch_size = targets.size(0) + # batch_size = images.size(0) + else: + images = images.cuda(args.gpu, non_blocking=True) + targets = targets.cuda(args.gpu, non_blocking=True) + batch_size = images.size(0) # DEBUG: fixed clean segmentation masks if "clean" in args.experiment: @@ -580,33 +774,47 @@ def _validate(val_loader, model, criterion, attack, args): # compute output with torch.no_grad(): - if attack.use_mask: - images = attack(images, targets, segs) + if 'dino' in args.experiment: + # TODO: need to attack images + if need_tgt_for_training: + outputs, dino_outputs = model(images, return_mask=need_tgt_for_training) + else: + outputs = model(images) + + loss_dict = criterion(outputs, dino_outputs, target_bbox, targets) + + weight_dict = criterion.weight_dict + # import ipdb; ipdb.set_trace() + loss = criterion() + loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) else: - images = attack(images, targets) - - # Need to duplicate segs and targets to match images expanded by - # image corruption attack - if images.shape[0] != targets.shape[0]: - ratio = images.shape[0] // targets.shape[0] - targets = targets.repeat( - (ratio,) + (1,) * (len(targets.shape) - 1) - ) - if segs: - segs = segs.repeat((ratio,) + (1,) * (len(segs.shape) - 1)) - - if segs is None or "normal" in args.experiment or seg_only: - outputs = model(images) - elif "groundtruth" in args.experiment: - outputs = model(images, segs=segs) + if attack.use_mask: + images = attack(images, targets, segs) + else: + images = attack(images, targets) + + # Need to duplicate segs and targets to match images expanded by + # image corruption attack + if images.shape[0] != targets.shape[0]: + ratio = images.shape[0] // targets.shape[0] + targets = targets.repeat( + (ratio,) + (1,) * (len(targets.shape) - 1) + ) + if segs: + segs = segs.repeat((ratio,) + (1,) * (len(segs.shape) - 1)) + + if segs is None or "normal" in args.experiment or seg_only: + outputs = model(images) + elif "groundtruth" in args.experiment: + outputs = model(images, segs=segs) + loss = criterion(outputs, targets) + else: + outputs, masks = model(images, return_mask=True) + if "centroid" in args.experiment: + masks, _, _, _ = masks + pixel_acc = pixel_accuracy(masks, segs) + pacc.update(pixel_acc.item(), batch_size) loss = criterion(outputs, targets) - else: - outputs, masks = model(images, return_mask=True) - if "centroid" in args.experiment: - masks, _, _, _ = masks - pixel_acc = pixel_accuracy(masks, segs) - pacc.update(pixel_acc.item(), batch_size) - loss = criterion(outputs, targets) # DEBUG # if args.debug and isinstance(attack, PGDAttackModule): diff --git a/panoptic_parts/__init__.py b/panoptic_parts/__init__.py index 44f5042..02c08a6 100644 --- a/panoptic_parts/__init__.py +++ b/panoptic_parts/__init__.py @@ -1,6 +1,8 @@ from panoptic_parts.utils.format import decode_uids, encode_ids from panoptic_parts.utils.visualization import uid2color, random_colors from panoptic_parts.utils.utils import safe_write - +# from .panoptic_parts.utils.format import decode_uids, encode_ids +# from .panoptic_parts.utils.visualization import uid2color, random_colors +# from .panoptic_parts.utils.utils import safe_write __version__ = '2.0' diff --git a/panoptic_parts/utils/format.py b/panoptic_parts/utils/format.py index 3fa90fb..c5c10e3 100644 --- a/panoptic_parts/utils/format.py +++ b/panoptic_parts/utils/format.py @@ -13,6 +13,11 @@ _sparse_ids_mapping_to_dense_ids_mapping as ndarray_from_dict from panoptic_parts.utils.utils import compare_pixelwise +# from panoptic_parts.specs.dataset_spec import DatasetSpec +# from panoptic_parts.utils.utils import \ +# _sparse_ids_mapping_to_dense_ids_mapping as ndarray_from_dict +# from panoptic_parts.utils.utils import compare_pixelwise + TENSORFLOW_IMPORTED = False try: import tensorflow as tf # pylint: disable=import-error diff --git a/part_model/dataloader/__init__.py b/part_model/dataloader/__init__.py index 4b345ea..5cc77ea 100644 --- a/part_model/dataloader/__init__.py +++ b/part_model/dataloader/__init__.py @@ -1,7 +1,14 @@ +# from .part_imagenet import PART_IMAGENET_BBOX + +# DATASET_DICT = { +# 'part-imagenet-bbox': PART_IMAGENET_BBOX +# } + + from .part_imagenet_corrupt import PART_IMAGENET_CORRUPT from .part_imagenet_mixed_next import PART_IMAGENET_MIXED from .cityscapes import CITYSCAPES -from .part_imagenet import PART_IMAGENET +from .part_imagenet import PART_IMAGENET, PART_IMAGENET_BBOX from .part_imagenet_geirhos import PART_IMAGENET_GEIRHOS from .pascal_part import PASCAL_PART from .pascal_voc import PASCAL_VOC @@ -14,7 +21,8 @@ 'part-imagenet': PART_IMAGENET, 'part-imagenet-geirhos': PART_IMAGENET_GEIRHOS, 'part-imagenet-mixed': PART_IMAGENET_MIXED, - 'part-imagenet-corrupt': PART_IMAGENET_CORRUPT + 'part-imagenet-corrupt': PART_IMAGENET_CORRUPT, + 'part-imagenet-bbox': PART_IMAGENET_BBOX } diff --git a/part_model/dataloader/part_imagenet.py b/part_model/dataloader/part_imagenet.py index e1a75c3..52ec3b1 100644 --- a/part_model/dataloader/part_imagenet.py +++ b/part_model/dataloader/part_imagenet.py @@ -1,6 +1,7 @@ import os import numpy as np +import json import torch import torch.utils.data as data from part_model.dataloader.util import COLORMAP @@ -240,3 +241,323 @@ def load_part_imagenet(args): "input_dim": (3, 224, 224), "colormap": COLORMAP, } + + + + + + + + + + + + + + + + + + + + + + +import torchvision +from DINO.datasets.coco import ConvertCocoPolysToMask, dataset_hook_register +from pathlib import Path + +class PartImageNetBBOXDataset(torchvision.datasets.CocoDetection): + def __init__(self, img_folder, class_label_file, ann_file, transforms, aux_target_hacks=None): + super(PartImageNetBBOXDataset, self).__init__(img_folder, ann_file) + self._transforms = transforms + self.prepare = ConvertCocoPolysToMask(False) + self.aux_target_hacks = aux_target_hacks + + with open(class_label_file, 'r') as openfile: + self.image_to_label = json.load(openfile) + + self.classes = list(sorted(set(self.image_to_label.values()))) + self.num_classes = len(self.classes) + + + def change_hack_attr(self, hackclassname, attrkv_dict): + target_class = dataset_hook_register[hackclassname] + for item in self.aux_target_hacks: + if isinstance(item, target_class): + for k,v in attrkv_dict.items(): + setattr(item, k, v) + + def get_hack(self, hackclassname): + target_class = dataset_hook_register[hackclassname] + for item in self.aux_target_hacks: + if isinstance(item, target_class): + return item + + def __getitem__(self, idx): + """ + Output: + - target: dict of multiple items + - boxes: Tensor[num_box, 4]. \ + Init type: x0,y0,x1,y1. unnormalized data. + Final type: cx,cy,w,h. normalized data. + """ + try: + img, target = super(PartImageNetBBOXDataset, self).__getitem__(idx) + except: + print("Error idx: {}".format(idx)) + idx += 1 + img, target = super(PartImageNetBBOXDataset, self).__getitem__(idx) + image_id = self.ids[idx] + target = {'image_id': image_id, 'annotations': target} + img, target = self.prepare(img, target) + + if self._transforms is not None: + img, target = self._transforms(img, target) + + # convert to needed format + if self.aux_target_hacks is not None: + for hack_runner in self.aux_target_hacks: + target, img = hack_runner(target, img=img) + + class_label = self.image_to_label[str(idx)] + return img, target, class_label + + +# from DINO.datasets.coco import CocoDetection + +# class PartImageNetBBOXDataset(CocoDetection): +# def __init__( +# self, +# img_folder, +# class_label_file, +# ann_file, +# transforms, +# aux_target_hacks=None +# ): +# super(PartImageNetBBOXDataset, self).__init__(img_folder, ann_file, transforms, False, aux_target_hacks) + +# with open(class_label_file, 'r') as openfile: +# self.image_to_label = json.load(openfile) + +# self.classes = list(sorted(set(self.image_to_label.values()))) +# self.num_classes = len(self.classes) + +# def __getitem__(self, index): +# img, target = super().__getitem__(index) +# # print(type(img)) +# # qq +# class_label = self.image_to_label[str(index)] +# return img, target, class_label + + + +def get_loader_sampler_bbox(args, transforms, split): + is_train = split == "train" + + # TODO: add as arg + root = Path('/data/shared/PartImageNet/PartBoxSegmentations') + + PATHS = { + "train": (root / "train", root / "image_labels" / 'train.json', root / "annotations" / 'train.json'), + "val": (root / "val", root / "image_labels" / 'val.json', root / "annotations" / 'val.json'), + "test": (root / "test", root / "image_labels" / 'test.json', root / "annotations" / 'test.json' ), + } + + img_folder, class_label_file, ann_file = PATHS[split] + + part_imagenet_dataset = PartImageNetBBOXDataset( + img_folder, + class_label_file, + ann_file, + transforms, + aux_target_hacks=None + ) + + sampler = None + if args.distributed: + shuffle = None + if is_train: + sampler = torch.utils.data.distributed.DistributedSampler( + part_imagenet_dataset, + shuffle=True, + seed=args.seed, + drop_last=False, + ) + else: + # Use distributed sampler for validation but not testing + sampler = DistributedEvalSampler( + part_imagenet_dataset, shuffle=False, seed=args.seed + ) + else: + # shuffle = is_train + shuffle = True + + from DINO.util.misc import collate_fn + batch_size = args.batch_size + loader = torch.utils.data.DataLoader( + part_imagenet_dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=args.workers, + pin_memory=True, + sampler=sampler, + drop_last=is_train, + collate_fn=collate_fn + ) + + # TODO: can we make this cleaner? + # PART_IMAGENET["part_to_class"] = part_imagenet_dataset.part_to_class + PART_IMAGENET["num_classes"] = part_imagenet_dataset.num_classes + # PART_IMAGENET["num_seg_labels"] = part_imagenet_dataset.num_seg_labels + + setattr(args, "num_classes", part_imagenet_dataset.num_classes) + # pto = part_imagenet_dataset.part_to_object + # if seg_type == "part": + # seg_labels = len(pto) + # elif seg_type == "fg": + # seg_labels = 2 + # else: + # seg_labels = pto.max().item() + 1 + # setattr(args, "seg_labels", seg_labels) + + return loader, sampler + + +import DINO.datasets.transforms as T + +def make_coco_transforms(image_set, fix_size=False, strong_aug=False, args=None): + normalize = T.Compose([ + T.ToTensor(), + T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + + # config the params for data aug + scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] + max_size = 1333 + scales2_resize = [400, 500, 600] + scales2_crop = [384, 600] + + # update args from config files + scales = getattr(args, 'data_aug_scales', scales) + max_size = getattr(args, 'data_aug_max_size', max_size) + scales2_resize = getattr(args, 'data_aug_scales2_resize', scales2_resize) + scales2_crop = getattr(args, 'data_aug_scales2_crop', scales2_crop) + + # resize them + data_aug_scale_overlap = getattr(args, 'data_aug_scale_overlap', None) + if data_aug_scale_overlap is not None and data_aug_scale_overlap > 0: + data_aug_scale_overlap = float(data_aug_scale_overlap) + scales = [int(i*data_aug_scale_overlap) for i in scales] + max_size = int(max_size*data_aug_scale_overlap) + scales2_resize = [int(i*data_aug_scale_overlap) for i in scales2_resize] + scales2_crop = [int(i*data_aug_scale_overlap) for i in scales2_crop] + # else: + # scales = getattr(args, 'data_aug_scales', scales) + # max_size = getattr(args, 'data_aug_max_size', max_size) + # scales2_resize = getattr(args, 'data_aug_scales2_resize', scales2_resize) + # scales2_crop = getattr(args, 'data_aug_scales2_crop', scales2_crop) + + datadict_for_print = { + 'scales': scales, + 'max_size': max_size, + 'scales2_resize': scales2_resize, + 'scales2_crop': scales2_crop + } + print("data_aug_params:", json.dumps(datadict_for_print, indent=2)) + + + if image_set == 'train': + if fix_size: + return T.Compose([ + T.RandomHorizontalFlip(), + T.RandomResize([(max_size, max(scales))]), + # T.RandomResize([(512, 512)]), + normalize, + ]) + + if strong_aug: + import DINO.datasets.sltransform as SLT + + return T.Compose([ + T.RandomHorizontalFlip(), + T.RandomSelect( + T.RandomResize(scales, max_size=max_size), + T.Compose([ + T.RandomResize(scales2_resize), + T.RandomSizeCrop(*scales2_crop), + T.RandomResize(scales, max_size=max_size), + ]) + ), + SLT.RandomSelectMulti([ + SLT.RandomCrop(), + # SLT.Rotate(10), + SLT.LightingNoise(), + SLT.AdjustBrightness(2), + SLT.AdjustContrast(2), + ]), + # # for debug only + # SLT.RandomCrop(), + # SLT.LightingNoise(), + # SLT.AdjustBrightness(2), + # SLT.AdjustContrast(2), + # SLT.Rotate(10), + normalize, + ]) + + return T.Compose([ + T.RandomHorizontalFlip(), + T.RandomSelect( + T.RandomResize(scales, max_size=max_size), + T.Compose([ + T.RandomResize(scales2_resize), + T.RandomSizeCrop(*scales2_crop), + T.RandomResize(scales, max_size=max_size), + ]) + ), + normalize, + ]) + + if image_set in ['val', 'eval_debug', 'train_reg', 'test']: + + if os.environ.get("GFLOPS_DEBUG_SHILONG", False) == 'INFO': + print("Under debug mode for flops calculation only!!!!!!!!!!!!!!!!") + return T.Compose([ + T.ResizeDebug((1280, 800)), + normalize, + ]) + + return T.Compose([ + T.RandomResize([max(scales)], max_size=max_size), + normalize, + ]) + + raise ValueError(f'unknown {image_set}') + + +def load_part_imagenet_bbox(args): + try: + strong_aug = args.strong_aug + except: + strong_aug = False + train_transforms = make_coco_transforms("train", fix_size=args.fix_size, strong_aug=strong_aug, args=args) + train_loader, train_sampler = get_loader_sampler_bbox( + args, train_transforms, "train" + ) + + val_transforms = make_coco_transforms("val", fix_size=args.fix_size, strong_aug=False, args=args) + val_loader, _ = get_loader_sampler_bbox(args, val_transforms, "val") + test_loader, _ = get_loader_sampler_bbox(args, val_transforms, "test") + + return train_loader, train_sampler, val_loader, test_loader + + +PART_IMAGENET_BBOX = { + "normalize": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + }, + "loader": load_part_imagenet_bbox, + "input_dim": (3, 224, 224), + "colormap": COLORMAP, +} diff --git a/part_model/models/__init__.py b/part_model/models/__init__.py index 12aa38a..87d4614 100644 --- a/part_model/models/__init__.py +++ b/part_model/models/__init__.py @@ -9,6 +9,7 @@ from ..dataloader import DATASET_DICT from ..utils.image import get_seg_type from .bbox_model import BoundingBoxModel +from .dino_bbox_model import DinoBoundingBoxModel from .clean_mask_model import CleanMaskModel from .common import Normalize from .groundtruth_mask_model import GroundtruthMaskModel @@ -22,7 +23,6 @@ from .two_head_model import TwoHeadModel from .weighted_bbox_model import WeightedBBoxModel - def wrap_distributed(args, model): if args.distributed: model.cuda(args.gpu) @@ -61,6 +61,19 @@ def build_classifier(args): tokens = args.experiment.split("-") model_token = tokens[1] exp_tokens = tokens[2:] + + if model_token == "dino": + # this is wrong + # need to create a wrapper to predict classes from bounding boxes + # not just predict bouding boxes + # actually dino can be used to predict segmentation too + # TODO: need to try segmentation too + # model, criterion, postprocessors = build_model_main(args) + # pass dino_args instead of args? + model = DinoBoundingBoxModel(args) + + + print("=> building segmentation model...") segmenter = SEGM_BUILDER[args.seg_arch](args) @@ -143,7 +156,7 @@ def build_classifier(args): sum(p.numel() for p in segmenter.parameters() if p.requires_grad) / 1e6 ) - print(f"=> segmenter params (train/total): {nt_seg:.2f}M/{n_seg:.2f}M") + print(f"=> segmenter params (train/total): {nt_seg:.2f}M/{n_seg:.2f}M") else: print("=> building a normal classifier (no segmentation)") model.fc = nn.Linear(rep_dim, args.num_classes) diff --git a/part_model/utils/loss.py b/part_model/utils/loss.py index 3b5fe80..b999518 100644 --- a/part_model/utils/loss.py +++ b/part_model/utils/loss.py @@ -315,7 +315,107 @@ def get_train_criterion(args): elif "semi" in args.experiment: if "centroid" in args.experiment: train_criterion = SemiKeypointLoss(seg_const=args.seg_const_trn) + elif "dino" in args.experiment: + losses = ['labels', 'boxes', 'cardinality'] + from DINO.models.dino.matcher import build_matcher + import copy + + matcher = build_matcher(args) + + + # prepare weight dict + weight_dict = {'loss_ce': args.cls_loss_coef, 'loss_bbox': args.bbox_loss_coef} + weight_dict['loss_giou'] = args.giou_loss_coef + clean_weight_dict_wo_dn = copy.deepcopy(weight_dict) + + + # for DN training + if args.use_dn: + weight_dict['loss_ce_dn'] = args.cls_loss_coef + weight_dict['loss_bbox_dn'] = args.bbox_loss_coef + weight_dict['loss_giou_dn'] = args.giou_loss_coef + + if args.masks: + weight_dict["loss_mask"] = args.mask_loss_coef + weight_dict["loss_dice"] = args.dice_loss_coef + clean_weight_dict = copy.deepcopy(weight_dict) + + # TODO this is a hack + if args.aux_loss: + aux_weight_dict = {} + for i in range(args.dec_layers - 1): + aux_weight_dict.update({k + f'_{i}': v for k, v in clean_weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + if args.two_stage_type != 'no': + interm_weight_dict = {} + try: + no_interm_box_loss = args.no_interm_box_loss + except: + no_interm_box_loss = False + _coeff_weight_dict = { + 'loss_ce': 1.0, + 'loss_bbox': 1.0 if not no_interm_box_loss else 0.0, + 'loss_giou': 1.0 if not no_interm_box_loss else 0.0, + } + try: + interm_loss_coef = args.interm_loss_coef + except: + interm_loss_coef = 1.0 + interm_weight_dict.update({k + f'_interm': v * interm_loss_coef * _coeff_weight_dict[k] for k, v in clean_weight_dict_wo_dn.items()}) + weight_dict.update(interm_weight_dict) + + # criterion = SetCriterion(args.num_classes, matcher=matcher, weight_dict=weight_dict, + # focal_alpha=args.focal_alpha, losses=losses, + # ) + train_criterion = SemiBBOXLoss(args.num_classes, matcher=matcher, weight_dict=weight_dict, + focal_alpha=args.focal_alpha, losses=losses, seg_const=args.seg_const_trn) else: train_criterion = SemiSumLoss(seg_const=args.seg_const_trn) train_criterion = train_criterion.cuda(args.gpu) return criterion, train_criterion + + +from DINO.models.dino.dino import SetCriterion +class SemiBBOXLoss(SetCriterion): + # def __init__(self, seg_const: float = 0.5, reduction: str = "mean"): + def __init__(self, num_classes, matcher, weight_dict, focal_alpha, losses, seg_const: float = 0.5, reduction: str = "mean"): + # 'num_classes', 'matcher', 'weight_dict', 'focal_alpha', and 'losses' + super(SemiBBOXLoss, self).__init__(num_classes, matcher, weight_dict, focal_alpha, losses) + assert 0 <= seg_const <= 1 + self.seg_const = seg_const + self.reduction = reduction + self.weight_dict = weight_dict + + def forward( + self, + logits: Union[list, tuple], + dino_outputs: dict, + dino_targets: list, + targets: torch.Tensor, + return_indices=False): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + + return_indices: used for vis. if True, the layer0-5 indices will be returned as well. + + """ + # logits, seg_mask = logits + loss = 0 + if self.seg_const < 1: + clf_loss = F.cross_entropy(logits, targets, reduction="none") + loss += (1 - self.seg_const) * clf_loss + if self.seg_const > 0: + super(SemiBBOXLoss, self).weight_dict + # loss_dict = criterion(outputs, targets) + loss_dict = super().forward(dino_outputs, dino_targets, return_indices) + bbox_loss = sum(loss_dict[k] * self.weight_dict[k] for k in loss_dict.keys() if k in self.weight_dict) + # bbox_loss = super().forward(logits, targets, return_indices) + loss += self.seg_const * bbox_loss + if self.reduction == "mean": + return loss.mean() + + return loss