From f4e1cdd9ab413b99d5dbf577e4bd32a75b889720 Mon Sep 17 00:00:00 2001 From: fra31 Date: Tue, 5 Nov 2024 11:43:54 +0000 Subject: [PATCH 1/3] add non-zero gradient versions for Amini2024MeanSparse models --- .../architectures/Meansparse_ra_wrn_70_16.py | 931 ++++++++++++++++ .../architectures/Meansparse_swin_L.py | 995 ++++++++++++++++++ .../architectures/Meansparse_wrn_70_16.py | 402 +++++++ .../architectures/Meansparse_wrn_94_16.py | 388 +++++++ .../architectures/sparsified_model.py | 148 ++- robustbench/model_zoo/cifar10.py | 21 +- robustbench/model_zoo/cifar100.py | 5 + robustbench/model_zoo/imagenet.py | 21 +- robustbench/utils.py | 2 + 9 files changed, 2861 insertions(+), 52 deletions(-) create mode 100644 robustbench/model_zoo/architectures/Meansparse_ra_wrn_70_16.py create mode 100644 robustbench/model_zoo/architectures/Meansparse_swin_L.py create mode 100644 robustbench/model_zoo/architectures/Meansparse_wrn_70_16.py create mode 100644 robustbench/model_zoo/architectures/Meansparse_wrn_94_16.py diff --git a/robustbench/model_zoo/architectures/Meansparse_ra_wrn_70_16.py b/robustbench/model_zoo/architectures/Meansparse_ra_wrn_70_16.py new file mode 100644 index 0000000..927098d --- /dev/null +++ b/robustbench/model_zoo/architectures/Meansparse_ra_wrn_70_16.py @@ -0,0 +1,931 @@ +# Code adapted from https://github.com/wzekai99/DM-Improves-AT +from typing import ( + Tuple, + Optional, + Callable, + Any, + Callable, + Optional, + Tuple, + # List, # TODO: for python<3.9 one needs to use `List[]` instead of `list[]`. +) +import math +import torch +from torch import nn, Tensor +import torch.nn.functional as F +from torchvision.ops.misc import SqueezeExcitation +from collections import OrderedDict +from torchvision.ops.misc import Conv2dNormActivation, SqueezeExcitation + +from robustbench.model_zoo.architectures.dm_wide_resnet import CIFAR10_MEAN, CIFAR10_STD + +IMAGENET_MEAN = [0.485, 0.456, 0.406] +IMAGENET_STD = [0.229, 0.224, 0.225] + +INPLACE_ACTIVATIONS = [nn.ReLU] +NORMALIZATIONS = [nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm, nn.LayerNorm] + + +# Custom autograd function for forward and backward modification +# class MeanSparseFunction2D(torch.autograd.Function): +# @staticmethod +# def forward(ctx, input, bias, crop, threshold): +# # Save context variables for backward computation if needed +# ctx.save_for_backward(input, bias, crop, threshold) + +# # Forward computation (given in the question) +# if threshold == 0: +# output = input +# else: +# diff = input - bias +# output = torch.where(torch.abs(diff) < crop, bias, input) + +# return output + +# @staticmethod +# def backward(ctx, grad_output): +# # For backward, we want output = input, so we pass grad_output as-is. +# input, bias, crop, threshold = ctx.saved_tensors + +# # Here we assume output = input in backward, so gradient is unchanged +# grad_input = grad_output +# return grad_input, None, None, None # Other inputs (bias, crop, threshold) have no gradients + +# # Define the MeanSparse module with modified backward behavior +# class MeanSparse(nn.Module): +# def __init__(self, in_planes): +# super(MeanSparse, self).__init__() + +# self.register_buffer('running_mean', torch.zeros(in_planes)) +# self.register_buffer('running_var', torch.zeros(in_planes)) + +# self.register_buffer('threshold', torch.tensor(0.0)) +# self.register_buffer('flag_update_statistics', torch.tensor(0)) +# self.register_buffer('batch_num', torch.tensor(0.0)) + +# def forward(self, input): +# if self.flag_update_statistics: +# # Calculate running mean and variance over batch, height, and width dimensions +# self.running_mean += (torch.mean(input.detach().clone(), dim=(0, 2, 3)) / self.batch_num) +# self.running_var += (torch.var(input.detach().clone(), dim=(0, 2, 3)) / self.batch_num) + +# bias = self.running_mean.view(1, self.running_mean.shape[0], 1, 1) +# crop = self.threshold * torch.sqrt(self.running_var).view(1, self.running_var.shape[0], 1, 1) + +# # Use the custom autograd function for forward and backward passes +# output = MeanSparseFunction2D.apply(input, bias, crop, self.threshold) +# return output + +### original +class MeanSparse(nn.Module): + def __init__(self, in_planes): + super(MeanSparse, self).__init__() + + self.register_buffer('running_mean', torch.zeros(in_planes)) + self.register_buffer('running_var', torch.zeros(in_planes)) + + self.register_buffer('threshold', torch.tensor(0.0)) + + self.register_buffer('flag_update_statistics', torch.tensor(0)) + self.register_buffer('batch_num', torch.tensor(0.0)) + + def forward(self, input): + + if self.flag_update_statistics: + self.running_mean += (torch.mean(input.detach().clone(), dim=(0, 2, 3))/self.batch_num) + self.running_var += (torch.var(input.detach().clone(), dim=(0, 2, 3))/self.batch_num) + + bias = self.running_mean.view(1, self.running_mean.shape[0], 1, 1) + crop = self.threshold * torch.sqrt(self.running_var).view(1, self.running_var.shape[0], 1, 1) + + diff = input - bias + + if self.threshold == 0: + output = input + else: + output = torch.where(torch.abs(diff) < crop, bias*torch.ones_like(input), input) + + return output + +def normalize_fn(tensor, mean, std): + mean = mean[None, :, None, None] + std = std[None, :, None, None] + return tensor.sub(mean).div(std) + +class NormalizeByChannelMeanStd(nn.Module): + def __init__(self, mean, std): + super(NormalizeByChannelMeanStd, self).__init__() + if not isinstance(mean, torch.Tensor): + mean = torch.tensor(mean) + if not isinstance(std, torch.Tensor): + std = torch.tensor(std) + self.register_buffer("mean", mean) + self.register_buffer("std", std) + + def forward(self, tensor): + return normalize_fn(tensor, self.mean, self.std) + + def extra_repr(self): + return "mean={}, std={}".format(self.mean, self.std) + + +class _Block(nn.Module): + def __init__( + self, + in_planes, + out_planes, + stride, + groups, + activation_fn=nn.ReLU, + se_ratio=None, + se_activation=nn.ReLU, + se_order=1, + ): + super().__init__() + self.batchnorm_0 = nn.BatchNorm2d(in_planes, momentum=0.01) + self.meansparse_0 = MeanSparse(in_planes) + self.relu_0 = activation_fn(inplace=True) + self.conv_0 = nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + groups=groups, + padding=0, + bias=False, + ) + self.batchnorm_1 = nn.BatchNorm2d(out_planes, momentum=0.01) + self.meansparse_1 = MeanSparse(out_planes) + self.relu_1 = activation_fn(inplace=True) + self.conv_1 = nn.Conv2d( + out_planes, + out_planes, + kernel_size=3, + stride=1, + groups=groups, + padding=1, + bias=False, + ) + self.has_shortcut = in_planes != out_planes + if self.has_shortcut: + self.shortcut = nn.Conv2d( + in_planes, + out_planes, + kernel_size=1, + stride=stride, + padding=0, + bias=False, + ) + else: + self.shortcut = None + self._stride = stride + + self.se = None + if se_ratio: + assert se_activation is not None + width_se_out = int(round(se_ratio * out_planes)) + self.se = SqueezeExcitation( + input_channels=out_planes, + squeeze_channels=width_se_out, + activation=se_activation, + ) + self.se_order = se_order + self.meansparse_2 = MeanSparse(out_planes) + + def forward(self, x): + if self.has_shortcut: + x = self.relu_0(self.meansparse_0(self.batchnorm_0(x))) + else: + out = self.relu_0(self.meansparse_0(self.batchnorm_0(x))) + v = x if self.has_shortcut else out + if self._stride == 1: + v = F.pad(v, (1, 1, 1, 1)) + elif self._stride == 2: + v = F.pad(v, (0, 1, 0, 1)) + else: + raise ValueError("Unsupported `stride`.") + out = self.conv_0(v) + + if self.se and self.se_order == 1: + out = self.se(out) + + out = self.relu_1(self.meansparse_1(self.batchnorm_1(out))) + out = self.conv_1(out) + + if self.se and self.se_order == 2: + out = self.se(out) + + out = torch.add(self.shortcut(x) if self.has_shortcut else x, out) + out = self.meansparse_2(out) + return out + + +class _BlockGroup(nn.Module): + def __init__( + self, + num_blocks, + in_planes, + out_planes, + stride, + groups, + activation_fn=nn.ReLU, + se_ratio=None, + se_activation=nn.ReLU, + se_order=1, + ): + super().__init__() + block = [] + for i in range(num_blocks): + block.append( + _Block( + i == 0 and in_planes or out_planes, + out_planes, + i == 0 and stride or 1, + groups=groups, + activation_fn=activation_fn, + se_ratio=se_ratio, + se_activation=se_activation, + se_order=se_order, + ) + ) + self.block = nn.Sequential(*block) + + def forward(self, x): + return self.block(x) + + +class NormalizedWideResNet(nn.Module): + def __init__( + self, + mean: Tuple[float], + std: Tuple[float], + stem_width: int, + depth: Tuple[int], + stage_width: Tuple[int], + groups: Tuple[int], + activation_fn: nn.Module, + se_ratio: Optional[float], + se_activation: Optional[Callable[..., nn.Module]], + se_order: Optional[int], + num_classes: int = 10, + padding: int = 0, + num_input_channels: int = 3, + ): + super().__init__() + self.mean = torch.tensor(mean).view(num_input_channels, 1, 1) + self.std = torch.tensor(std).view(num_input_channels, 1, 1) + self.mean_cuda = None + self.std_cuda = None + self.padding = padding + num_channels = [stem_width, *stage_width] + self.init_conv = nn.Conv2d( + num_input_channels, + num_channels[0], + kernel_size=3, + stride=1, + padding=1, + bias=False, + ) + self.layer = nn.Sequential( + _BlockGroup( + depth[0], + num_channels[0], + num_channels[1], + 1, + groups=groups[0], + activation_fn=activation_fn, + se_ratio=se_ratio, + se_activation=se_activation, + se_order=se_order, + ), + _BlockGroup( + depth[1], + num_channels[1], + num_channels[2], + 2, + groups=groups[1], + activation_fn=activation_fn, + se_ratio=se_ratio, + se_activation=se_activation, + se_order=se_order, + ), + _BlockGroup( + depth[2], + num_channels[2], + num_channels[3], + 2, + groups=groups[2], + activation_fn=activation_fn, + se_ratio=se_ratio, + se_activation=se_activation, + se_order=se_order, + ), + ) + self.batchnorm = nn.BatchNorm2d(num_channels[3], momentum=0.01) + self.meansparse_end = MeanSparse(num_channels[3]) + self.relu = activation_fn(inplace=True) + self.logits = nn.Linear(num_channels[3], num_classes) + self.num_channels = num_channels[3] + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2.0 / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.bias.data.zero_() + + def forward(self, x): + if self.padding > 0: + x = F.pad(x, (self.padding,) * 4) + if x.is_cuda: + #if self.mean_cuda is None + self.mean_cuda = self.mean.to(x.device) # TODO: improve this. + self.std_cuda = self.std.to(x.device) + out = (x - self.mean_cuda) / self.std_cuda + else: + out = (x - self.mean) / self.std + + out = self.init_conv(out) + out = self.layer(out) + out = self.relu(self.meansparse_end(self.batchnorm(out))) + out = F.avg_pool2d(out, 8) + out = out.view(-1, self.num_channels) + return self.logits(out) + + +class NormActivationConv(torch.nn.Sequential): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: Optional[int] = None, + groups: int = 1, + norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + dilation: int = 1, + inplace: Optional[bool] = True, + bias: Optional[bool] = None, + conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d, + ) -> None: + if padding is None: + padding = (kernel_size - 1) // 2 * dilation + if bias is None: + bias = norm_layer is None + + layers = list() + + if norm_layer is not None: + layers.append(norm_layer(in_channels)) + + if activation_layer is not None: + params = {} if inplace is None else {"inplace": inplace} + layers.append(activation_layer(**params)) + + layers.append( + conv_layer( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + ) + super().__init__(*layers) + self.out_channels = out_channels + + +class NormActivationConv2d(NormActivationConv): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 1, + padding: Optional[int] = None, + groups: int = 1, + norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, + activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, + dilation: int = 1, + inplace: Optional[bool] = True, + bias: Optional[bool] = None, + ) -> None: + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + groups, + norm_layer, + activation_layer, + dilation, + inplace, + bias, + torch.nn.Conv2d, + ) + + +class BottleneckTransform(nn.Sequential): + "Transformation in a Bottleneck: 1x1, kxk (k=3, 5, 7, ...) [+SE], 1x1" + "Supported archs: [preact] [norm func+num] [act func+num] [conv kernel]" + + def __init__( + self, + width_in: int, + width_out: int, + kernel: int, + stride: int, + dilation: int, + norm_layer: list[Callable[..., nn.Module]], + activation_layer: list[Callable[..., nn.Module]], + group_width: int, + bottleneck_multiplier: float, + se_ratio: Optional[float], + se_activation: Optional[Callable[..., nn.Module]], + ConvBlock: Callable[..., nn.Module], + ): + # compute transform params + w_b = int( + round(width_out * bottleneck_multiplier) + ) # bottleneck_multiplier > 1 for inverted bottleneck + g = w_b // group_width + assert len(norm_layer) == 3 + assert len(activation_layer) == 3 + assert g > 0, f"Group convolution groups {g} should be greater than 0." + assert ( + w_b % g == 0 + ), f"Convolution input channels {w_b} is not divisible by {g} groups." + + layers: OrderedDict[str, nn.Module] = OrderedDict() + layers["a"] = ConvBlock( + width_in, + w_b, + kernel_size=1, + stride=1, + norm_layer=norm_layer[0], + activation_layer=activation_layer[0], + inplace=True if activation_layer[0] in INPLACE_ACTIVATIONS else None, + ) + + layers["b"] = ConvBlock( + w_b, + w_b, + kernel, + stride=stride, + groups=g, + dilation=dilation, + norm_layer=norm_layer[1], + activation_layer=activation_layer[1], + inplace=True if activation_layer[1] in INPLACE_ACTIVATIONS else None, + ) + + if se_ratio: + assert se_activation is not None + width_se_out = int(round(se_ratio * width_in)) + layers["se"] = SqueezeExcitation( + input_channels=w_b, + squeeze_channels=width_se_out, + activation=se_activation, + ) + if ConvBlock == Conv2dNormActivation: + layers["c"] = ConvBlock( + w_b, + width_out, + kernel_size=1, + stride=1, + norm_layer=norm_layer[2], + activation_layer=None, + ) + else: + layers["c"] = ConvBlock( + w_b, + width_out, + kernel_size=1, + stride=1, + norm_layer=norm_layer[2], + activation_layer=activation_layer[2], + inplace=True if activation_layer[2] in INPLACE_ACTIVATIONS else None, + ) + + super().__init__(layers) + + +class BottleneckBlock(nn.Module): + """Bottleneck block x + F(x), where F = bottleneck transform""" + + def __init__( + self, + width_in: int, + width_out: int, + kernel: int, + stride: int, + dilation: int, + norm_layer: list[Callable[..., nn.Module]], + activation_layer: list[Callable[..., nn.Module]], + group_width: int, + bottleneck_multiplier: float, + se_ratio: Optional[float], + se_activation: Optional[Callable[..., nn.Module]], + ConvBlock: Callable[..., nn.Module], + downsample_norm: Callable[..., nn.Module], + ) -> None: + super().__init__() + + # projection on skip connection if shape changes + self.proj = None + should_proj = (width_in != width_out) or (stride != 1) + if should_proj: + if ConvBlock == Conv2dNormActivation: + self.proj = ConvBlock( + width_in, + width_out, + kernel_size=1, + stride=stride, + norm_layer=downsample_norm, + activation_layer=None, + ) + elif ConvBlock == NormActivationConv2d: + self.proj = ConvBlock( + width_in, + width_out, + kernel_size=1, + stride=stride, + norm_layer=None, + activation_layer=None, + bias=False, + ) + + self.F = BottleneckTransform( + width_in, + width_out, + kernel, + stride, + dilation, + norm_layer, + activation_layer, + group_width, + bottleneck_multiplier, + se_ratio, + se_activation, + ConvBlock, + ) + + if ConvBlock == Conv2dNormActivation: + if activation_layer[2] is not None: + if activation_layer[2] in INPLACE_ACTIVATIONS: + self.last_activation = activation_layer[2](inplace=True) + else: + self.last_activation = activation_layer[2]() + else: + self.last_activation = None + + def forward(self, x: Tensor) -> Tensor: + if self.proj is not None: + x = self.proj(x) + self.F(x) + else: + x = x + self.F(x) + + if self.last_activation is not None: + return self.last_activation(x) + else: + return x + + +class Stage(nn.Sequential): + """Stage is a sequence of blocks with the same output shape. Downsampling block is the first in each stage""" + + """Options: stage numbers, stage depth, dense connection""" + + def __init__( + self, + width_in: int, + width_out: int, + kernel: int, + stride: int, + dilation: int, + norm_layer: list[Callable[..., nn.Module]], + activation_layer: list[Callable[..., nn.Module]], + group_width: int, + bottleneck_multiplier: float, + se_ratio: Optional[float], + se_activation: Optional[Callable[..., nn.Module]], + ConvBlock: Callable[..., nn.Module], + downsample_norm: Callable[..., nn.Module], + depth: int, + dense_ratio: Optional[float], + block_constructor: Callable[..., nn.Module] = BottleneckBlock, + stage_index: int = 0, + ): + super().__init__() + self.dense_ratio = dense_ratio + for i in range(depth): + block = block_constructor( + width_in if i == 0 else width_out, + width_out, + kernel, + stride if i == 0 else 1, + dilation, + norm_layer, + activation_layer, + group_width, + bottleneck_multiplier, + se_ratio, + se_activation, + ConvBlock, + downsample_norm, + ) + + self.add_module(f"stage{stage_index}-block{i}", block) + + def forward(self, x: Tensor) -> Tensor: + if self.dense_ratio: + assert self.dense_ratio > 0 + features = list([x]) + for i, module in enumerate(self): + input = features[-1] + if i > 2: + for j in range(self.dense_ratio): + if j + 4 > len(features): + break + input = input + features[-3 - j] + x = module(input) + features.append(x) + + # output of each stage is also densely connected + x = features[-1] + for k in range(self.dense_ratio): + if k + 4 > len(features): + break + x = x + features[-3 - k] + else: + for module in self: + x = module(x) + return x + + +class Stem(nn.Module): + """Stem for ImageNet: kxk, BN, ReLU[, MaxPool]""" + + def __init__( + self, + width_in: int, + width_out: int, + kernel_size: int, + norm_layer: Callable[..., nn.Module], + activation_layer: Callable[..., nn.Module], + downsample_factor: int, + patch_size: Optional[int], + ) -> None: + super().__init__() + + assert downsample_factor % 2 == 0 and downsample_factor >= 2 + layers: OrderedDict[str, nn.Module] = OrderedDict() + + stride = 2 + if patch_size: + kernel_size = patch_size + stride = patch_size + + layers["stem"] = Conv2dNormActivation( + width_in, + width_out, + kernel_size=kernel_size, + stride=stride, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + + if not patch_size and downsample_factor // 2 > 1: + layers["stem_downsample"] = nn.MaxPool2d( + kernel_size=3, stride=downsample_factor // 2, padding=1 + ) + + self.stem = nn.Sequential(layers) + + def forward(self, x: Tensor) -> Tensor: + return self.stem(x) + + +class ConfigurableModel(nn.Module): + def __init__( + self, + stage_widths: list[int], # output width of each stage + kernel: int, # kernel for non-pointwise conv + strides: list[int], # stride in each stage + dilation: int, # dilation for non-pointwise conv + norm_layer: list[ + Callable[..., nn.Module] + ], # norm layer in each block, length 3 for bottleneck + activation_layer: list[ + Callable[..., nn.Module] + ], # activation layer in each block, length 3 for bottleneck + group_widths: list[ + int + ], # group conv width in each stage, groups = width_out * bottleneck_multiplier // group_width + bottleneck_multipliers: list[ + float + ], # bottleneck_multiplier > 1 for inverted bottleneck + downsample_norm: Callable[ + ..., nn.Module + ], # norm layer in downsampling shortcut + depths: list[int], # depth in each stage + dense_ratio: Optional[float], # dense connection ratio + stem_type: Callable[..., nn.Module], # stem stage + stem_width: int, # stem stage output width + stem_kernel: int, # stem stage kernel size + stem_downsample_factor: int, # downscale factor in the stem stage, if > 2, a maxpool layer is added + stem_patch_size: Optional[int], # patchify stem patch size + block_constructor: Callable[ + ..., nn.Module + ] = BottleneckBlock, # block type in body stage + ConvBlock: Callable[ + ..., nn.Module + ] = Conv2dNormActivation, # block with different "conv-norm-act" order + se_ratio: Optional[float] = None, # squeeze and excitation (SE) ratio + se_activation: Optional[ + Callable[..., nn.Module] + ] = None, # activation layer in SE block + weight_init_type: str = "resnet", # initialization type + num_classes: int = 1000, # num of classification classes + ) -> None: + super().__init__() + + num_stages = len(stage_widths) + assert len(strides) == num_stages + assert len(bottleneck_multipliers) == num_stages + assert len(group_widths) == num_stages + assert len(norm_layer) == len(activation_layer) + assert ( + sum([i % 8 for i in stage_widths]) == 0 + ), f"Stage width {stage_widths} non-divisible by 8" + + # stem + self.stem = stem_type( + width_in=3, + width_out=stem_width, + kernel_size=stem_kernel, + norm_layer=nn.BatchNorm2d, + activation_layer=nn.ReLU, + downsample_factor=stem_downsample_factor, + patch_size=stem_patch_size, + ) + + # stages + current_width = stem_width + stages = list() + for i, ( + width_out, + stride, + group_width, + bottleneck_multiplier, + depth, + ) in enumerate( + zip(stage_widths, strides, group_widths, bottleneck_multipliers, depths) + ): + stages.append( + ( + f"stage{i + 1}", + Stage( + current_width, + width_out, + kernel, + stride, + dilation, + norm_layer, + activation_layer, + group_width, + bottleneck_multiplier, + se_ratio, + se_activation, + ConvBlock, + downsample_norm, + depth, + dense_ratio, + block_constructor, + stage_index=i + 1, + ), + ) + ) + + current_width = width_out + + self.stages = nn.Sequential(OrderedDict(stages)) + + # classification head + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(current_width, num_classes) + + # initialization + if weight_init_type == "resnet": + for m in self.modules(): + if isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + nn.init.normal_(m.weight, mean=0.0, std=math.sqrt(2.0 / fan_out)) + # nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif m in NORMALIZATIONS: + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, mean=0.0, std=0.01) + nn.init.zeros_(m.bias) + else: + raise NotImplementedError + + def forward(self, x: Tensor) -> Tensor: + x = self.stem(x) + x = self.stages(x) + + x = self.avgpool(x) + x = x.flatten(start_dim=1) + x = self.fc(x) + + return x + + +class NormalizedConfigurableModel(ConfigurableModel): + def __init__(self, mean: list[float], std: list[float], **kwargs: Any): + super().__init__(**kwargs) + + assert len(mean) == len(std) + self.normalization = NormalizeByChannelMeanStd(mean=mean, std=std) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.normalization(x) + + x = self.stem(x) + x = self.stages(x) + + x = self.avgpool(x) + x = x.flatten(start_dim=1) + x = self.fc(x) + + return x + + +def get_model(model_name): + if model_name == "ra_wrn70_16": + model = NormalizedWideResNet( + CIFAR10_MEAN, + CIFAR10_STD, + stem_width=96, + depth=[30, 31, 10], + stage_width=[216, 432, 864], + groups=[1, 1, 1], + activation_fn=torch.nn.SiLU, + se_ratio=0.25, + se_activation=torch.nn.ReLU, + se_order=2, + num_classes=10, + ) + elif model_name == "ra_wrn101_2": + model = NormalizedConfigurableModel( + mean=IMAGENET_MEAN, + std=IMAGENET_STD, + stage_widths=[512, 1024, 2016, 4032], + kernel=3, + strides=[2, 2, 2, 2], + dilation=1, + norm_layer=[nn.Identity, nn.BatchNorm2d, nn.BatchNorm2d], + activation_layer=[nn.SiLU] * 3, + group_widths=[64, 128, 252, 504], + bottleneck_multipliers=[0.25] * 4, + downsample_norm=nn.BatchNorm2d, + depths=[7, 11, 18, 1], + dense_ratio=None, + stem_type=Stem, + stem_width=96, + stem_kernel=7, + stem_downsample_factor=2, + stem_patch_size=None, + block_constructor=BottleneckBlock, + ConvBlock=Conv2dNormActivation, + se_ratio=0.25, + se_activation=nn.ReLU, + ) + else: + raise ValueError(f"Unknown model name: {model_name}.") + + return model + + +if __name__ == "__main__": + model = get_model("ra_wrn70_16") + model.cuda() + x = torch.rand([10, 3, 32, 32]) + with torch.no_grad(): + print(model(x.cuda()).shape) + + model = get_model("ra_wrn101_2") + #print(model.state_dict().keys()) + model.cuda() + x = torch.rand([10, 3, 224, 224]) + with torch.no_grad(): + print(model(x.cuda()).shape) \ No newline at end of file diff --git a/robustbench/model_zoo/architectures/Meansparse_swin_L.py b/robustbench/model_zoo/architectures/Meansparse_swin_L.py new file mode 100644 index 0000000..87ad338 --- /dev/null +++ b/robustbench/model_zoo/architectures/Meansparse_swin_L.py @@ -0,0 +1,995 @@ +""" Swin Transformer +A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` + - https://arxiv.org/pdf/2103.14030 + +Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below + +S3 (AutoFormerV2, https://arxiv.org/abs/2111.14725) Swin weights from + - https://github.com/microsoft/Cream/tree/main/AutoFormerV2 + +Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman +""" +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- +import logging +import math +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import PatchEmbed, DropPath, ClassifierHead, to_2tuple, to_ntuple, trunc_normal_, \ + _assert, use_fused_attn, resize_rel_pos_bias_table, resample_patch_embed, ndgrid +from timm.models._builder import build_model_with_cfg +from timm.models._features_fx import register_notrace_function +from timm.models._manipulate import checkpoint_seq, named_apply +from timm.models._registry import generate_default_cfgs, register_model, register_model_deprecations +from timm.models.vision_transformer import get_init_weights_vit +from collections import OrderedDict +from torch import Tensor + + +# Custom autograd function to modify backward pass +# class MeanSparseFunction(torch.autograd.Function): +# @staticmethod +# def forward(ctx, input, bias, crop, threshold): +# # Save context variables for backward computation if needed +# ctx.save_for_backward(input, bias, crop, threshold) + +# # Forward computation (as given in the question) +# if threshold == 0: +# output = input +# else: +# diff = input - bias +# output = torch.where(torch.abs(diff) < crop, bias * torch.ones_like(input), input) + +# return output + +# @staticmethod +# def backward(ctx, grad_output): +# # For backward, we want output = input, so we pass grad_output as-is. +# input, bias, crop, threshold = ctx.saved_tensors + +# # Here we assume output = input in backward, so gradient is unchanged +# grad_input = grad_output +# return grad_input, None, None, None # Other inputs (bias, crop, threshold) have no gradients + +# # Define the MeanSparse module with the custom backward behavior +# class MeanSparse(nn.Module): +# def __init__(self, in_planes): +# super(MeanSparse, self).__init__() + +# self.register_buffer('running_mean', torch.zeros(in_planes)) +# self.register_buffer('running_var', torch.zeros(in_planes)) + +# self.register_buffer('threshold', torch.tensor(0.0)) +# self.register_buffer('flag_update_statistics', torch.tensor(0)) +# self.register_buffer('batch_num', torch.tensor(0.0)) + +# def forward(self, input): +# if self.flag_update_statistics: +# self.running_mean += (torch.mean(input.detach().clone(), dim=(0, 1)) / self.batch_num) +# self.running_var += (torch.var(input.detach().clone(), dim=(0, 1)) / self.batch_num) + +# bias = self.running_mean.view(1, 1, self.running_mean.shape[0]) +# crop = self.threshold * torch.sqrt(self.running_var).view(1, 1, self.running_var.shape[0]) + +# # Use the custom autograd function for forward and backward passes +# output = MeanSparseFunction.apply(input, bias, crop, self.threshold) +# return output + + +### original +class MeanSparse(nn.Module): + def __init__(self, in_planes): + super(MeanSparse, self).__init__() + + self.register_buffer('running_mean', torch.zeros(in_planes)) + self.register_buffer('running_var', torch.zeros(in_planes)) + + self.register_buffer('threshold', torch.tensor(0.0)) + + self.register_buffer('flag_update_statistics', torch.tensor(0)) + self.register_buffer('batch_num', torch.tensor(0.0)) + + def forward(self, input): + + if self.flag_update_statistics: + self.running_mean += (torch.mean(input.detach().clone(), dim=(0, 1))/self.batch_num) + self.running_var += (torch.var(input.detach().clone(), dim=(0, 1))/self.batch_num) + + bias = self.running_mean.view(1, 1, self.running_mean.shape[0]) + crop = self.threshold * torch.sqrt(self.running_var).view(1, 1, self.running_var.shape[0]) + + diff = input - bias + + if self.threshold == 0: + output = input + else: + output = torch.where(torch.abs(diff) < crop, bias*torch.ones_like(input), input) + + return output + + +class Mlp(nn.Module): + """ MLP as used in Vision Transformer, MLP-Mixer and related networks + """ + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0., + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + self.meansparse1 = MeanSparse(hidden_features) + self.meansparse2 = MeanSparse(in_features) + def forward(self, x): + x = self.fc1(x) + x = self.meansparse1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.meansparse2(x) + x = self.drop2(x) + return x + + +class ImageNormalizer(nn.Module): + + def __init__(self, mean: Tuple[float, float, float], + std: Tuple[float, float, float]) -> None: + super(ImageNormalizer, self).__init__() + + self.register_buffer('mean', torch.as_tensor(mean).view(1, 3, 1, 1)) + self.register_buffer('std', torch.as_tensor(std).view(1, 3, 1, 1)) + + def forward(self, input: Tensor) -> Tensor: + return (input - self.mean) / self.std + + def __repr__(self): + return f'ImageNormalizer(mean={self.mean.squeeze()}, std={self.std.squeeze()})' # type: ignore + +__all__ = ['SwinTransformer'] # model_registry will add each entrypoint fn to this + +_logger = logging.getLogger(__name__) + +_int_or_tuple_2_t = Union[int, Tuple[int, int]] + + +def window_partition( + x: torch.Tensor, + window_size: Tuple[int, int], +) -> torch.Tensor: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse(windows, window_size: Tuple[int, int], H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + C = windows.shape[-1] + x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) + return x + + +def get_relative_position_index(win_h: int, win_w: int): + # get pair-wise relative position index for each token inside the window + coords = torch.stack(ndgrid(torch.arange(win_h), torch.arange(win_w))) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += win_h - 1 # shift to start from 0 + relative_coords[:, :, 1] += win_w - 1 + relative_coords[:, :, 0] *= 2 * win_w - 1 + return relative_coords.sum(-1) # Wh*Ww, Wh*Ww + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports shifted and non-shifted windows. + """ + fused_attn: torch.jit.Final[bool] + + def __init__( + self, + dim: int, + num_heads: int, + head_dim: Optional[int] = None, + window_size: _int_or_tuple_2_t = 7, + qkv_bias: bool = True, + attn_drop: float = 0., + proj_drop: float = 0., + ): + """ + Args: + dim: Number of input channels. + num_heads: Number of attention heads. + head_dim: Number of channels per head (dim // num_heads if not set) + window_size: The height and width of the window. + qkv_bias: If True, add a learnable bias to query, key, value. + attn_drop: Dropout ratio of attention weight. + proj_drop: Dropout ratio of output. + """ + super().__init__() + self.dim = dim + self.window_size = to_2tuple(window_size) # Wh, Ww + win_h, win_w = self.window_size + self.window_area = win_h * win_w + self.num_heads = num_heads + head_dim = head_dim or dim // num_heads + attn_dim = head_dim * num_heads + self.scale = head_dim ** -0.5 + self.fused_attn = use_fused_attn(experimental=True) # NOTE not tested for prime-time yet + + # define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH + self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * win_h - 1) * (2 * win_w - 1), num_heads)) + + # get pair-wise relative position index for each token inside the window + self.register_buffer("relative_position_index", get_relative_position_index(win_h, win_w), persistent=False) + + self.qkv = nn.Linear(dim, attn_dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(attn_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def _get_rel_pos_bias(self) -> torch.Tensor: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view(self.window_area, self.window_area, -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + return relative_position_bias.unsqueeze(0) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + if self.fused_attn: + attn_mask = self._get_rel_pos_bias() + if mask is not None: + num_win = mask.shape[0] + mask = mask.view(1, num_win, 1, N, N).expand(B_ // num_win, -1, self.num_heads, -1, -1) + attn_mask = attn_mask + mask.reshape(-1, self.num_heads, N, N) + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn + self._get_rel_pos_bias() + if mask is not None: + num_win = mask.shape[0] + attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B_, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """ Swin Transformer Block. + """ + + def __init__( + self, + dim: int, + input_resolution: _int_or_tuple_2_t, + num_heads: int = 4, + head_dim: Optional[int] = None, + window_size: _int_or_tuple_2_t = 7, + shift_size: int = 0, + mlp_ratio: float = 4., + qkv_bias: bool = True, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + ): + """ + Args: + dim: Number of input channels. + input_resolution: Input resolution. + window_size: Window size. + num_heads: Number of attention heads. + head_dim: Enforce the number of channels per head + shift_size: Shift size for SW-MSA. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + proj_drop: Dropout rate. + attn_drop: Attention dropout rate. + drop_path: Stochastic depth rate. + act_layer: Activation layer. + norm_layer: Normalization layer. + """ + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + ws, ss = self._calc_window_shift(window_size, shift_size) + self.window_size: Tuple[int, int] = ws + self.shift_size: Tuple[int, int] = ss + self.window_area = self.window_size[0] * self.window_size[1] + self.mlp_ratio = mlp_ratio + + self.norm1 = norm_layer(dim) + self.meansparse1 = MeanSparse(dim) + self.attn = WindowAttention( + dim, + num_heads=num_heads, + head_dim=head_dim, + window_size=to_2tuple(self.window_size), + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=proj_drop, + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.meansparse2 = MeanSparse(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + if any(self.shift_size): + # calculate attention mask for SW-MSA + H, W = self.input_resolution + H = math.ceil(H / self.window_size[0]) * self.window_size[0] + W = math.ceil(W / self.window_size[1]) * self.window_size[1] + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + cnt = 0 + for h in ( + slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)): + for w in ( + slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)): + img_mask[:, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_area) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask, persistent=False) + + def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]: + target_window_size = to_2tuple(target_window_size) + target_shift_size = to_2tuple(target_shift_size) + window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)] + shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)] + return tuple(window_size), tuple(shift_size) + + def _attn(self, x): + B, H, W, C = x.shape + + # cyclic shift + has_shift = any(self.shift_size) + if has_shift: + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) + else: + shifted_x = x + + # pad for resolution not divisible by window size + pad_h = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] + pad_w = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + shifted_x = torch.nn.functional.pad(shifted_x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + shifted_x = shifted_x[:, :H, :W, :].contiguous() + + # reverse cyclic shift + if has_shift: + x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2)) + else: + x = shifted_x + return x + + def forward(self, x): + B, H, W, C = x.shape + x = x + self.drop_path1(self._attn(self.norm1(x))) + x = x.reshape(B, -1, C) + x = self.meansparse1(x) + x = x + self.drop_path2(self.mlp(self.norm2(x))) + x = self.meansparse2(x) + x = x.reshape(B, H, W, C) + return x + + +class PatchMerging(nn.Module): + """ Patch Merging Layer. + """ + + def __init__( + self, + dim: int, + out_dim: Optional[int] = None, + norm_layer: Callable = nn.LayerNorm, + ): + """ + Args: + dim: Number of input channels. + out_dim: Number of output channels (or 2 * dim if None) + norm_layer: Normalization layer. + """ + super().__init__() + self.dim = dim + self.out_dim = out_dim or 2 * dim + self.norm = norm_layer(4 * dim) + self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False) + + def forward(self, x): + B, H, W, C = x.shape + _assert(H % 2 == 0, f"x height ({H}) is not even.") + _assert(W % 2 == 0, f"x width ({W}) is not even.") + x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3) + x = self.norm(x) + x = self.reduction(x) + return x + + +class SwinTransformerStage(nn.Module): + """ A basic Swin Transformer layer for one stage. + """ + + def __init__( + self, + dim: int, + out_dim: int, + input_resolution: Tuple[int, int], + depth: int, + downsample: bool = True, + num_heads: int = 4, + head_dim: Optional[int] = None, + window_size: _int_or_tuple_2_t = 7, + mlp_ratio: float = 4., + qkv_bias: bool = True, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: Union[List[float], float] = 0., + norm_layer: Callable = nn.LayerNorm, + ): + """ + Args: + dim: Number of input channels. + out_dim: Number of output channels. + input_resolution: Input resolution. + depth: Number of blocks. + downsample: Downsample layer at the end of the layer. + num_heads: Number of attention heads. + head_dim: Channels per head (dim // num_heads if not set) + window_size: Local window size. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + proj_drop: Projection dropout rate. + attn_drop: Attention dropout rate. + drop_path: Stochastic depth rate. + norm_layer: Normalization layer. + """ + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution + self.depth = depth + self.grad_checkpointing = False + window_size = to_2tuple(window_size) + shift_size = tuple([w // 2 for w in window_size]) + + # patch merging layer + if downsample: + self.downsample = PatchMerging( + dim=dim, + out_dim=out_dim, + norm_layer=norm_layer, + ) + else: + assert dim == out_dim + self.downsample = nn.Identity() + + # build blocks + self.blocks = nn.Sequential(*[ + SwinTransformerBlock( + dim=out_dim, + input_resolution=self.output_resolution, + num_heads=num_heads, + head_dim=head_dim, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else shift_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_drop=proj_drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth)]) + + def forward(self, x): + x = self.downsample(x) + + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class SwinTransformer(nn.Module): + """ Swin Transformer + + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + """ + + def __init__( + self, + img_size: _int_or_tuple_2_t = 224, + patch_size: int = 4, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + embed_dim: int = 96, + depths: Tuple[int, ...] = (2, 2, 6, 2), + num_heads: Tuple[int, ...] = (3, 6, 12, 24), + head_dim: Optional[int] = None, + window_size: _int_or_tuple_2_t = 7, + mlp_ratio: float = 4., + qkv_bias: bool = True, + drop_rate: float = 0., + proj_drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0.1, + embed_layer: Callable = PatchEmbed, + norm_layer: Union[str, Callable] = nn.LayerNorm, + weight_init: str = '', + **kwargs, + ): + """ + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of input image channels. + num_classes: Number of classes for classification head. + embed_dim: Patch embedding dimension. + depths: Depth of each Swin Transformer layer. + num_heads: Number of attention heads in different layers. + head_dim: Dimension of self-attention heads. + window_size: Window size. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + drop_rate: Dropout rate. + attn_drop_rate (float): Attention dropout rate. + drop_path_rate (float): Stochastic depth rate. + embed_layer: Patch embedding layer. + norm_layer (nn.Module): Normalization layer. + """ + super().__init__() + assert global_pool in ('', 'avg') + self.num_classes = num_classes + self.global_pool = global_pool + self.output_fmt = 'NHWC' + + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.feature_info = [] + + if not isinstance(embed_dim, (tuple, list)): + embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + + # split image into non-overlapping patches + self.patch_embed = embed_layer( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim[0], + norm_layer=norm_layer, + output_fmt='NHWC', + ) + self.patch_grid = self.patch_embed.grid_size + + # build layers + head_dim = to_ntuple(self.num_layers)(head_dim) + if not isinstance(window_size, (list, tuple)): + window_size = to_ntuple(self.num_layers)(window_size) + elif len(window_size) == 2: + window_size = (window_size,) * self.num_layers + assert len(window_size) == self.num_layers + mlp_ratio = to_ntuple(self.num_layers)(mlp_ratio) + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + layers = [] + in_dim = embed_dim[0] + scale = 1 + for i in range(self.num_layers): + out_dim = embed_dim[i] + layers += [SwinTransformerStage( + dim=in_dim, + out_dim=out_dim, + input_resolution=( + self.patch_grid[0] // scale, + self.patch_grid[1] // scale + ), + depth=depths[i], + downsample=i > 0, + num_heads=num_heads[i], + head_dim=head_dim[i], + window_size=window_size[i], + mlp_ratio=mlp_ratio[i], + qkv_bias=qkv_bias, + proj_drop=proj_drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + )] + in_dim = out_dim + if i > 0: + scale *= 2 + self.feature_info += [dict(num_chs=out_dim, reduction=4 * scale, module=f'layers.{i}')] + self.layers = nn.Sequential(*layers) + + self.norm = norm_layer(self.num_features) + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=drop_rate, + input_fmt=self.output_fmt, + ) + if weight_init != 'skip': + self.init_weights(weight_init) + + @torch.jit.ignore + def init_weights(self, mode=''): + assert mode in ('jax', 'jax_nlhb', 'moco', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + named_apply(get_init_weights_vit(mode, head_bias=head_bias), self) + + @torch.jit.ignore + def no_weight_decay(self): + nwd = set() + for n, _ in self.named_parameters(): + if 'relative_position_bias_table' in n: + nwd.add(n) + return nwd + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^patch_embed', # stem and embed + blocks=r'^layers\.(\d+)' if coarse else [ + (r'^layers\.(\d+).downsample', (0,)), + (r'^layers\.(\d+)\.\w+\.(\d+)', None), + (r'^norm', (99999,)), + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for l in self.layers: + l.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + self.head.reset(num_classes, pool_type=global_pool) + + def forward_features(self, x): + x = self.patch_embed(x) + x = self.layers(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=True) if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + old_weights = True + if 'head.fc.weight' in state_dict: + old_weights = False + import re + out_dict = {} + state_dict = state_dict.get('model', state_dict) + state_dict = state_dict.get('state_dict', state_dict) + for k, v in state_dict.items(): + if any([n in k for n in ('relative_position_index', 'attn_mask')]): + continue # skip buffers that should not be persistent + + if 'patch_embed.proj.weight' in k: + _, _, H, W = model.patch_embed.proj.weight.shape + if v.shape[-2] != H or v.shape[-1] != W: + v = resample_patch_embed( + v, + (H, W), + interpolation='bicubic', + antialias=True, + verbose=True, + ) + + if k.endswith('relative_position_bias_table'): + m = model.get_submodule(k[:-29]) + if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]: + v = resize_rel_pos_bias_table( + v, + new_window_size=m.window_size, + new_bias_shape=m.relative_position_bias_table.shape, + ) + + if old_weights: + k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k) + k = k.replace('head.', 'head.fc.') + + out_dict[k] = v + return out_dict + + +def _create_swin_transformer(variant, pretrained=False, **kwargs): + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) + out_indices = kwargs.pop('out_indices', default_out_indices) + + model = build_model_with_cfg( + SwinTransformer, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs) + + return model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', + 'license': 'mit', **kwargs + } + + +default_cfgs = generate_default_cfgs({ + # 'swin_small_patch4_window7_224.ms_in22k_ft_in1k': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22kto1k_finetune.pth', ), + # 'swin_base_patch4_window7_224.ms_in22k_ft_in1k': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22kto1k.pth',), + # 'swin_base_patch4_window12_384.ms_in22k_ft_in1k': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22kto1k.pth', + # input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + 'swin_large_patch4_window7_224.ms_in22k_ft_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22kto1k.pth',), + # 'swin_large_patch4_window12_384.ms_in22k_ft_in1k': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth', + # input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + + # 'swin_tiny_patch4_window7_224.ms_in1k': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth',), + # 'swin_small_patch4_window7_224.ms_in1k': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_small_patch4_window7_224.pth',), + # 'swin_base_patch4_window7_224.ms_in1k': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224.pth',), + # 'swin_base_patch4_window12_384.ms_in1k': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384.pth', + # input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + + # # tiny 22k pretrain is worse than 1k, so moved after (untagged priority is based on order) + # 'swin_tiny_patch4_window7_224.ms_in22k_ft_in1k': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22kto1k_finetune.pth',), + + # 'swin_tiny_patch4_window7_224.ms_in22k': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_tiny_patch4_window7_224_22k.pth', + # num_classes=21841), + # 'swin_small_patch4_window7_224.ms_in22k': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/SwinTransformer/storage/releases/download/v1.0.8/swin_small_patch4_window7_224_22k.pth', + # num_classes=21841), + # 'swin_base_patch4_window7_224.ms_in22k': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth', + # num_classes=21841), + # 'swin_base_patch4_window12_384.ms_in22k': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth', + # input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841), + # 'swin_large_patch4_window7_224.ms_in22k': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window7_224_22k.pth', + # num_classes=21841), + # 'swin_large_patch4_window12_384.ms_in22k': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth', + # input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, num_classes=21841), + + # 'swin_s3_tiny_224.ms_in1k': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_t-1d53f6a8.pth'), + # 'swin_s3_small_224.ms_in1k': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_s-3bb4c69d.pth'), + # 'swin_s3_base_224.ms_in1k': _cfg( + # hf_hub_id='timm/', + # url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/s3_b-a1e95db4.pth'), +}) + + +# @register_model +# def swin_tiny_patch4_window7_224(pretrained=False, **kwargs) -> SwinTransformer: +# """ Swin-T @ 224x224, trained ImageNet-1k +# """ +# model_args = dict(patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24)) +# return _create_swin_transformer( +# 'swin_tiny_patch4_window7_224', pretrained=pretrained, **dict(model_args, **kwargs)) + + +# @register_model +# def swin_small_patch4_window7_224(pretrained=False, **kwargs) -> SwinTransformer: +# """ Swin-S @ 224x224 +# """ +# model_args = dict(patch_size=4, window_size=7, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24)) +# return _create_swin_transformer( +# 'swin_small_patch4_window7_224', pretrained=pretrained, **dict(model_args, **kwargs)) + + +# @register_model +# def swin_base_patch4_window7_224(pretrained=False, **kwargs) -> SwinTransformer: +# """ Swin-B @ 224x224 +# """ +# model_args = dict(patch_size=4, window_size=7, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32)) +# return _create_swin_transformer( +# 'swin_base_patch4_window7_224', pretrained=pretrained, **dict(model_args, **kwargs)) + + +# @register_model +# def swin_base_patch4_window12_384(pretrained=False, **kwargs) -> SwinTransformer: +# """ Swin-B @ 384x384 +# """ +# model_args = dict(patch_size=4, window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32)) +# return _create_swin_transformer( +# 'swin_base_patch4_window12_384', pretrained=pretrained, **dict(model_args, **kwargs)) + + +@register_model +def swin_large_patch4_window7_224_with_MeanSparse(pretrained=False, **kwargs) -> SwinTransformer: + """ Swin-L @ 224x224 + """ + model_args = dict(patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48)) + model = _create_swin_transformer( + 'swin_large_patch4_window7_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return nn.Sequential(OrderedDict([('normalize', ImageNormalizer(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))), ('model', model)])) + + +# @register_model +# def swin_large_patch4_window12_384(pretrained=False, **kwargs) -> SwinTransformer: +# """ Swin-L @ 384x384 +# """ +# model_args = dict(patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48)) +# return _create_swin_transformer( +# 'swin_large_patch4_window12_384', pretrained=pretrained, **dict(model_args, **kwargs)) + + +# @register_model +# def swin_s3_tiny_224(pretrained=False, **kwargs) -> SwinTransformer: +# """ Swin-S3-T @ 224x224, https://arxiv.org/abs/2111.14725 +# """ +# model_args = dict( +# patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24)) +# return _create_swin_transformer('swin_s3_tiny_224', pretrained=pretrained, **dict(model_args, **kwargs)) + + +# @register_model +# def swin_s3_small_224(pretrained=False, **kwargs) -> SwinTransformer: +# """ Swin-S3-S @ 224x224, https://arxiv.org/abs/2111.14725 +# """ +# model_args = dict( +# patch_size=4, window_size=(14, 14, 14, 7), embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24)) +# return _create_swin_transformer('swin_s3_small_224', pretrained=pretrained, **dict(model_args, **kwargs)) + + +# @register_model +# def swin_s3_base_224(pretrained=False, **kwargs) -> SwinTransformer: +# """ Swin-S3-B @ 224x224, https://arxiv.org/abs/2111.14725 +# """ +# model_args = dict( +# patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 30, 2), num_heads=(3, 6, 12, 24)) +# return _create_swin_transformer('swin_s3_base_224', pretrained=pretrained, **dict(model_args, **kwargs)) + + +# register_model_deprecations(__name__, { +# 'swin_base_patch4_window7_224_in22k': 'swin_base_patch4_window7_224.ms_in22k', +# 'swin_base_patch4_window12_384_in22k': 'swin_base_patch4_window12_384.ms_in22k', +# 'swin_large_patch4_window7_224_in22k': 'swin_large_patch4_window7_224.ms_in22k', +# 'swin_large_patch4_window12_384_in22k': 'swin_large_patch4_window12_384.ms_in22k', +# }) \ No newline at end of file diff --git a/robustbench/model_zoo/architectures/Meansparse_wrn_70_16.py b/robustbench/model_zoo/architectures/Meansparse_wrn_70_16.py new file mode 100644 index 0000000..836da4e --- /dev/null +++ b/robustbench/model_zoo/architectures/Meansparse_wrn_70_16.py @@ -0,0 +1,402 @@ +# Copyright 2020 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""WideResNet implementation in PyTorch. From: +https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/pytorch/model_zoo.py +""" + +from typing import Tuple, Type, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) +CIFAR10_STD = (0.2471, 0.2435, 0.2616) +CIFAR100_MEAN = (0.5071, 0.4865, 0.4409) +CIFAR100_STD = (0.2673, 0.2564, 0.2762) + + +# Custom autograd function for forward and backward modification +# class MeanSparseFunction2D(torch.autograd.Function): +# @staticmethod +# def forward(ctx, input, bias, crop, threshold): +# # Save context variables for backward computation if needed +# ctx.save_for_backward(input, bias, crop, threshold) + +# # Forward computation (given in the question) +# if threshold == 0: +# output = input +# else: +# diff = input - bias +# output = torch.where(torch.abs(diff) < crop, bias, #* torch.ones_like(input), +# input) + +# return output + +# @staticmethod +# def backward(ctx, grad_output): +# # For backward, we want output = input, so we pass grad_output as-is. +# input, bias, crop, threshold = ctx.saved_tensors + +# # Here we assume output = input in backward, so gradient is unchanged +# grad_input = grad_output +# return grad_input, None, None, None # Other inputs (bias, crop, threshold) have no gradients + +# # Define the MeanSparse module with modified backward behavior +# class MeanSparse(nn.Module): +# def __init__(self, in_planes): +# super(MeanSparse, self).__init__() + +# self.register_buffer('running_mean', torch.zeros(in_planes)) +# self.register_buffer('running_var', torch.zeros(in_planes)) + +# self.register_buffer('threshold', torch.tensor(0.0)) +# self.register_buffer('flag_update_statistics', torch.tensor(0)) +# self.register_buffer('batch_num', torch.tensor(0.0)) + +# def forward(self, input): +# if self.flag_update_statistics: +# # Calculate running mean and variance over batch, height, and width dimensions +# self.running_mean += (torch.mean(input.detach().clone(), dim=(0, 2, 3)) / self.batch_num) +# self.running_var += (torch.var(input.detach().clone(), dim=(0, 2, 3)) / self.batch_num) + +# bias = self.running_mean.view(1, self.running_mean.shape[0], 1, 1) +# crop = self.threshold * torch.sqrt(self.running_var).view(1, self.running_var.shape[0], 1, 1) + +# # Use the custom autograd function for forward and backward passes +# output = MeanSparseFunction2D.apply(input, bias, crop, self.threshold) +# return output + + +### original +class MeanSparse(nn.Module): + def __init__(self, in_planes): + super(MeanSparse, self).__init__() + + self.register_buffer('running_mean', torch.zeros(in_planes)) + self.register_buffer('running_var', torch.zeros(in_planes)) + + self.register_buffer('threshold', torch.tensor(0.0)) + + self.register_buffer('flag_update_statistics', torch.tensor(0)) + self.register_buffer('batch_num', torch.tensor(0.0)) + + self.bias = None + self.crop = None + + def forward(self, input): + + if self.flag_update_statistics: + self.running_mean += (torch.mean(input.detach().clone(), dim=(0, 2, 3))/self.batch_num) + self.running_var += (torch.var(input.detach().clone(), dim=(0, 2, 3))/self.batch_num) + + bias = self.running_mean.view(1, self.running_mean.shape[0], 1, 1) + crop = self.threshold * torch.sqrt(self.running_var).view(1, self.running_var.shape[0], 1, 1) + + diff = input - bias + + if self.threshold == 0: + output = input + else: + output = torch.where(torch.abs(diff) < crop, bias*torch.ones_like(input), input) + + # if self.bias is None: + # self.bias = self.running_mean.view(1, self.running_mean.shape[0], 1, 1) + # self.crop = self.threshold * torch.sqrt(self.running_var).view(1, self.running_var.shape[0], 1, 1) + + # diff = input - self.bias + + # if self.threshold == 0: + # output = input + # else: + # output = torch.where(torch.abs(diff) < self.crop, self.bias, input) + + return output + +class _Swish(torch.autograd.Function): + """Custom implementation of swish.""" + + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_variables[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class Swish(nn.Module): + """Module using custom implementation.""" + + def forward(self, input_tensor): + return _Swish.apply(input_tensor) + + +class _Block(nn.Module): + """WideResNet Block.""" + + def __init__(self, + in_planes, + out_planes, + stride, + activation_fn: Type[nn.Module] = nn.ReLU): + super().__init__() + self.batchnorm_0 = nn.BatchNorm2d(in_planes) + self.meansparse_0 = MeanSparse(in_planes) + self.relu_0 = activation_fn() + # We manually pad to obtain the same effect as `SAME` (necessary when + # `stride` is different than 1). + self.conv_0 = nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=0, + bias=False) + self.batchnorm_1 = nn.BatchNorm2d(out_planes) + self.meansparse_1 = MeanSparse(out_planes) + self.relu_1 = activation_fn() + self.conv_1 = nn.Conv2d(out_planes, + out_planes, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.has_shortcut = in_planes != out_planes + if self.has_shortcut: + self.shortcut = nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + padding=0, + bias=False) + else: + self.shortcut = None + self._stride = stride + self.meansparse_2 = MeanSparse(out_planes) + + def forward(self, x): + if self.has_shortcut: + x = self.relu_0(self.meansparse_0(self.batchnorm_0(x))) + else: + out = self.relu_0(self.meansparse_0(self.batchnorm_0(x))) + v = x if self.has_shortcut else out + if self._stride == 1: + v = F.pad(v, (1, 1, 1, 1)) + elif self._stride == 2: + v = F.pad(v, (0, 1, 0, 1)) + else: + raise ValueError('Unsupported `stride`.') + out = self.conv_0(v) + out = self.relu_1(self.meansparse_1(self.batchnorm_1(out))) + out = self.conv_1(out) + out = torch.add(self.shortcut(x) if self.has_shortcut else x, out) + out = self.meansparse_2(out) + return out + + +class _BlockGroup(nn.Module): + """WideResNet block group.""" + + def __init__(self, + num_blocks, + in_planes, + out_planes, + stride, + activation_fn: Type[nn.Module] = nn.ReLU): + super().__init__() + block = [] + for i in range(num_blocks): + block.append( + _Block(i == 0 and in_planes or out_planes, + out_planes, + i == 0 and stride or 1, + activation_fn=activation_fn)) + self.block = nn.Sequential(*block) + + def forward(self, x): + return self.block(x) + + +class DMWideResNet(nn.Module): + """WideResNet.""" + + def __init__(self, + num_classes: int = 10, + depth: int = 28, + width: int = 10, + activation_fn: Type[nn.Module] = nn.ReLU, + mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN, + std: Union[Tuple[float, ...], float] = CIFAR10_STD, + padding: int = 0, + num_input_channels: int = 3): + super().__init__() + # persistent=False to not put these tensors in the module's state_dict and not try to + # load it from the checkpoint + self.register_buffer('mean', torch.tensor(mean).view(num_input_channels, 1, 1), + persistent=False) + self.register_buffer('std', torch.tensor(std).view(num_input_channels, 1, 1), + persistent=False) + self.padding = padding + num_channels = [16, 16 * width, 32 * width, 64 * width] + assert (depth - 4) % 6 == 0 + num_blocks = (depth - 4) // 6 + self.init_conv = nn.Conv2d(num_input_channels, + num_channels[0], + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.layer = nn.Sequential( + _BlockGroup(num_blocks, + num_channels[0], + num_channels[1], + 1, + activation_fn=activation_fn), + _BlockGroup(num_blocks, + num_channels[1], + num_channels[2], + 2, + activation_fn=activation_fn), + _BlockGroup(num_blocks, + num_channels[2], + num_channels[3], + 2, + activation_fn=activation_fn)) + self.batchnorm = nn.BatchNorm2d(num_channels[3]) + self.meansparse_end = MeanSparse(num_channels[3]) + self.relu = activation_fn() + self.logits = nn.Linear(num_channels[3], num_classes) + self.num_channels = num_channels[3] + + def forward(self, x): + if self.padding > 0: + x = F.pad(x, (self.padding,) * 4) + out = (x - self.mean) / self.std + out = self.init_conv(out) + out = self.layer(out) + out = self.relu(self.meansparse_end(self.batchnorm(out))) + out = F.avg_pool2d(out, 8) + out = out.view(-1, self.num_channels) + return self.logits(out) + + +class _PreActBlock(nn.Module): + """Pre-activation ResNet Block.""" + + def __init__(self, in_planes, out_planes, stride, activation_fn=nn.ReLU): + super().__init__() + self._stride = stride + self.batchnorm_0 = nn.BatchNorm2d(in_planes) + self.relu_0 = activation_fn() + # We manually pad to obtain the same effect as `SAME` (necessary when + # `stride` is different than 1). + self.conv_2d_1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=stride, padding=0, bias=False) + self.batchnorm_1 = nn.BatchNorm2d(out_planes) + self.relu_1 = activation_fn() + self.conv_2d_2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, + padding=1, bias=False) + self.has_shortcut = stride != 1 or in_planes != out_planes + if self.has_shortcut: + self.shortcut = nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=stride, padding=0, bias=False) + + def _pad(self, x): + if self._stride == 1: + x = F.pad(x, (1, 1, 1, 1)) + elif self._stride == 2: + x = F.pad(x, (0, 1, 0, 1)) + else: + raise ValueError('Unsupported `stride`.') + return x + + def forward(self, x): + out = self.relu_0(self.batchnorm_0(x)) + shortcut = self.shortcut(self._pad(x)) if self.has_shortcut else x + out = self.conv_2d_1(self._pad(out)) + out = self.conv_2d_2(self.relu_1(self.batchnorm_1(out))) + return out + shortcut + + +class DMPreActResNet(nn.Module): + """Pre-activation ResNet.""" + + def __init__(self, + num_classes: int = 10, + depth: int = 18, + width: int = 0, # Used to make the constructor consistent. + activation_fn: Type[nn.Module] = nn.ReLU, + mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN, + std: Union[Tuple[float, ...], float] = CIFAR10_STD, + padding: int = 0, + num_input_channels: int = 3, + use_cuda: bool = True): + super().__init__() + if width != 0: + raise ValueError('Unsupported `width`.') + # persistent=False to not put these tensors in the module's state_dict and not try to + # load it from the checkpoint + self.register_buffer('mean', torch.tensor(mean).view(num_input_channels, 1, 1), + persistent=False) + self.register_buffer('std', torch.tensor(std).view(num_input_channels, 1, 1), + persistent=False) + self.mean_cuda = None + self.std_cuda = None + self.padding = padding + self.conv_2d = nn.Conv2d(num_input_channels, 64, kernel_size=3, stride=1, + padding=1, bias=False) + if depth == 18: + num_blocks = (2, 2, 2, 2) + elif depth == 34: + num_blocks = (3, 4, 6, 3) + else: + raise ValueError('Unsupported `depth`.') + self.layer_0 = self._make_layer(64, 64, num_blocks[0], 1, activation_fn) + self.layer_1 = self._make_layer(64, 128, num_blocks[1], 2, activation_fn) + self.layer_2 = self._make_layer(128, 256, num_blocks[2], 2, activation_fn) + self.layer_3 = self._make_layer(256, 512, num_blocks[3], 2, activation_fn) + self.batchnorm = nn.BatchNorm2d(512) + self.relu = activation_fn() + self.logits = nn.Linear(512, num_classes) + + def _make_layer(self, in_planes, out_planes, num_blocks, stride, + activation_fn): + layers = [] + for i, stride in enumerate([stride] + [1] * (num_blocks - 1)): + layers.append( + _PreActBlock(i == 0 and in_planes or out_planes, + out_planes, + stride, + activation_fn)) + return nn.Sequential(*layers) + + def forward(self, x): + if self.padding > 0: + x = F.pad(x, (self.padding,) * 4) + out = (x - self.mean) / self.std + out = self.conv_2d(out) + out = self.layer_0(out) + out = self.layer_1(out) + out = self.layer_2(out) + out = self.layer_3(out) + out = self.relu(self.batchnorm(out)) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + return self.logits(out) \ No newline at end of file diff --git a/robustbench/model_zoo/architectures/Meansparse_wrn_94_16.py b/robustbench/model_zoo/architectures/Meansparse_wrn_94_16.py new file mode 100644 index 0000000..6244e6b --- /dev/null +++ b/robustbench/model_zoo/architectures/Meansparse_wrn_94_16.py @@ -0,0 +1,388 @@ +# Copyright 2020 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""WideResNet implementation in PyTorch. From: +https://github.com/deepmind/deepmind-research/blob/master/adversarial_robustness/pytorch/model_zoo.py +""" + +from typing import Tuple, Type, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +CIFAR10_MEAN = (0.4914, 0.4822, 0.4465) +CIFAR10_STD = (0.2471, 0.2435, 0.2616) +CIFAR100_MEAN = (0.5071, 0.4865, 0.4409) +CIFAR100_STD = (0.2673, 0.2564, 0.2762) + + +# Custom autograd function for forward and backward modification +# class MeanSparseFunction2D(torch.autograd.Function): +# @staticmethod +# def forward(ctx, input, bias, crop, threshold): +# # Save context variables for backward computation if needed +# ctx.save_for_backward(input, bias, crop, threshold) + +# # Forward computation (given in the question) +# if threshold == 0: +# output = input +# else: +# diff = input - bias +# output = torch.where(torch.abs(diff) < crop, bias, input) + +# return output + +# @staticmethod +# def backward(ctx, grad_output): +# # For backward, we want output = input, so we pass grad_output as-is. +# input, bias, crop, threshold = ctx.saved_tensors + +# # Here we assume output = input in backward, so gradient is unchanged +# grad_input = grad_output +# return grad_input, None, None, None # Other inputs (bias, crop, threshold) have no gradients + +# # Define the MeanSparse module with modified backward behavior +# class MeanSparse(nn.Module): +# def __init__(self, in_planes): +# super(MeanSparse, self).__init__() + +# self.register_buffer('running_mean', torch.zeros(in_planes)) +# self.register_buffer('running_var', torch.zeros(in_planes)) + +# self.register_buffer('threshold', torch.tensor(0.0)) +# self.register_buffer('flag_update_statistics', torch.tensor(0)) +# self.register_buffer('batch_num', torch.tensor(0.0)) + +# def forward(self, input): +# if self.flag_update_statistics: +# # Calculate running mean and variance over batch, height, and width dimensions +# self.running_mean += (torch.mean(input.detach().clone(), dim=(0, 2, 3)) / self.batch_num) +# self.running_var += (torch.var(input.detach().clone(), dim=(0, 2, 3)) / self.batch_num) + +# bias = self.running_mean.view(1, self.running_mean.shape[0], 1, 1) +# crop = self.threshold * torch.sqrt(self.running_var).view(1, self.running_var.shape[0], 1, 1) + +# # Use the custom autograd function for forward and backward passes +# output = MeanSparseFunction2D.apply(input, bias, crop, self.threshold) +# return output + + +### original +class MeanSparse(nn.Module): + def __init__(self, in_planes): + super(MeanSparse, self).__init__() + + self.register_buffer('running_mean', torch.zeros(in_planes)) + self.register_buffer('running_var', torch.zeros(in_planes)) + + self.register_buffer('threshold', torch.tensor(0.0)) + + self.register_buffer('flag_update_statistics', torch.tensor(0)) + self.register_buffer('batch_num', torch.tensor(0.0)) + + def forward(self, input): + + if self.flag_update_statistics: + self.running_mean += (torch.mean(input.detach().clone(), dim=(0, 2, 3))/self.batch_num) + self.running_var += (torch.var(input.detach().clone(), dim=(0, 2, 3))/self.batch_num) + + bias = self.running_mean.view(1, self.running_mean.shape[0], 1, 1) + crop = self.threshold * torch.sqrt(self.running_var).view(1, self.running_var.shape[0], 1, 1) + + diff = input - bias + + if self.threshold == 0: + output = input + else: + output = torch.where(torch.abs(diff) < crop, bias*torch.ones_like(input), input) + + return output + +class _Swish(torch.autograd.Function): + """Custom implementation of swish.""" + + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_variables[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class Swish(nn.Module): + """Module using custom implementation.""" + + def forward(self, input_tensor): + return _Swish.apply(input_tensor) + + +class _Block(nn.Module): + """WideResNet Block.""" + + def __init__(self, + in_planes, + out_planes, + stride, + activation_fn: Type[nn.Module] = nn.ReLU): + super().__init__() + self.batchnorm_0 = nn.BatchNorm2d(in_planes) + self.meansparse_0 = MeanSparse(in_planes) + + self.relu_0 = activation_fn() + # We manually pad to obtain the same effect as `SAME` (necessary when + # `stride` is different than 1). + self.conv_0 = nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=0, + bias=False) + self.batchnorm_1 = nn.BatchNorm2d(out_planes) + self.meansparse_1 = MeanSparse(out_planes) + + self.relu_1 = activation_fn() + self.conv_1 = nn.Conv2d(out_planes, + out_planes, + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.has_shortcut = in_planes != out_planes + if self.has_shortcut: + self.shortcut = nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + padding=0, + bias=False) + else: + self.shortcut = None + self._stride = stride + self.meansparse_2 = MeanSparse(out_planes) + + def forward(self, x): + if self.has_shortcut: + x = self.relu_0(self.meansparse_0(self.batchnorm_0(x))) + else: + out = self.relu_0(self.meansparse_0(self.batchnorm_0(x))) + v = x if self.has_shortcut else out + if self._stride == 1: + v = F.pad(v, (1, 1, 1, 1)) + elif self._stride == 2: + v = F.pad(v, (0, 1, 0, 1)) + else: + raise ValueError('Unsupported `stride`.') + out = self.conv_0(v) + out = self.relu_1(self.meansparse_1(self.batchnorm_1(out))) + out = self.conv_1(out) + out = torch.add(self.shortcut(x) if self.has_shortcut else x, out) + out = self.meansparse_2(out) + return out + + +class _BlockGroup(nn.Module): + """WideResNet block group.""" + + def __init__(self, + num_blocks, + in_planes, + out_planes, + stride, + activation_fn: Type[nn.Module] = nn.ReLU): + super().__init__() + block = [] + for i in range(num_blocks): + block.append( + _Block(i == 0 and in_planes or out_planes, + out_planes, + i == 0 and stride or 1, + activation_fn=activation_fn)) + self.block = nn.Sequential(*block) + + def forward(self, x): + return self.block(x) + + +class MeanSparse_DMWideResNet(nn.Module): + """WideResNet.""" + + def __init__(self, + num_classes: int = 10, + depth: int = 28, + width: int = 10, + activation_fn: Type[nn.Module] = nn.ReLU, + mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN, + std: Union[Tuple[float, ...], float] = CIFAR10_STD, + padding: int = 0, + num_input_channels: int = 3): + super().__init__() + # persistent=False to not put these tensors in the module's state_dict and not try to + # load it from the checkpoint + self.register_buffer('mean', torch.tensor(mean).view(num_input_channels, 1, 1), + persistent=False) + self.register_buffer('std', torch.tensor(std).view(num_input_channels, 1, 1), + persistent=False) + self.padding = padding + num_channels = [16, 16 * width, 32 * width, 64 * width] + assert (depth - 4) % 6 == 0 + num_blocks = (depth - 4) // 6 + self.init_conv = nn.Conv2d(num_input_channels, + num_channels[0], + kernel_size=3, + stride=1, + padding=1, + bias=False) + self.layer = nn.Sequential( + _BlockGroup(num_blocks, + num_channels[0], + num_channels[1], + 1, + activation_fn=activation_fn), + _BlockGroup(num_blocks, + num_channels[1], + num_channels[2], + 2, + activation_fn=activation_fn), + _BlockGroup(num_blocks, + num_channels[2], + num_channels[3], + 2, + activation_fn=activation_fn)) + self.batchnorm = nn.BatchNorm2d(num_channels[3]) + self.relu = activation_fn() + self.logits = nn.Linear(num_channels[3], num_classes) + self.num_channels = num_channels[3] + self.meansparse_end = MeanSparse(num_channels[3]) + + def forward(self, x): + if self.padding > 0: + x = F.pad(x, (self.padding,) * 4) + out = (x - self.mean) / self.std + out = self.init_conv(out) + out = self.layer(out) + out = self.relu(self.meansparse_end(self.batchnorm(out))) + out = F.avg_pool2d(out, 8) + out = out.view(-1, self.num_channels) + return self.logits(out) + + +class _PreActBlock(nn.Module): + """Pre-activation ResNet Block.""" + + def __init__(self, in_planes, out_planes, stride, activation_fn=nn.ReLU): + super().__init__() + self._stride = stride + self.batchnorm_0 = nn.BatchNorm2d(in_planes) + self.relu_0 = activation_fn() + # We manually pad to obtain the same effect as `SAME` (necessary when + # `stride` is different than 1). + self.conv_2d_1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=stride, padding=0, bias=False) + self.batchnorm_1 = nn.BatchNorm2d(out_planes) + self.relu_1 = activation_fn() + self.conv_2d_2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, + padding=1, bias=False) + self.has_shortcut = stride != 1 or in_planes != out_planes + if self.has_shortcut: + self.shortcut = nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=stride, padding=0, bias=False) + + def _pad(self, x): + if self._stride == 1: + x = F.pad(x, (1, 1, 1, 1)) + elif self._stride == 2: + x = F.pad(x, (0, 1, 0, 1)) + else: + raise ValueError('Unsupported `stride`.') + return x + + def forward(self, x): + out = self.relu_0(self.batchnorm_0(x)) + shortcut = self.shortcut(self._pad(x)) if self.has_shortcut else x + out = self.conv_2d_1(self._pad(out)) + out = self.conv_2d_2(self.relu_1(self.batchnorm_1(out))) + return out + shortcut + + +class DMPreActResNet(nn.Module): + """Pre-activation ResNet.""" + + def __init__(self, + num_classes: int = 10, + depth: int = 18, + width: int = 0, # Used to make the constructor consistent. + activation_fn: Type[nn.Module] = nn.ReLU, + mean: Union[Tuple[float, ...], float] = CIFAR10_MEAN, + std: Union[Tuple[float, ...], float] = CIFAR10_STD, + padding: int = 0, + num_input_channels: int = 3, + use_cuda: bool = True): + super().__init__() + if width != 0: + raise ValueError('Unsupported `width`.') + # persistent=False to not put these tensors in the module's state_dict and not try to + # load it from the checkpoint + self.register_buffer('mean', torch.tensor(mean).view(num_input_channels, 1, 1), + persistent=False) + self.register_buffer('std', torch.tensor(std).view(num_input_channels, 1, 1), + persistent=False) + self.mean_cuda = None + self.std_cuda = None + self.padding = padding + self.conv_2d = nn.Conv2d(num_input_channels, 64, kernel_size=3, stride=1, + padding=1, bias=False) + if depth == 18: + num_blocks = (2, 2, 2, 2) + elif depth == 34: + num_blocks = (3, 4, 6, 3) + else: + raise ValueError('Unsupported `depth`.') + self.layer_0 = self._make_layer(64, 64, num_blocks[0], 1, activation_fn) + self.layer_1 = self._make_layer(64, 128, num_blocks[1], 2, activation_fn) + self.layer_2 = self._make_layer(128, 256, num_blocks[2], 2, activation_fn) + self.layer_3 = self._make_layer(256, 512, num_blocks[3], 2, activation_fn) + self.batchnorm = nn.BatchNorm2d(512) + self.relu = activation_fn() + self.logits = nn.Linear(512, num_classes) + + def _make_layer(self, in_planes, out_planes, num_blocks, stride, + activation_fn): + layers = [] + for i, stride in enumerate([stride] + [1] * (num_blocks - 1)): + layers.append( + _PreActBlock(i == 0 and in_planes or out_planes, + out_planes, + stride, + activation_fn)) + return nn.Sequential(*layers) + + def forward(self, x): + if self.padding > 0: + x = F.pad(x, (self.padding,) * 4) + out = (x - self.mean) / self.std + out = self.conv_2d(out) + out = self.layer_0(out) + out = self.layer_1(out) + out = self.layer_2(out) + out = self.layer_3(out) + out = self.relu(self.batchnorm(out)) + out = F.avg_pool2d(out, 4) + out = out.view(out.size(0), -1) + return self.logits(out) \ No newline at end of file diff --git a/robustbench/model_zoo/architectures/sparsified_model.py b/robustbench/model_zoo/architectures/sparsified_model.py index fcce868..bfc70ad 100644 --- a/robustbench/model_zoo/architectures/sparsified_model.py +++ b/robustbench/model_zoo/architectures/sparsified_model.py @@ -2,48 +2,11 @@ import torch.nn as nn from timm.layers.activations import GELU -class MeanSparse_cifar_10(nn.Module): - def __init__(self, in_planes, momentum=0.1): - super(MeanSparse_cifar_10, self).__init__() - - self.register_buffer('running_mean', torch.zeros(in_planes)) - self.register_buffer('running_var', torch.zeros(in_planes)) - - self.register_buffer('threshold', torch.tensor(0.0)) - - self.register_buffer('flag_update_statistics', torch.tensor(0)) - self.register_buffer('batch_num', torch.tensor(0.0)) - - def forward(self, input): - - if self.flag_update_statistics: - self.running_mean += (torch.mean(input.detach().clone(), dim=(0, 2, 3))/self.batch_num) - self.running_var += (torch.var(input.detach().clone(), dim=(0, 2, 3))/self.batch_num) - - bias = self.running_mean.view(1, self.running_mean.shape[0], 1, 1) - crop = self.threshold * torch.sqrt(self.running_var).view(1, self.running_var.shape[0], 1, 1) - - diff = input - bias - - if self.threshold == 0: - output = input - else: - output = torch.where(torch.abs(diff) < crop, bias*torch.ones_like(input), input) - - return output - -def add_custom_layer_cifar_10(model, custom_layer_class, parent_path='', prev_features=None): - for name, child in model.named_children(): - current_path = f"{parent_path}.{name}" if parent_path else name # Build the current path - - if isinstance(child, nn.SiLU): - modified_layer = nn.Sequential(custom_layer_class(prev_features), child) - setattr(model, name, modified_layer) - - elif isinstance(child, nn.BatchNorm2d): - prev_features = child.num_features +from robustbench.model_zoo.architectures.Meansparse_wrn_70_16 import DMWideResNet as s_wrn_70_16 +from robustbench.model_zoo.architectures.Meansparse_swin_L import swin_large_patch4_window7_224_with_MeanSparse +from robustbench.model_zoo.architectures.Meansparse_wrn_94_16 import MeanSparse_DMWideResNet +from robustbench.model_zoo.architectures.Meansparse_ra_wrn_70_16 import NormalizedWideResNet - add_custom_layer_cifar_10(child, custom_layer_class, current_path, prev_features) def add_custom_layer_imagenet(model, custom_layer_class, parent_path='', prev_features=None): for name, child in model.named_children(): @@ -77,6 +40,75 @@ def add_custom_layer_imagenet(model, custom_layer_class, parent_path='', prev_fe add_custom_layer_imagenet(child, custom_layer_class, current_path, prev_features) + +# Custom autograd function for controlling forward and backward behavior +# class MeanSparseFunctionImagenet(torch.autograd.Function): +# @staticmethod +# def forward(ctx, input, bias, crop, threshold): +# # Save for backward pass if needed +# ctx.save_for_backward(input, bias, crop, threshold) + +# # Forward pass logic as provided +# if threshold == 0: +# output = input +# else: +# diff = input - bias +# output = torch.where(torch.abs(diff) < crop, bias, input) + +# return output + +# @staticmethod +# def backward(ctx, grad_output): +# # Ignore the torch.where() operation in backward; simply pass through grad_output +# input, bias, crop, threshold = ctx.saved_tensors +# grad_input = grad_output # As if output = input in backward +# return grad_input, None, None, None # Only input has gradients; bias, crop, threshold do not + + +# # MeanSparse_imagenet module with custom backward behavior +# class MeanSparse_imagenet(nn.Module): +# def __init__(self, in_planes, momentum=0.1): +# super(MeanSparse_imagenet, self).__init__() + +# self.register_buffer('momentum', torch.tensor(momentum)) +# self.register_buffer('epsilon', torch.tensor(1.0e-10)) + +# self.register_buffer('running_mean', torch.zeros(in_planes)) +# self.register_buffer('running_var', torch.zeros(in_planes)) +# self.register_buffer('threshold', torch.tensor(0.0)) + +# self.register_buffer('flag_update_statistics', torch.tensor(0)) +# self.register_buffer('batch_num', torch.tensor(0.0)) + +# def forward(self, input): +# if input.shape[1] == self.running_mean.shape[0]: +# # Update running statistics if flag is set +# if self.flag_update_statistics: +# self.running_mean += (torch.mean(input.detach(), dim=(0, 2, 3)) / self.batch_num) +# self.running_var += (torch.var(input.detach(), dim=(0, 2, 3)) / self.batch_num) + +# bias = self.running_mean.view(1, self.running_mean.shape[0], 1, 1) +# crop = self.threshold * torch.sqrt(self.running_var).view(1, self.running_var.shape[0], 1, 1) + +# # Use the custom autograd function +# output = MeanSparseFunctionImagenet.apply(input, bias, crop, self.threshold) + +# else: +# # Update running statistics if flag is set +# if self.flag_update_statistics: +# self.running_mean += (torch.mean(input.detach(), dim=(0, 1, 2)) / self.batch_num) +# self.running_var += (torch.var(input.detach(), dim=(0, 1, 2)) / self.batch_num) + +# bias = self.running_mean.view(1, 1, 1, self.running_mean.shape[0]) +# crop = self.threshold * torch.sqrt(self.running_var).view(1, 1, 1, self.running_var.shape[0]) + +# # Use the custom autograd function +# output = MeanSparseFunctionImagenet.apply(input, bias, crop, self.threshold) + +# return output + + +### original class MeanSparse_imagenet(nn.Module): def __init__(self, in_planes, momentum=0.1): super(MeanSparse_imagenet, self).__init__() @@ -131,10 +163,32 @@ def forward(self, input): return output -def get_sparse_model(model, dataset='cifar-10'): - if dataset == 'cifar-10': - add_custom_layer_cifar_10(model, MeanSparse_cifar_10, parent_path='', prev_features=None) - elif dataset == 'imagenet': - add_custom_layer_imagenet(model, MeanSparse_imagenet, parent_path='', prev_features=None) - + +def get_sparse_model(model, dataset): + if dataset == 'cifar-10-Linf': + if model == 'wrn_94_16': + model = MeanSparse_DMWideResNet(num_classes=10, depth=94, width=16, activation_fn=nn.SiLU, + mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)) + elif model == 'ra_wrn_70_16': + model = NormalizedWideResNet(mean = (0.4914, 0.4822, 0.4465), std = (0.2471, 0.2435, 0.2616), + stem_width = 96, depth = [30, 31, 10], stage_width = [216, 432, 864], + groups = [1, 1, 1], activation_fn = torch.nn.modules.activation.SiLU, se_ratio = 0.25, + se_activation = torch.nn.modules.activation.ReLU, se_order = 2, num_classes = 10, + padding = 0, num_input_channels = 3) + + elif dataset == 'imagenet-Linf': + if model == 'swin-l': + model = swin_large_patch4_window7_224_with_MeanSparse( + pretrained=False, pretrained_cfg=None, pretrained_cfg_overlay=None) + else: + add_custom_layer_imagenet(model, MeanSparse_imagenet, parent_path='', prev_features=None) + elif dataset == 'cifar-100-Linf': + model = s_wrn_70_16(num_classes=100, depth=70, width=16, activation_fn=nn.SiLU, + mean=(0.5071, 0.4865, 0.4409), std=(0.2673, 0.2564, 0.2762)) + elif dataset == 'cifar-10-L2': + model = s_wrn_70_16(num_classes=10, depth=70, width=16, activation_fn=nn.SiLU, + mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)) + else: + raise ValueError(f'Unknown dataset: {dataset}.') + return model \ No newline at end of file diff --git a/robustbench/model_zoo/cifar10.py b/robustbench/model_zoo/cifar10.py index 885d228..ac6a5a0 100644 --- a/robustbench/model_zoo/cifar10.py +++ b/robustbench/model_zoo/cifar10.py @@ -903,10 +903,20 @@ def forward(self, x): 'model': lambda: WideResNet(depth=34, widen_factor=20), 'gdrive_id': '1-IbKAGtp79tAEm59N0i8QMvMZ3nxSD2-', }), - ('Amini2024MeanSparse', { + # ('Amini2024MeanSparse', { + # 'model': lambda: get_sparse_model( + # get_robustarch_model('ra_wrn70_16'), dataset='cifar-10'), # TODO: check device calls. + # 'gdrive_id': '1-4XSB3Ir-pn5gnEJ4TbUdkwBclIay8-q', + # }), + ('Amini2024MeanSparse_Ra_WRN_70_16', { 'model': lambda: get_sparse_model( - get_robustarch_model('ra_wrn70_16'), dataset='cifar-10'), # TODO: check device calls. - 'gdrive_id': '1-4XSB3Ir-pn5gnEJ4TbUdkwBclIay8-q', + 'ra_wrn_70_16', dataset='cifar-10-Linf'), # TODO: check device calls. + 'gdrive_id': '1-JdK480cLUrVCaNKCSBsmy7y59kGYF19', + }), + ('Amini2024MeanSparse_S-WRN-94-16', { + 'model': lambda: get_sparse_model( + 'wrn_94_16', dataset='cifar-10-Linf'), # TODO: check device calls. + 'gdrive_id': '1-GoAnBP6K6uwAzJbe4liOSfN_sZVYTln', }), ('Bartoldson2024Adversarial_WRN-94-16', { 'model': @@ -1076,6 +1086,11 @@ def forward(self, x): 'gdrive_id': '101UyURyte05tJLA9VFRBN6UDLyn-0sJw' }), + ('Amini2024MeanSparse_S-WRN-70-16', { + 'model': lambda: get_sparse_model( + 'wrn_70_16', dataset='cifar-10-L2'), # TODO: check device calls. + 'gdrive_id': '1kBtoDnMwbVmsFl4YxlDhuJpRiE-gCKvI', #'1-IxGA3BI1nK5iMkgXbSoaDZF5c7vhe8b', + }), ]) common_corruptions = OrderedDict([ diff --git a/robustbench/model_zoo/cifar100.py b/robustbench/model_zoo/cifar100.py index 9e114b3..9413826 100644 --- a/robustbench/model_zoo/cifar100.py +++ b/robustbench/model_zoo/cifar100.py @@ -15,6 +15,7 @@ from robustbench.model_zoo.architectures import xcit from robustbench.model_zoo.architectures.comp_model import get_composite_model, \ get_nonlin_mixed_classifier +from robustbench.model_zoo.architectures.sparsified_model import get_sparse_model class Chen2020EfficientNet(WideResNet): @@ -521,6 +522,10 @@ def forward(self, x): 'model': lambda: WideResNet(depth=34, widen_factor=10, num_classes=100), 'gdrive_id': '1-N5a0Z9o-8-oIrF4_EaVFVOOPI5-D-H2', }), + ('Amini2024MeanSparse_S-WRN-70-16', { + 'model': lambda: get_sparse_model('wrn_70_16', dataset='cifar-100-Linf'), # TODO: check device calls. + 'gdrive_id': '1-JaGaVHX3nBiXhDmZSB6WgD2xzbeO2wD', + }), ]) common_corruptions = OrderedDict([ diff --git a/robustbench/model_zoo/imagenet.py b/robustbench/model_zoo/imagenet.py index 3664368..349c02b 100644 --- a/robustbench/model_zoo/imagenet.py +++ b/robustbench/model_zoo/imagenet.py @@ -151,13 +151,30 @@ 'gdrive_id': '1-dUFdvDBflqMsMLjZv3wlPJTm-Jm7net', 'preprocessing': 'Res224', }), - ('Amini2024MeanSparse', { + ('Amini2024MeanSparse_ConvNeXt-L', { 'model': lambda: get_sparse_model( normalize_model(timm.create_model('convnext_large', pretrained=False), - mu, sigma), dataset='imagenet'), + mu, sigma), dataset='imagenet-Linf'), 'gdrive_id': '1-LUMPqauSx68bPmZFIuklFoJ6NmBhu7A', 'preprocessing': 'BicubicRes256Crop224', }), + ('Amini2024MeanSparse_Swin-L', { + 'model': lambda: get_sparse_model('swin-l', dataset='imagenet-Linf'), + 'gdrive_id': '1-KmvrDXd_kcJS-TcNmtHP5NInQ5I4lgS', + 'preprocessing': 'BicubicRes256Crop224', + }), + ('RodriguezMunoz2024Characterizing_Swin-B', { + 'model': lambda: normalize_model(timm.create_model( + 'swin_base_patch4_window7_224', pretrained=False), mu, sigma), + 'gdrive_id': '1-BSUjoFXx3PP-TfeE5fjbofO2lLyUf56', # '1-9h_4PImbQM3XhKBcnqTh4PHxz9rM6vr', + 'preprocessing': 'BicubicRes256Crop224' + }), + ('RodriguezMunoz2024Characterizing_Swin-L', { + 'model': lambda: normalize_model(timm.create_model( + 'swin_large_patch4_window7_224', pretrained=False), mu, sigma), + 'gdrive_id': '1-Dc9WhPU2wv4OMskLo1U57n5O8VbpNXv', # '1-DoJoTiPynr39AFNsEyhOej4rKPL3xqT' + 'preprocessing': 'BicubicRes256Crop224' + }), ]) common_corruptions = OrderedDict( diff --git a/robustbench/utils.py b/robustbench/utils.py index 14ff757..491557a 100644 --- a/robustbench/utils.py +++ b/robustbench/utils.py @@ -183,6 +183,8 @@ def load_model(model_name: str, 'Liu2023Comprehensive_Swin-B', 'Liu2023Comprehensive_Swin-L', 'Mo2022When_Swin-B', + 'RodriguezMunoz2024Characterizing_Swin-B', + 'RodriguezMunoz2024Characterizing_Swin-L', ]: try: from timm.models.swin_transformer import checkpoint_filter_fn From c288dc46b8eaa2e726ebde8e56a674c111fcc1b8 Mon Sep 17 00:00:00 2001 From: fra31 Date: Fri, 20 Dec 2024 14:19:08 +0000 Subject: [PATCH 2/3] add model info, RodriguezMunoz2024Characterizing models --- .../L2/Amini2024MeanSparse_S-WRN-70-16.json | 17 +++++++++++++++++ .../Linf/Amini2024MeanSparse_Ra_WRN_70_16.json | 17 +++++++++++++++++ .../Linf/Amini2024MeanSparse_S-WRN-94-16.json | 17 +++++++++++++++++ .../Linf/Amini2024MeanSparse_S-WRN-70-16.json | 17 +++++++++++++++++ ...json => Amini2024MeanSparse_ConvNeXt-L.json} | 5 +++-- .../Linf/Amini2024MeanSparse_Swin-L.json | 17 +++++++++++++++++ ...RodriguezMunoz2024Characterizing_Swin-B.json | 15 +++++++++++++++ ...RodriguezMunoz2024Characterizing_Swin-L.json | 15 +++++++++++++++ 8 files changed, 118 insertions(+), 2 deletions(-) create mode 100644 model_info/cifar10/L2/Amini2024MeanSparse_S-WRN-70-16.json create mode 100644 model_info/cifar10/Linf/Amini2024MeanSparse_Ra_WRN_70_16.json create mode 100644 model_info/cifar10/Linf/Amini2024MeanSparse_S-WRN-94-16.json create mode 100644 model_info/cifar100/Linf/Amini2024MeanSparse_S-WRN-70-16.json rename model_info/imagenet/Linf/{Amini2024MeanSparse.json => Amini2024MeanSparse_ConvNeXt-L.json} (76%) create mode 100644 model_info/imagenet/Linf/Amini2024MeanSparse_Swin-L.json create mode 100644 model_info/imagenet/Linf/RodriguezMunoz2024Characterizing_Swin-B.json create mode 100644 model_info/imagenet/Linf/RodriguezMunoz2024Characterizing_Swin-L.json diff --git a/model_info/cifar10/L2/Amini2024MeanSparse_S-WRN-70-16.json b/model_info/cifar10/L2/Amini2024MeanSparse_S-WRN-70-16.json new file mode 100644 index 0000000..44f8670 --- /dev/null +++ b/model_info/cifar10/L2/Amini2024MeanSparse_S-WRN-70-16.json @@ -0,0 +1,17 @@ +{ + "link": "https://arxiv.org/abs/2406.05927", + "name": "MeanSparse: Post-Training Robustness Enhancement Through Mean-Centered Feature Sparsification", + "authors": "Sajjad Amini, Mohammadreza Teymoorianfard, Shiqing Ma, Amir Houmansadr", + "additional_data": true, + "number_forward_passes": 1, + "dataset": "cifar10", + "venue": "arXiv, Jun 2024", + "architecture": "MeanSparse WideResNet-70-16", + "eps": "0.5", + "clean_acc": "95.51", + "reported": "87.28", + "autoattack_acc": "87.28", + "external": "84.33", + "footnote": "It adds the MeanSparse operator to the adversarially trained model Wang2023Better_WRN-70-16. 84.33% robust accuracy is due to APGD (both versions) with BPDA.", + "unreliable": false +} \ No newline at end of file diff --git a/model_info/cifar10/Linf/Amini2024MeanSparse_Ra_WRN_70_16.json b/model_info/cifar10/Linf/Amini2024MeanSparse_Ra_WRN_70_16.json new file mode 100644 index 0000000..f0bee47 --- /dev/null +++ b/model_info/cifar10/Linf/Amini2024MeanSparse_Ra_WRN_70_16.json @@ -0,0 +1,17 @@ +{ + "link": "https://arxiv.org/abs/2406.05927", + "name": "MeanSparse: Post-Training Robustness Enhancement Through Mean-Centered Feature Sparsification", + "authors": "Sajjad Amini, Mohammadreza Teymoorianfard, Shiqing Ma, Amir Houmansadr", + "additional_data": true, + "number_forward_passes": 1, + "dataset": "cifar10", + "venue": "arXiv, Jun 2024", + "architecture": "MeanSparse RaWideResNet-70-16", + "eps": "8/255", + "clean_acc": "93.24", + "reported": "72.08", + "autoattack_acc": "72.08", + "external": "68.94", + "footnote": "It adds the MeanSparse operator to the adversarially trained model Peng2023Robust. 68.94% robust accuracy is due to APGD (both versions) with BPDA.", + "unreliable": false +} \ No newline at end of file diff --git a/model_info/cifar10/Linf/Amini2024MeanSparse_S-WRN-94-16.json b/model_info/cifar10/Linf/Amini2024MeanSparse_S-WRN-94-16.json new file mode 100644 index 0000000..2f47662 --- /dev/null +++ b/model_info/cifar10/Linf/Amini2024MeanSparse_S-WRN-94-16.json @@ -0,0 +1,17 @@ +{ + "link": "https://arxiv.org/abs/2406.05927", + "name": "MeanSparse: Post-Training Robustness Enhancement Through Mean-Centered Feature Sparsification", + "authors": "Sajjad Amini, Mohammadreza Teymoorianfard, Shiqing Ma, Amir Houmansadr", + "additional_data": true, + "number_forward_passes": 1, + "dataset": "cifar10", + "venue": "arXiv, Jun 2024", + "architecture": "MeanSparse WideResNet-94-16", + "eps": "8/255", + "clean_acc": "93.60", + "reported": "75.28", + "autoattack_acc": "75.28", + "external": "73.10", + "footnote": "It adds the MeanSparse operator to the adversarially trained model Bartoldson2024Adversarial_WRN-94-16. 73.10% robust accuracy is due to APGD (both versions) with BPDA.", + "unreliable": false +} \ No newline at end of file diff --git a/model_info/cifar100/Linf/Amini2024MeanSparse_S-WRN-70-16.json b/model_info/cifar100/Linf/Amini2024MeanSparse_S-WRN-70-16.json new file mode 100644 index 0000000..144e4ea --- /dev/null +++ b/model_info/cifar100/Linf/Amini2024MeanSparse_S-WRN-70-16.json @@ -0,0 +1,17 @@ +{ + "link": "https://arxiv.org/abs/2406.05927", + "name": "MeanSparse: Post-Training Robustness Enhancement Through Mean-Centered Feature Sparsification", + "authors": "Sajjad Amini, Mohammadreza Teymoorianfard, Shiqing Ma, Amir Houmansadr", + "additional_data": true, + "number_forward_passes": 1, + "dataset": "cifar100", + "venue": "arXiv, Jun 2024", + "architecture": "MeanSparse WideResNet-70-16", + "eps": "8/255", + "clean_acc": "75.13", + "reported": "44.78", + "autoattack_acc": "44.78", + "external": "42.25", + "footnote": "It adds the MeanSparse operator to the adversarially trained model Wang2023Better_WRN-70-16. 42.25% robust accuracy is due to APGD (both versions) with BPDA.", + "unreliable": false +} \ No newline at end of file diff --git a/model_info/imagenet/Linf/Amini2024MeanSparse.json b/model_info/imagenet/Linf/Amini2024MeanSparse_ConvNeXt-L.json similarity index 76% rename from model_info/imagenet/Linf/Amini2024MeanSparse.json rename to model_info/imagenet/Linf/Amini2024MeanSparse_ConvNeXt-L.json index 5ae866f..b04d32d 100644 --- a/model_info/imagenet/Linf/Amini2024MeanSparse.json +++ b/model_info/imagenet/Linf/Amini2024MeanSparse_ConvNeXt-L.json @@ -8,9 +8,10 @@ "venue": "arXiv, Jun 2024", "architecture": "MeanSparse ConvNeXt-L", "eps": "4/255", - "clean_acc": "77.96", + "clean_acc": "77.92", "reported": "59.64", "autoattack_acc": "59.64", - "footnote": "It adds the MeanSparse operator to the adversarially trained models.", + "external": "58.22", + "footnote": "It adds the MeanSparse operator to the adversarially trained models Liu2023Comprehensive_ConvNeXt-L. 58.22% robust accuracy is due to APGD (both versions) with BPDA.", "unreliable": false } \ No newline at end of file diff --git a/model_info/imagenet/Linf/Amini2024MeanSparse_Swin-L.json b/model_info/imagenet/Linf/Amini2024MeanSparse_Swin-L.json new file mode 100644 index 0000000..d043be5 --- /dev/null +++ b/model_info/imagenet/Linf/Amini2024MeanSparse_Swin-L.json @@ -0,0 +1,17 @@ +{ + "link": "https://arxiv.org/abs/2406.05927", + "name": "MeanSparse: Post-Training Robustness Enhancement Through Mean-Centered Feature Sparsification", + "authors": "Sajjad Amini, Mohammadreza Teymoorianfard, Shiqing Ma, Amir Houmansadr", + "additional_data": false, + "number_forward_passes": 1, + "dataset": "imagenet", + "venue": "arXiv, Jun 2024", + "architecture": "MeanSparse Swin-L", + "eps": "4/255", + "clean_acc": "78.80", + "reported": "62.12", + "autoattack_acc": "62.12", + "external": "58.92", + "footnote": "It adds the MeanSparse operator to the adversarially trained models Liu2023Comprehensive_Swin-L. 58.92% robust accuracy is due to APGD (both versions) with BPDA.", + "unreliable": false +} \ No newline at end of file diff --git a/model_info/imagenet/Linf/RodriguezMunoz2024Characterizing_Swin-B.json b/model_info/imagenet/Linf/RodriguezMunoz2024Characterizing_Swin-B.json new file mode 100644 index 0000000..4fdb545 --- /dev/null +++ b/model_info/imagenet/Linf/RodriguezMunoz2024Characterizing_Swin-B.json @@ -0,0 +1,15 @@ +{ + "link": "https://arxiv.org/abs/2409.20139", + "name": "Characterizing Model Robustness via Natural Input Gradients", + "authors": "Adrián Rodríguez-Muñoz, Tongzhou Wang, Antonio Torralba", + "additional_data": false, + "number_forward_passes": 1, + "dataset": "imagenet", + "venue": "arXiv, Sep 2024", + "architecture": "Swin-B", + "eps": "4/255", + "clean_acc": "77.76", + "reported": "51.56", + "autoattack_acc": "51.56", + "unreliable": false +} diff --git a/model_info/imagenet/Linf/RodriguezMunoz2024Characterizing_Swin-L.json b/model_info/imagenet/Linf/RodriguezMunoz2024Characterizing_Swin-L.json new file mode 100644 index 0000000..9529714 --- /dev/null +++ b/model_info/imagenet/Linf/RodriguezMunoz2024Characterizing_Swin-L.json @@ -0,0 +1,15 @@ +{ + "link": "https://arxiv.org/abs/2409.20139", + "name": "Characterizing Model Robustness via Natural Input Gradients", + "authors": "Adrián Rodríguez-Muñoz, Tongzhou Wang, Antonio Torralba", + "additional_data": false, + "number_forward_passes": 1, + "dataset": "imagenet", + "venue": "arXiv, Sep 2024", + "architecture": "Swin-L", + "eps": "4/255", + "clean_acc": "79.36", + "reported": "53.82", + "autoattack_acc": "53.82", + "unreliable": false +} From 6b78743917844108c1bfbe97696f3db6c629204d Mon Sep 17 00:00:00 2001 From: fra31 Date: Fri, 20 Dec 2024 15:02:56 +0000 Subject: [PATCH 3/3] add evaluations from MALT (Melamed et al., 2024) --- model_info/cifar10/Linf/Amini2024MeanSparse.json | 16 ---------------- .../cifar100/Linf/Wang2023Better_WRN-28-10.json | 3 ++- .../cifar100/Linf/Wang2023Better_WRN-70-16.json | 3 ++- 3 files changed, 4 insertions(+), 18 deletions(-) delete mode 100644 model_info/cifar10/Linf/Amini2024MeanSparse.json diff --git a/model_info/cifar10/Linf/Amini2024MeanSparse.json b/model_info/cifar10/Linf/Amini2024MeanSparse.json deleted file mode 100644 index 8c67b1a..0000000 --- a/model_info/cifar10/Linf/Amini2024MeanSparse.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "link": "https://arxiv.org/abs/2406.05927", - "name": "MeanSparse: Post-Training Robustness Enhancement Through Mean-Centered Feature Sparsification", - "authors": "Sajjad Amini, Mohammadreza Teymoorianfard, Shiqing Ma, Amir Houmansadr", - "additional_data": true, - "number_forward_passes": 1, - "dataset": "cifar10", - "venue": "arXiv, Jun 2024", - "architecture": "MeanSparse RaWideResNet-70-16", - "eps": "8/255", - "clean_acc": "93.24", - "reported": "72.08", - "autoattack_acc": "72.08", - "footnote": "It adds the MeanSparse operator to the adversarially trained models.", - "unreliable": false -} \ No newline at end of file diff --git a/model_info/cifar100/Linf/Wang2023Better_WRN-28-10.json b/model_info/cifar100/Linf/Wang2023Better_WRN-28-10.json index 98e138b..a38a511 100644 --- a/model_info/cifar100/Linf/Wang2023Better_WRN-28-10.json +++ b/model_info/cifar100/Linf/Wang2023Better_WRN-28-10.json @@ -11,6 +11,7 @@ "clean_acc": "72.58", "reported": "38.83", "autoattack_acc": "38.83", - "footnote": "It uses additional 50M synthetic images in training.", + "external": "38.77", + "footnote": "It uses additional 50M synthetic images in training. 38.77% robust accuracy is given by MALT (Melamed et al., 2024).", "unreliable": false } \ No newline at end of file diff --git a/model_info/cifar100/Linf/Wang2023Better_WRN-70-16.json b/model_info/cifar100/Linf/Wang2023Better_WRN-70-16.json index 9092674..ab35feb 100644 --- a/model_info/cifar100/Linf/Wang2023Better_WRN-70-16.json +++ b/model_info/cifar100/Linf/Wang2023Better_WRN-70-16.json @@ -11,6 +11,7 @@ "clean_acc": "75.22", "reported": "42.67", "autoattack_acc": "42.67", - "footnote": "It uses additional 50M synthetic images in training.", + "external": "42.66", + "footnote": "It uses additional 50M synthetic images in training. 42.66% robust accuracy is given by MALT (Melamed et al., 2024).", "unreliable": false } \ No newline at end of file