Skip to content

Commit

Permalink
Merge branch 'master' of github.com:chawins/adv-part-model
Browse files Browse the repository at this point in the history
  • Loading branch information
chawins committed Feb 14, 2023
2 parents a3c8f4b + 1b04806 commit 148019c
Show file tree
Hide file tree
Showing 22 changed files with 454 additions and 381 deletions.
5 changes: 3 additions & 2 deletions DINO/config/DINO/DINO_4scale_modified.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# lr = 0.0001
param_dict_type = 'default'
lr_backbone = 1e-05
lr_backbone = 1e-02
lr_backbone_names = ['backbone.0']
lr_linear_proj_names = ['reference_points', 'sampling_offsets']
lr_linear_proj_mult = 0.1
Expand Down Expand Up @@ -100,7 +100,8 @@
dn_box_noise_scale = 0.4
dn_label_noise_ratio = 0.5
embed_init_tgt = True
dn_labelbook_size = 91
dn_labelbook_size = 41
# dn_labelbook_size = 91

match_unstable_error = True

Expand Down
141 changes: 130 additions & 11 deletions DINO/datasets/partimagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ def __getitem__(self, idx):
image_id = self.ids[idx]
target = {'image_id': image_id, 'annotations': target}
img, target = self.prepare(img, target)


if self._transforms is not None:
img, target = self._transforms(img, target)
Expand All @@ -409,7 +410,7 @@ def __getitem__(self, idx):
if self.aux_target_hacks is not None:
for hack_runner in self.aux_target_hacks:
target, img = hack_runner(target, img=img)

# class_label = self.image_to_label[str(idx)]
return img, target

Expand Down Expand Up @@ -711,17 +712,113 @@ def get_aux_target_hacks_list(image_set, args):
return aux_target_hacks_list


# TODO: clean imports
import PIL
from torchvision.transforms import RandomResizedCrop
from DINO.datasets.transforms import crop, resize

class RandomCrop(object):
def __init__(
self,
scale=(0.08, 1.0),
ratio=(0.75, 1.333333),
) -> None:
self.scale = scale
self.ratio = ratio

def __call__(self, img: PIL.Image.Image, target: dict):
region = RandomResizedCrop.get_params(img, self.scale, self.ratio)
img, target = crop(img, target, region)
return img, target


class Resize(object):
def __init__(self, size) -> None:
self.size: int = size

def __call__(self, img: PIL.Image.Image, target: dict):
img, target = resize(img, target, self.size)
return img, target



def build(image_set, args):
# TODO: add as arg
root = Path('/data/shared/PartImageNet/PartBoxSegmentations')
# root = Path('/global/scratch/users/nabeel126/PartImageNet/PartBoxSegmentations')

PATHS = {
"train": (root / "train", root / "image_labels" / 'train.json', root / "annotations" / 'train.json'),
"val": (root / "val", root / "image_labels" / 'val.json', root / "annotations" / 'val.json'),
"test": (root / "test", root / "image_labels" / 'test.json', root / "annotations" / 'test.json' ),
}
root = Path(args.bbox_label_dir)

