Skip to content

Commit

Permalink
Add parameter type restrictions
Browse files Browse the repository at this point in the history
  • Loading branch information
rikonaka committed Dec 2, 2023
1 parent 6ba76d4 commit 9762cab
Showing 1 changed file with 48 additions and 42 deletions.
90 changes: 48 additions & 42 deletions torchattacks/attack.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import time
from typing import Optional, Union, List, Dict, Tuple
from collections import OrderedDict

import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

Expand All @@ -25,7 +27,7 @@ class Attack(object):
To change this, please see `set_model_training_mode`.
"""

def __init__(self, name, model):
def __init__(self, name: str, model: torch.nn.Module) -> None:
r"""
Initializes internal attack state.
Expand Down Expand Up @@ -61,34 +63,34 @@ def __init__(self, name, model):
self._batchnorm_training = False
self._dropout_training = False

def forward(self, inputs, labels=None, *args, **kwargs):
def forward(self, inputs: torch.nn.Module, labels: Optional[torch.nn.Module], *args, **kwargs) -> None:
r"""
It defines the computation performed at every call.
Should be overridden by all subclasses.
"""
raise NotImplementedError

@wrapper_method
def set_model(self, model):
def set_model(self, model: torch.nn.Module) -> None:
self.model = model
self.model_name = model.__class__.__name__

def get_logits(self, inputs, labels=None, *args, **kwargs):
def get_logits(self, inputs: torch.Tensor, *args, **kwargs) -> torch.tensor:
if self._normalization_applied is False:
inputs = self.normalize(inputs)
logits = self.model(inputs)
return logits

@wrapper_method
def _set_normalization_applied(self, flag):
def _set_normalization_applied(self, flag: bool) -> None:
self._normalization_applied = flag

@wrapper_method
def set_device(self, device):
def set_device(self, device: Union[str, torch.device]) -> None:
self.device = device

@wrapper_method
def _set_rmodel_normalization_used(self, model):
def _set_rmodel_normalization_used(self, model: torch.nn.Module) -> None:
r"""
Set attack normalization for MAIR [https://github.com/Harry24k/MAIR].
Expand All @@ -104,7 +106,7 @@ def _set_rmodel_normalization_used(self, model):
self.set_normalization_used(mean, std)

@wrapper_method
def set_normalization_used(self, mean, std):
def set_normalization_used(self, mean: Union[np.array, torch.Tensor, List, Tuple], std: Union[np.array, torch.Tensor, List, Tuple]) -> None:
self.normalization_used = {}
n_channels = len(mean)
mean = torch.tensor(mean).reshape(1, n_channels, 1, 1)
Expand All @@ -113,25 +115,25 @@ def set_normalization_used(self, mean, std):
self.normalization_used["std"] = std
self._set_normalization_applied(True)

def normalize(self, inputs):
def normalize(self, inputs: torch.Tensor) -> torch.Tensor:
mean = self.normalization_used["mean"].to(inputs.device)
std = self.normalization_used["std"].to(inputs.device)
return (inputs - mean) / std

def inverse_normalize(self, inputs):
def inverse_normalize(self, inputs: torch.Tensor) -> torch.Tensor:
mean = self.normalization_used["mean"].to(inputs.device)
std = self.normalization_used["std"].to(inputs.device)
return inputs * std + mean

def get_mode(self):
def get_mode(self) -> str:
r"""
Get attack mode.
"""
return self.attack_mode

@wrapper_method
def set_mode_default(self):
def set_mode_default(self) -> None:
r"""
Set attack mode as default mode.
Expand All @@ -141,7 +143,7 @@ def set_mode_default(self):
print("Attack mode is changed to 'default.'")

@wrapper_method
def _set_mode_targeted(self, mode, quiet):
def _set_mode_targeted(self, mode: str, quiet: bool) -> None:
if "targeted" not in self.supported_mode:
raise ValueError("Targeted mode is not supported.")
self.targeted = True
Expand All @@ -150,7 +152,7 @@ def _set_mode_targeted(self, mode, quiet):
print("Attack mode is changed to '%s'." % mode)

@wrapper_method
def set_mode_targeted_by_function(self, target_map_function, quiet=False):
def set_mode_targeted_by_function(self, target_map_function, quiet: bool = False) -> None:
r"""
Set attack mode as targeted.
Expand All @@ -165,7 +167,7 @@ def set_mode_targeted_by_function(self, target_map_function, quiet=False):
self._target_map_function = target_map_function

@wrapper_method
def set_mode_targeted_random(self, quiet=False):
def set_mode_targeted_random(self, quiet: bool = False) -> None:
r"""
Set attack mode as targeted with random labels.
Expand All @@ -177,13 +179,12 @@ def set_mode_targeted_random(self, quiet=False):
self._target_map_function = self.get_random_target_label

@wrapper_method
def set_mode_targeted_least_likely(self, kth_min=1, quiet=False):
def set_mode_targeted_least_likely(self, kth_min: int = 1, quiet: bool = False) -> None:
r"""
Set attack mode as targeted with least likely labels.
Arguments:
kth_min (str): label with the k-th smallest probability used as target labels. (Default: 1)
num_classses (str): number of classes. (Default: False)
kth_min (int): label with the k-th smallest probability used as target labels. (Default: 1)
"""
self._set_mode_targeted("targeted(least-likely)", quiet)
Expand All @@ -192,7 +193,7 @@ def set_mode_targeted_least_likely(self, kth_min=1, quiet=False):
self._target_map_function = self.get_least_likely_label

