Skip to content

Commit

Permalink
Merge pull request #2207 from Trusted-AI/dev_1.15.0
Browse files Browse the repository at this point in the history
Update to ART 1.15.0
  • Loading branch information
beat-buesser authored Jun 30, 2023
2 parents 7691b39 + e3708eb commit aaa5bce
Show file tree
Hide file tree
Showing 68 changed files with 5,578 additions and 1,065 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci-pytorch-object-detectors.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ jobs:
run: pytest --cov-report=xml --cov=art --cov-append -q -vv tests/estimators/object_detection/test_pytorch_object_detector.py --framework=pytorch --durations=0
- name: Run Test Action - test_pytorch_faster_rcnn
run: pytest --cov-report=xml --cov=art --cov-append -q -vv tests/estimators/object_detection/test_pytorch_faster_rcnn.py --framework=pytorch --durations=0
- name: Run Test Action - test_pytorch_detection_transformer
run: pytest --cov-report=xml --cov=art --cov-append -q -vv tests/estimators/object_detection/test_pytorch_detection_transformer.py --framework=pytorch --durations=0
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
with:
Expand Down
4 changes: 2 additions & 2 deletions art/attacks/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,10 @@ def __init__(self):
@abc.abstractmethod
def poison(
self,
x: np.ndarray,
x: Union[np.ndarray, List[np.ndarray]],
y: List[Dict[str, np.ndarray]],
**kwargs,
) -> Tuple[np.ndarray, List[Dict[str, np.ndarray]]]:
) -> Tuple[Union[np.ndarray, List[np.ndarray]], List[Dict[str, np.ndarray]]]:
"""
Generate poisoning examples and return them as an array. This method should be overridden by all concrete
poisoning attack implementations.
Expand Down
32 changes: 23 additions & 9 deletions art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,9 +575,9 @@ def __getitem__(self, idx):
img = torch.from_numpy(self.x[idx])

target = {}
target["boxes"] = torch.from_numpy(y[idx]["boxes"])
target["labels"] = torch.from_numpy(y[idx]["labels"])
target["scores"] = torch.from_numpy(y[idx]["scores"])
target["boxes"] = torch.from_numpy(self.y[idx]["boxes"])
target["labels"] = torch.from_numpy(self.y[idx]["labels"])
target["scores"] = torch.from_numpy(self.y[idx]["scores"])
mask_i = torch.from_numpy(self.mask[idx])

return img, target, mask_i
Expand All @@ -602,19 +602,33 @@ def __getitem__(self, idx):
if isinstance(target, torch.Tensor):
target = target.to(self.estimator.device)
else:
target["boxes"] = target["boxes"].to(self.estimator.device)
target["labels"] = target["labels"].to(self.estimator.device)
target["scores"] = target["scores"].to(self.estimator.device)
targets = []
for idx in range(target["boxes"].shape[0]):
targets.append(
{
"boxes": target["boxes"][idx].to(self.estimator.device),
"labels": target["labels"][idx].to(self.estimator.device),
"scores": target["scores"][idx].to(self.estimator.device),
}
)
target = targets
_ = self._train_step(images=images, target=target, mask=None)
else:
for images, target, mask_i in data_loader:
images = images.to(self.estimator.device)
if isinstance(target, torch.Tensor):
target = target.to(self.estimator.device)
else:
target["boxes"] = target["boxes"].to(self.estimator.device)
target["labels"] = target["labels"].to(self.estimator.device)
target["scores"] = target["scores"].to(self.estimator.device)
targets = []
for idx in range(target["boxes"].shape[0]):
targets.append(
{
"boxes": target["boxes"][idx].to(self.estimator.device),
"labels": target["labels"][idx].to(self.estimator.device),
"scores": target["scores"][idx].to(self.estimator.device),
}
)
target = targets
mask_i = mask_i.to(self.estimator.device)
_ = self._train_step(images=images, target=target, mask=mask_i)

Expand Down
3 changes: 2 additions & 1 deletion art/attacks/evasion/auto_conjugate_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def __call__(self, y_true: tf.Tensor, y_pred: tf.Tensor, *args, **kwargs) -> tf.
nb_classes=estimator.nb_classes,
input_shape=estimator.input_shape,
loss_object=_loss_object_tf,
train_step=estimator._train_step,
optimizer=estimator.optimizer,
train_step=estimator.train_step,
channels_first=estimator.channels_first,
clip_values=estimator.clip_values,
preprocessing_defences=estimator.preprocessing_defences,
Expand Down
3 changes: 2 additions & 1 deletion art/attacks/evasion/auto_projected_gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ def __call__(self, y_true: tf.Tensor, y_pred: tf.Tensor, *args, **kwargs) -> tf.
nb_classes=estimator.nb_classes,
input_shape=estimator.input_shape,
loss_object=_loss_object_tf,
train_step=estimator._train_step,
optimizer=estimator.optimizer,
train_step=estimator.train_step,
channels_first=estimator.channels_first,
clip_values=estimator.clip_values,
preprocessing_defences=estimator.preprocessing_defences,
Expand Down
3 changes: 2 additions & 1 deletion art/attacks/evasion/brendel_bethge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2055,7 +2055,8 @@ def logits_difference(y_true, y_pred):
nb_classes=estimator.nb_classes,
input_shape=estimator.input_shape,
loss_object=self._loss_object,
train_step=estimator._train_step,
optimizer=estimator.optimizer,
train_step=estimator.train_step,
channels_first=estimator.channels_first,
clip_values=estimator.clip_values,
preprocessing_defences=estimator.preprocessing_defences,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def _generate_batch(
inputs = x.to(self.estimator.device)
targets = targets.to(self.estimator.device)
adv_x = torch.clone(inputs)
momentum = torch.zeros(inputs.shape)
momentum = torch.zeros(inputs.shape).to(self.estimator.device)

if mask is not None:
mask = mask.to(self.estimator.device)
Expand Down
16 changes: 8 additions & 8 deletions art/attacks/evasion/sign_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,14 +306,14 @@ def _fine_grained_binary_search_local(
lbd = initial_lbd
# For targeted: we want to expand(x1.01) boundary away from targeted dataset
# For untargeted, we want to slim(x0.99) the boundary toward the original dataset
if (not self._is_label(x_0 + lbd * theta, target) and self.targeted) or (
self._is_label(x_0 + lbd * theta, y_0) and not self.targeted
if (self.targeted and not self._is_label(x_0 + lbd * theta, target)) or (
not self.targeted and self._is_label(x_0 + lbd * theta, y_0)
):
lbd_lo = lbd
lbd_hi = lbd * 1.01
nquery += 1
while (not self._is_label(x_0 + lbd_hi * theta, target) and self.targeted) or (
self._is_label(x_0 + lbd_hi * theta, y_0) and not self.targeted
while (self.targeted and not self._is_label(x_0 + lbd_hi * theta, target)) or (
not self.targeted and self._is_label(x_0 + lbd_hi * theta, y_0)
):
lbd_hi = lbd_hi * 1.01
nquery += 1
Expand All @@ -323,17 +323,17 @@ def _fine_grained_binary_search_local(
lbd_hi = lbd
lbd_lo = lbd * 0.99
nquery += 1
while (self._is_label(x_0 + lbd_lo * theta, target) and self.targeted) or (
not self._is_label(x_0 + lbd_lo * theta, y_0) and not self.targeted
while (self.targeted and self._is_label(x_0 + lbd_lo * theta, target)) or (
not self.targeted and not self._is_label(x_0 + lbd_lo * theta, y_0)
):
lbd_lo = lbd_lo * 0.99
nquery += 1

while (lbd_hi - lbd_lo) > tol:
lbd_mid = (lbd_lo + lbd_hi) / 2.0
nquery += 1
if (self._is_label(x_0 + lbd_mid * theta, target) and self.targeted) or (
not self._is_label(x_0 + lbd_mid * theta, y_0) and not self.targeted
if (self.targeted and self._is_label(x_0 + lbd_mid * theta, target)) or (
not self.targeted and not self._is_label(x_0 + lbd_mid * theta, y_0)
):
lbd_hi = lbd_mid
else:
Expand Down
41 changes: 24 additions & 17 deletions art/attacks/poisoning/bad_det/bad_det_gma.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union

import numpy as np
from tqdm.auto import tqdm
Expand Down Expand Up @@ -77,36 +77,39 @@ def __init__(

def poison( # pylint: disable=W0221
self,
x: np.ndarray,
x: Union[np.ndarray, List[np.ndarray]],
y: List[Dict[str, np.ndarray]],
**kwargs,
) -> Tuple[np.ndarray, List[Dict[str, np.ndarray]]]:
) -> Tuple[Union[np.ndarray, List[np.ndarray]], List[Dict[str, np.ndarray]]]:
"""
Generate poisoning examples by inserting the backdoor onto the input `x` and changing the classification
for labels `y`.
:param x: Sample images of shape `NCHW` or `NHWC`.
:param x: Sample images of shape `NCHW` or `NHWC` or a list of sample images of any size.
:param y: True labels of type `List[Dict[np.ndarray]]`, one dictionary per input image. The keys and values
of the dictionary are:
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
- labels [N]: the labels for each image.
- scores [N]: the scores or each prediction.
:return: An tuple holding the `(poisoning_examples, poisoning_labels)`.
"""
x_ndim = len(x.shape)
if isinstance(x, np.ndarray):
x_ndim = len(x.shape)
else:
x_ndim = len(x[0].shape) + 1

if x_ndim != 4:
raise ValueError("Unrecognized input dimension. BadDet GMA can only be applied to image data.")

if self.channels_first:
# NCHW --> NHWC
x = np.transpose(x, (0, 2, 3, 1))

x_poison = x.copy()
y_poison: List[Dict[str, np.ndarray]] = []
# copy images
x_poison: Union[np.ndarray, List[np.ndarray]]
if isinstance(x, np.ndarray):
x_poison = x.copy()
else:
x_poison = [x_i.copy() for x_i in x]

# copy labels
y_poison: List[Dict[str, np.ndarray]] = []
for y_i in y:
target_dict = {k: v.copy() for k, v in y_i.items()}
y_poison.append(target_dict)
Expand All @@ -120,18 +123,22 @@ def poison( # pylint: disable=W0221
image = x_poison[i]
labels = y_poison[i]["labels"]

if self.channels_first:
image = np.transpose(image, (1, 2, 0))

# insert backdoor into the image
# add an additional dimension to create a batch of size 1
poisoned_input, _ = self.backdoor.poison(image[np.newaxis], labels)
x_poison[i] = poisoned_input[0]
image = poisoned_input[0]

# replace the original image with the poisoned image
if self.channels_first:
image = np.transpose(image, (2, 0, 1))
x_poison[i] = image

# change all labels to the target label
y_poison[i]["labels"] = np.full(labels.shape, self.class_target)

if self.channels_first:
# NHWC --> NCHW
x_poison = np.transpose(x_poison, (0, 3, 1, 2))

return x_poison, y_poison

def _check_params(self) -> None:
Expand Down
40 changes: 23 additions & 17 deletions art/attacks/poisoning/bad_det/bad_det_oda.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
from typing import Dict, List, Tuple
from typing import Dict, List, Tuple, Union

import numpy as np
from tqdm.auto import tqdm
Expand Down Expand Up @@ -77,36 +77,39 @@ def __init__(

def poison( # pylint: disable=W0221
self,
x: np.ndarray,
x: Union[np.ndarray, List[np.ndarray]],
y: List[Dict[str, np.ndarray]],
**kwargs,
) -> Tuple[np.ndarray, List[Dict[str, np.ndarray]]]:
) -> Tuple[Union[np.ndarray, List[np.ndarray]], List[Dict[str, np.ndarray]]]:
"""
Generate poisoning examples by inserting the backdoor onto the input `x` and changing the classification
for labels `y`.
:param x: Sample images of shape `NCHW` or `NHWC`.
:param x: Sample images of shape `NCHW` or `NHWC` or a list of sample images of any size.
:param y: True labels of type `List[Dict[np.ndarray]]`, one dictionary per input image. The keys and values
of the dictionary are:
- boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
- labels [N]: the labels for each image.
- scores [N]: the scores or each prediction.
:return: An tuple holding the `(poisoning_examples, poisoning_labels)`.
"""
x_ndim = len(x.shape)
if isinstance(x, np.ndarray):
x_ndim = len(x.shape)
else:
x_ndim = len(x[0].shape) + 1

if x_ndim != 4:
raise ValueError("Unrecognized input dimension. BadDet ODA can only be applied to image data.")

if self.channels_first:
# NCHW --> NHWC
x = np.transpose(x, (0, 2, 3, 1))

x_poison = x.copy()
y_poison: List[Dict[str, np.ndarray]] = []
# copy images
x_poison: Union[np.ndarray, List[np.ndarray]]
if isinstance(x, np.ndarray):
x_poison = x.copy()
else:
x_poison = [x_i.copy() for x_i in x]

# copy labels and find indices of the source class
y_poison: List[Dict[str, np.ndarray]] = []
source_indices = []
for i, y_i in enumerate(y):
target_dict = {k: v.copy() for k, v in y_i.items()}
Expand All @@ -121,10 +124,12 @@ def poison( # pylint: disable=W0221

for i in tqdm(selected_indices, desc="BadDet ODA iteration", disable=not self.verbose):
image = x_poison[i]

boxes = y_poison[i]["boxes"]
labels = y_poison[i]["labels"]

if self.channels_first:
image = np.transpose(image, (1, 2, 0))

keep_indices = []

for j, (box, label) in enumerate(zip(boxes, labels)):
Expand All @@ -140,13 +145,14 @@ def poison( # pylint: disable=W0221
else:
keep_indices.append(j)

# replace the original image with the poisoned image
if self.channels_first:
image = np.transpose(image, (2, 0, 1))
x_poison[i] = image

# remove labels for poisoned bounding boxes
y_poison[i] = {k: v[keep_indices] for k, v in y_poison[i].items()}

if self.channels_first:
# NHWC --> NCHW
x_poison = np.transpose(x_poison, (0, 3, 1, 2))

return x_poison, y_poison

def _check_params(self) -> None:
Expand Down
Loading

0 comments on commit aaa5bce

Please sign in to comment.