if args.use_imagenet_classes:
if args.group_parts:
PATHS = {
"train": (
root / "train",
root / "image_labels" / "imagenet" / "grouped" / "train.json",
root / "annotations" / "imagenet" / "grouped" / "train.json",
),
"val": (
root / "val",
root / "image_labels" / "imagenet" / "grouped" / "val.json",
root / "annotations" / "imagenet" / "grouped" / "val.json",
),
"test": (
root / "test",
root / "image_labels" / "imagenet" / "grouped" / "test.json",
root / "annotations" / "imagenet" / "grouped" / "test.json",
),
}
else:
PATHS = {
"train": (
root / "train",
root / "image_labels" / "imagenet" / "all" / "train.json",
root / "annotations" / "imagenet" / "all" / "train.json",
),
"val": (
root / "val",
root / "image_labels" / "imagenet" / "all" / "val.json",
root / "annotations" / "imagenet" / "all" / "val.json",
),
"test": (
root / "test",
root / "image_labels" / "imagenet" / "all" / "test.json",
root / "annotations" / "imagenet" / "all" / "test.json",
),
}
else:
if args.group_parts:
PATHS = {
"train": (
root / "train",
root / "image_labels" / "partimagenet" / "grouped" / "train.json",
root / "annotations" / "partimagenet" / "grouped" / "train.json",
),
"val": (
root / "val",
root / "image_labels" / "partimagenet" / "grouped" / "val.json",
root / "annotations" / "partimagenet" / "grouped" / "val.json",
),
"test": (
root / "test",
root / "image_labels" / "partimagenet" / "grouped" / "test.json",
root / "annotations" / "partimagenet" / "grouped" / "test.json",
),
}
else:
PATHS = {
"train": (
root / "train",
root / "image_labels" / "partimagenet" / "all" / "train.json",
root / "annotations" / "partimagenet" / "all" / "train.json",
),
"val": (
root / "val",
root / "image_labels" / "partimagenet" / "all" / "val.json",
root / "annotations" / "partimagenet" / "all" / "val.json",
),
"test": (
root / "test",
root / "image_labels" / "partimagenet" / "all" / "test.json",
root / "annotations" / "partimagenet" / "all" / "test.json",
),
}

img_folder, class_label_file, ann_file = PATHS[image_set]

Expand All @@ -738,7 +835,29 @@ def build(image_set, args):
except:
strong_aug = False

transforms = make_coco_transforms(image_set, fix_size=args.fix_size, strong_aug=strong_aug, args=args)
# transforms = make_coco_transforms(image_set, fix_size=args.fix_size, strong_aug=strong_aug, args=args)

img_size = 224
if image_set == 'train':
transforms = T.Compose(
[
RandomCrop(),
Resize([img_size, img_size]),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize([0, 0, 0], [1, 1, 1]),
]
)
elif image_set in ['val', 'eval_debug', 'train_reg', 'test']:
transforms = T.Compose(
[
Resize([int(img_size * 256 / 224), int(img_size * 256 / 224)]),
T.CenterCrop([img_size, img_size]),
T.ToTensor(),
T.Normalize([0, 0, 0], [1, 1, 1]),
]
)


dataset = PartImageNetBBOXDataset(
img_folder,
Expand Down
5 changes: 3 additions & 2 deletions DINO/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Optional
from DINO.util.get_param_dicts import get_param_dict


# sys.path.insert(0, '..')


from DINO.util.logger import setup_logger
Expand Down Expand Up @@ -78,7 +78,8 @@ def get_args_parser():
parser.add_argument('--rank', default=0, type=int,
help='number of distributed processes')
parser.add_argument("--local_rank", type=int, help='local rank for DistributedDataParallel')

parser.add_argument('--amp', action='store_true',
help="Train with mixed precision")

return parser

Expand Down
5 changes: 4 additions & 1 deletion DINO/models/dino/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,12 @@ def __init__(self, name: str,
batch_norm=FrozenBatchNorm2d,
):
if name in ['resnet18', 'resnet34', 'resnet50', 'resnet101']:
# backbone = getattr(torchvision.models, name)(
# replace_stride_with_dilation=[False, False, dilation],
# pretrained=is_main_process(), norm_layer=batch_norm)
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(), norm_layer=batch_norm)
pretrained=is_main_process())
else:
raise NotImplementedError("Why you can get here with name {}".format(name))
# num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
Expand Down
31 changes: 15 additions & 16 deletions autoattack_modified/autoattack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from .flags import SET_MASK
from .other_utils import Logger, get_pred
from .other_utils import Logger, get_pred, mask_kwargs