@wrapper_method
def set_mode_targeted_by_label(self, quiet=False):
def set_mode_targeted_by_label(self, quiet: bool = False) -> None:
r"""
Set attack mode as targeted.
Expand All @@ -207,8 +208,11 @@ def set_mode_targeted_by_label(self, quiet=False):

@wrapper_method
def set_model_training_mode(
self, model_training=False, batchnorm_training=False, dropout_training=False
):
self,
model_training: bool = False,
batchnorm_training: bool = False,
dropout_training: bool = False
) -> None:
r"""
Set training mode during attack process.
Expand Down Expand Up @@ -247,13 +251,13 @@ def _recover_model_mode(self, given_training):
def save(
self,
data_loader,
save_path=None,
verbose=True,
return_verbose=False,
save_predictions=False,
save_clean_inputs=False,
save_path: Union[str, None] = None,
verbose: bool = True,
return_verbose: bool = False,
save_predictions: bool = False,
save_clean_inputs: bool = False,
save_type="float",
):
) -> Tuple[float, torch.tensor, float]:
r"""
Save adversarial inputs as torch.tensor from given torch.utils.data.DataLoader.
Expand Down Expand Up @@ -370,7 +374,7 @@ def save(
return rob_acc, l2, elapsed_time

@staticmethod
def to_type(inputs, type):
def to_type(inputs, type: str) -> torch.Tensor:
r"""
Return inputs as int if float is given.
"""
Expand All @@ -385,11 +389,12 @@ def to_type(inputs, type):
):
return inputs.float() / 255
else:
raise ValueError(type + " is not a valid type. [Options: float, int]")
raise ValueError(
type + " is not a valid type. [Options: float, int]")
return inputs

@staticmethod
def _save_print(progress, rob_acc, l2, elapsed_time, end):
def _save_print(progress: float, rob_acc: float, l2: float, elapsed_time: float, end: Union[str, None]) -> None:
print(
"- Save progress: %2.2f %% / Robust accuracy: %2.2f %% / L2: %1.5f (%2.3f it/s) \t"
% (progress, rob_acc, l2, elapsed_time),
Expand All @@ -398,13 +403,13 @@ def _save_print(progress, rob_acc, l2, elapsed_time, end):

@staticmethod
def load(
load_path,
batch_size=128,
shuffle=False,
load_path: str,
batch_size: int = 128,
shuffle: bool = False,
normalize=None,
load_predictions=False,
load_clean_inputs=False,
):
load_predictions: bool = False,
load_clean_inputs: bool = False,
) -> DataLoader:
save_dict = torch.load(load_path)
keys = ["adv_inputs", "labels"]

Expand All @@ -431,14 +436,15 @@ def load(
) / std # nopep8

adv_data = TensorDataset(*[save_dict[key] for key in keys])
adv_loader = DataLoader(adv_data, batch_size=batch_size, shuffle=shuffle)
adv_loader = DataLoader(
adv_data, batch_size=batch_size, shuffle=shuffle)
print(
"Data is loaded in the following order: [%s]" % (", ".join(keys))
) # nopep8
return adv_loader

@torch.no_grad()
def get_output_with_eval_nograd(self, inputs):
def get_output_with_eval_nograd(self, inputs: torch.Tensor) -> torch.Tensor:
given_training = self.model.training
if given_training:
self.model.eval()
Expand All @@ -447,7 +453,7 @@ def get_output_with_eval_nograd(self, inputs):
self.model.train()
return outputs

def get_target_label(self, inputs, labels=None):
def get_target_label(self, inputs: torch.Tensor, labels: Union[torch.Tensor, None] = None) -> torch.Tensor:
r"""
Function for changing the attack mode.
Return input labels.
Expand All @@ -463,7 +469,7 @@ def get_target_label(self, inputs, labels=None):
return target_labels

@torch.no_grad()
def get_least_likely_label(self, inputs, labels=None):
def get_least_likely_label(self, inputs: torch.Tensor, labels: Union[torch.Tensor, None] = None) -> torch.Tensor:
outputs = self.get_output_with_eval_nograd(inputs)
if labels is None:
_, labels = torch.max(outputs, dim=1)
Expand All @@ -479,7 +485,7 @@ def get_least_likely_label(self, inputs, labels=None):
return target_labels.long().to(self.device)

@torch.no_grad()
def get_random_target_label(self, inputs, labels=None):
def get_random_target_label(self, inputs: torch.Tensor, labels: Union[torch.Tensor, None] = None) -> torch.Tensor:
outputs = self.get_output_with_eval_nograd(inputs)
if labels is None:
_, labels = torch.max(outputs, dim=1)
Expand All @@ -494,7 +500,7 @@ def get_random_target_label(self, inputs, labels=None):

return target_labels.long().to(self.device)

def __call__(self, inputs, labels=None, *args, **kwargs):
def __call__(self, inputs: torch.Tensor, labels: Union[torch.Tensor, None] = None, *args, **kwargs) -> torch.Tensor:
given_training = self.model.training
self._change_model_mode(given_training)

Expand Down

0 comments on commit 9762cab

Please sign in to comment.