diff --git a/src/algos/base_class.py b/src/algos/base_class.py index 7b61baf..72f1344 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -313,6 +313,7 @@ def get_malicious_model_weights(self) -> Dict[str, Tensor]: if self.malicious_type == "sign_flip": return SignFlipAttack(self.config, self.model.state_dict()).get_representation() elif self.malicious_type == "bad_weights": + print("bad weights attack") return BadWeightsAttack(self.config, self.model.state_dict()).get_representation() elif self.malicious_type == "add_noise": return AddNoiseAttack(self.config, self.model.state_dict()).get_representation() diff --git a/src/utils/model_utils.py b/src/utils/model_utils.py index 97c639e..82011a1 100644 --- a/src/utils/model_utils.py +++ b/src/utils/model_utils.py @@ -88,7 +88,7 @@ def train( model, optim, dloader, loss_fn, device, test_loader=None, **kwargs ) return mean_loss, acc - elif self.malicious_type == "backdoor_attack" or self.malicious_type == "gradient_attack": + elif self.malicious_type == "backdoor_attack" or self.malicious_type == "gradient_attack" or self.malicious_type == "label_flip": train_loss, acc = self.train_classification_malicious( model, optim, dloader, loss_fn, device, test_loader=None, **kwargs ) @@ -308,6 +308,7 @@ def train_classification_malicious( loss.backward() elif self.malicious_type == "label_flip": # permutation = torch.tensor(self.config.get("permutation", [i for i in range(10)])) + print("flipping labels") permute_labels = self.config.get("permute_labels", 10) permutation = torch.randperm(permute_labels) permutation = permutation.to(target.device)