class AutoAttack:
Expand Down Expand Up @@ -164,16 +164,16 @@ def __init__(
# if version in ['standard', 'plus', 'rand']:
self.set_version(version)

def get_logits(self, x):
def get_logits(self, x, **kwargs):
if not self.is_tf_model:
return self.model(x)
return self.model(x, **kwargs)
else:
return self.model.predict(x)

def get_seed(self):
return time.time() if self.seed is None else self.seed

def run_standard_evaluation(self, x_orig, y_orig, bs=250):
def run_standard_evaluation(self, x_orig, y_orig, bs=250, **kwargs_orig):
if self.verbose:
print(
"using {} version including {}".format(
Expand All @@ -200,7 +200,7 @@ def run_standard_evaluation(self, x_orig, y_orig, bs=250):
# DEBUG
# output = self.get_logits(x)
# correct_batch = y.eq(output.max(dim=1)[1])
output = get_pred(self.get_logits(x))
output = get_pred(self.get_logits(x, **kwargs_orig))
correct_batch = y.eq(output)
robust_flags[start_idx:end_idx] = correct_batch.detach().to(
robust_flags.device
Expand All @@ -215,6 +215,7 @@ def run_standard_evaluation(self, x_orig, y_orig, bs=250):

x_adv = x_orig.clone().detach()
startt = time.time()

for attack in self.attacks_to_run:
# item() is super important as pytorch int division uses floor rounding
num_robust = torch.sum(robust_flags).item()
Expand All @@ -238,6 +239,8 @@ def run_standard_evaluation(self, x_orig, y_orig, bs=250):
x = x_orig[batch_datapoint_idcs, :].clone().to(self.device)
y = y_orig[batch_datapoint_idcs].clone().to(self.device)

kwargs = mask_kwargs(kwargs_orig, batch_datapoint_idcs)

# DEBUG: set mask for IN-9 dataset experiment
if SET_MASK:
# self.model.set_mask(batch_datapoint_idcs)
Expand All @@ -252,38 +255,34 @@ def run_standard_evaluation(self, x_orig, y_orig, bs=250):
# apgd on cross-entropy loss
self.apgd.loss = "ce"
self.apgd.seed = self.get_seed()
adv_curr = self.apgd.perturb(x, y) # cheap=True

adv_curr = self.apgd.perturb(x, y, **kwargs) # cheap=True
elif attack == "apgd-dlr":
# apgd on dlr loss
self.apgd.loss = "dlr"
self.apgd.seed = self.get_seed()
adv_curr = self.apgd.perturb(x, y) # cheap=True
adv_curr = self.apgd.perturb(x, y, **kwargs) # cheap=True

elif attack == "fab":
# fab
self.fab.targeted = False
self.fab.seed = self.get_seed()
adv_curr = self.fab.perturb(x, y)
adv_curr = self.fab.perturb(x, y, **kwargs) # cheap=True

elif attack == "square":
# square
self.square.seed = self.get_seed()
adv_curr = self.square.perturb(x, y)
adv_curr = self.square.perturb(x, y, **kwargs) # cheap=True

elif attack == "apgd-t":
# targeted apgd
self.apgd_targeted.seed = self.get_seed()
adv_curr = self.apgd_targeted.perturb(
x, y
) # cheap=True

adv_curr = self.apgd_targeted.perturb(x, y, **kwargs) # cheap=True
elif attack == "fab-t":
# fab targeted
self.fab.targeted = True
self.fab.n_restarts = 1
self.fab.seed = self.get_seed()
adv_curr = self.fab.perturb(x, y)
adv_curr = self.fab.perturb(x, y, **kwargs) # cheap=True

else:
raise ValueError("Attack not supported")
Expand All @@ -295,7 +294,7 @@ def run_standard_evaluation(self, x_orig, y_orig, bs=250):
# DEBUG
# output = self.get_logits(adv_curr)
# false_batch = ~y.eq(output.max(dim=1)[1]).to(robust_flags.device)
output = get_pred(self.get_logits(adv_curr))
output = get_pred(self.get_logits(adv_curr, **kwargs))
false_batch = ~y.eq(output).to(robust_flags.device)
non_robust_lin_idcs = batch_datapoint_idcs[false_batch]
robust_flags[non_robust_lin_idcs] = False
Expand Down
Loading

0 comments on commit 148019c

Please sign in to comment.