diff --git a/src/algos/MetaL2C.py b/src/algos/MetaL2C.py index beb267d..f0e0edf 100644 --- a/src/algos/MetaL2C.py +++ b/src/algos/MetaL2C.py @@ -122,7 +122,7 @@ def __init__( ) -> None: super().__init__(config, comm_utils) - self.encoder = ModelEncoder(self.get_model_weights()) + self.encoder = ModelEncoder(self.get_model_weights(get_external_repr=False)) self.encoder_optim = optim.SGD( self.encoder.parameters(), lr=self.config["alpha_lr"] ) diff --git a/src/algos/base_class.py b/src/algos/base_class.py index e967ade..7b61baf 100644 --- a/src/algos/base_class.py +++ b/src/algos/base_class.py @@ -29,6 +29,12 @@ TransformDataset, CorruptDataset, ) + +# import the possible attacks +from algos.attack_add_noise import AddNoiseAttack +from algos.attack_bad_weights import BadWeightsAttack +from algos.attack_sign_flip import SignFlipAttack + from utils.log_utils import LogUtils from utils.model_utils import ModelUtils from utils.community_utils import ( @@ -94,6 +100,7 @@ class BaseNode(ABC): def __init__( self, config: Dict[str, Any], comm_utils: CommunicationManager ) -> None: + self.config = config self.set_constants() self.comm_utils = comm_utils self.node_id = self.comm_utils.get_rank() @@ -122,6 +129,7 @@ def __init__( if "gia" in config and self.node_id in config["gia_attackers"]: self.gia_attacker = True + self.malicious_type = config.get("malicious_type", "normal") self.log_memory = config.get("log_memory", False) @@ -175,7 +183,7 @@ def setup_logging(self, config: Dict[str, ConfigType]) -> None: def setup_cuda(self, config: Dict[str, ConfigType]) -> None: """add docstring here""" # Need a mapping from rank to device id - if (config.get("assign_based_on_host", False)): + if (config.get("assign_based_on_host", False)) == False: device_ids_map = config["device_ids"] node_name = f"node_{self.node_id}" self.device_ids = device_ids_map[node_name] @@ -271,11 +279,21 @@ def set_shared_exp_parameters(self, config: Dict[str, ConfigType]) -> None: def local_round_done(self) -> None: self.round += 1 - def get_model_weights(self) -> Dict[str, Tensor]: + def get_model_weights(self, get_external_repr:bool=True) -> Dict[str, Tensor]: """ Share the model weights + + Args: + get_external_repr (bool): Whether to get the external representation of the model, + used for malicious attacks where the model weights are modified before sharing. """ - message = {"sender": self.node_id, "round": self.round, "model": self.model.state_dict()} + + if get_external_repr and self.malicious_type != "normal": + # Get the external representation of the malicious model + model = self.get_malicious_model_weights() + else: + model = self.model.state_dict() + message = {"sender": self.node_id, "round": self.round, "model": model} if "gia" in self.config and hasattr(self, 'images') and hasattr(self, 'labels'): # also stream image and labels @@ -287,6 +305,19 @@ def get_model_weights(self) -> Dict[str, Tensor]: message["model"][key] = message["model"][key].to("cpu") return message + + def get_malicious_model_weights(self) -> Dict[str, Tensor]: + """ + Get the external representation of the model based on the malicious type. + """ + if self.malicious_type == "sign_flip": + return SignFlipAttack(self.config, self.model.state_dict()).get_representation() + elif self.malicious_type == "bad_weights": + return BadWeightsAttack(self.config, self.model.state_dict()).get_representation() + elif self.malicious_type == "add_noise": + return AddNoiseAttack(self.config, self.model.state_dict()).get_representation() + else: + return self.model.state_dict() def get_local_rounds(self) -> int: return self.round @@ -1101,7 +1132,7 @@ def receive_and_aggregate_streaming(self, neighbors: List[int]) -> None: total_weight = 0.0 # To re-normalize weights after handling dropouts # Include the current node's model in the aggregation - current_model_wts = self.get_model_weights() + current_model_wts = self.get_model_weights(get_external_repr=False) # internal model representation assert "model" in current_model_wts, "Model not found in the current model." current_model_wts = current_model_wts["model"] current_weight = 1.0 / (len(neighbors) + 1) # Weight for the current node diff --git a/src/algos/fl.py b/src/algos/fl.py index 7d5e58a..c4b6eb7 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -37,52 +37,52 @@ def local_test(self, **kwargs: Any) -> Tuple[float, float, float]: return test_loss, test_acc, time_taken - def get_model_weights(self, **kwargs: Any) -> Dict[str, Any]: - """ - Overwrite the get_model_weights method of the BaseNode - to add malicious attacks - TODO: this should be moved to BaseClient - """ - - message = {"sender": self.node_id, "round": self.round} - - malicious_type = self.config.get("malicious_type", "normal") - - if malicious_type == "normal": - message["model"] = self.model.state_dict() # type: ignore - elif malicious_type == "bad_weights": - # Corrupt the weights - message["model"] = BadWeightsAttack( - self.config, self.model.state_dict() - ).get_representation() - elif malicious_type == "sign_flip": - # Flip the sign of the weights, also TODO: consider label flipping - message["model"] = SignFlipAttack( - self.config, self.model.state_dict() - ).get_representation() - elif malicious_type == "add_noise": - # Add noise to the weights - message["model"] = AddNoiseAttack( - self.config, self.model.state_dict() - ).get_representation() - else: - message["model"] = self.model.state_dict() # type: ignore - - # move the model to cpu before sending - for key in message["model"].keys(): - message["model"][key] = message["model"][key].to("cpu") - - # assert hasattr(self, 'images') and hasattr(self, 'labels'), "Images and labels not found" - if "gia" in self.config and hasattr(self, 'images') and hasattr(self, 'labels'): - # also stream image and labels - message["images"] = self.images.to("cpu") - message["labels"] = self.labels.to("cpu") - - message["random_params"] = self.random_params - for key in message["random_params"].keys(): - message["random_params"][key] = message["random_params"][key].to("cpu") + # def get_model_weights(self, **kwargs: Any) -> Dict[str, Any]: + # """ + # Overwrite the get_model_weights method of the BaseNode + # to add malicious attacks + # TODO: this should be moved to BaseClient + # """ + + # message = {"sender": self.node_id, "round": self.round} + + # malicious_type = self.config.get("malicious_type", "normal") + + # if malicious_type == "normal": + # message["model"] = self.model.state_dict() # type: ignore + # elif malicious_type == "bad_weights": + # # Corrupt the weights + # message["model"] = BadWeightsAttack( + # self.config, self.model.state_dict() + # ).get_representation() + # elif malicious_type == "sign_flip": + # # Flip the sign of the weights, also TODO: consider label flipping + # message["model"] = SignFlipAttack( + # self.config, self.model.state_dict() + # ).get_representation() + # elif malicious_type == "add_noise": + # # Add noise to the weights + # message["model"] = AddNoiseAttack( + # self.config, self.model.state_dict() + # ).get_representation() + # else: + # message["model"] = self.model.state_dict() # type: ignore + + # # move the model to cpu before sending + # for key in message["model"].keys(): + # message["model"][key] = message["model"][key].to("cpu") + + # # assert hasattr(self, 'images') and hasattr(self, 'labels'), "Images and labels not found" + # if "gia" in self.config and hasattr(self, 'images') and hasattr(self, 'labels'): + # # also stream image and labels + # message["images"] = self.images.to("cpu") + # message["labels"] = self.labels.to("cpu") + + # message["random_params"] = self.random_params + # for key in message["random_params"].keys(): + # message["random_params"][key] = message["random_params"][key].to("cpu") - return message # type: ignore + # return message # type: ignore def run_protocol(self): print(f"Client {self.node_id} ready to start training") diff --git a/src/configs/malicious_config.py b/src/configs/malicious_config.py index 15025d4..ed9ecb7 100644 --- a/src/configs/malicious_config.py +++ b/src/configs/malicious_config.py @@ -60,4 +60,5 @@ "gradient_attack": gradient_attack, "backdoor_attack": backdoor_attack, "data_poisoning": data_poisoning, + "label_flip": label_flip, }