Skip to content

Commit

Permalink
abstracted get_model_weights to work with malicious weight sending
Browse files Browse the repository at this point in the history
  • Loading branch information
joyce-yuan committed Dec 17, 2024
1 parent 1a2c089 commit a5e5d74
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 50 deletions.
2 changes: 1 addition & 1 deletion src/algos/MetaL2C.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
) -> None:
super().__init__(config, comm_utils)

self.encoder = ModelEncoder(self.get_model_weights())
self.encoder = ModelEncoder(self.get_model_weights(get_external_repr=False))
self.encoder_optim = optim.SGD(
self.encoder.parameters(), lr=self.config["alpha_lr"]
)
Expand Down
39 changes: 35 additions & 4 deletions src/algos/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
TransformDataset,
CorruptDataset,
)

# import the possible attacks
from algos.attack_add_noise import AddNoiseAttack
from algos.attack_bad_weights import BadWeightsAttack
from algos.attack_sign_flip import SignFlipAttack

from utils.log_utils import LogUtils
from utils.model_utils import ModelUtils
from utils.community_utils import (
Expand Down Expand Up @@ -94,6 +100,7 @@ class BaseNode(ABC):
def __init__(
self, config: Dict[str, Any], comm_utils: CommunicationManager
) -> None:
self.config = config
self.set_constants()
self.comm_utils = comm_utils
self.node_id = self.comm_utils.get_rank()
Expand Down Expand Up @@ -122,6 +129,7 @@ def __init__(

if "gia" in config and self.node_id in config["gia_attackers"]:
self.gia_attacker = True
self.malicious_type = config.get("malicious_type", "normal")

self.log_memory = config.get("log_memory", False)

Expand Down Expand Up @@ -175,7 +183,7 @@ def setup_logging(self, config: Dict[str, ConfigType]) -> None:
def setup_cuda(self, config: Dict[str, ConfigType]) -> None:
"""add docstring here"""
# Need a mapping from rank to device id
if (config.get("assign_based_on_host", False)):
if (config.get("assign_based_on_host", False)) == False:
device_ids_map = config["device_ids"]
node_name = f"node_{self.node_id}"
self.device_ids = device_ids_map[node_name]
Expand Down Expand Up @@ -271,11 +279,21 @@ def set_shared_exp_parameters(self, config: Dict[str, ConfigType]) -> None:
def local_round_done(self) -> None:
self.round += 1

def get_model_weights(self) -> Dict[str, Tensor]:
def get_model_weights(self, get_external_repr:bool=True) -> Dict[str, Tensor]:
"""
Share the model weights
Args:
get_external_repr (bool): Whether to get the external representation of the model,
used for malicious attacks where the model weights are modified before sharing.
"""
message = {"sender": self.node_id, "round": self.round, "model": self.model.state_dict()}

if get_external_repr and self.malicious_type != "normal":
# Get the external representation of the malicious model
model = self.get_malicious_model_weights()
else:
model = self.model.state_dict()
message = {"sender": self.node_id, "round": self.round, "model": model}

if "gia" in self.config and hasattr(self, 'images') and hasattr(self, 'labels'):
# also stream image and labels
Expand All @@ -287,6 +305,19 @@ def get_model_weights(self) -> Dict[str, Tensor]:
message["model"][key] = message["model"][key].to("cpu")

return message

def get_malicious_model_weights(self) -> Dict[str, Tensor]:
"""
Get the external representation of the model based on the malicious type.
"""
if self.malicious_type == "sign_flip":
return SignFlipAttack(self.config, self.model.state_dict()).get_representation()
elif self.malicious_type == "bad_weights":
return BadWeightsAttack(self.config, self.model.state_dict()).get_representation()
elif self.malicious_type == "add_noise":
return AddNoiseAttack(self.config, self.model.state_dict()).get_representation()
else:
return self.model.state_dict()

def get_local_rounds(self) -> int:
return self.round
Expand Down Expand Up @@ -1101,7 +1132,7 @@ def receive_and_aggregate_streaming(self, neighbors: List[int]) -> None:
total_weight = 0.0 # To re-normalize weights after handling dropouts

# Include the current node's model in the aggregation
current_model_wts = self.get_model_weights()
current_model_wts = self.get_model_weights(get_external_repr=False) # internal model representation
assert "model" in current_model_wts, "Model not found in the current model."
current_model_wts = current_model_wts["model"]
current_weight = 1.0 / (len(neighbors) + 1) # Weight for the current node
Expand Down
90 changes: 45 additions & 45 deletions src/algos/fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,52 +37,52 @@ def local_test(self, **kwargs: Any) -> Tuple[float, float, float]:
return test_loss, test_acc, time_taken


def get_model_weights(self, **kwargs: Any) -> Dict[str, Any]:
"""
Overwrite the get_model_weights method of the BaseNode
to add malicious attacks
TODO: this should be moved to BaseClient
"""

message = {"sender": self.node_id, "round": self.round}

malicious_type = self.config.get("malicious_type", "normal")

if malicious_type == "normal":
message["model"] = self.model.state_dict() # type: ignore
elif malicious_type == "bad_weights":
# Corrupt the weights
message["model"] = BadWeightsAttack(
self.config, self.model.state_dict()
).get_representation()
elif malicious_type == "sign_flip":
# Flip the sign of the weights, also TODO: consider label flipping
message["model"] = SignFlipAttack(
self.config, self.model.state_dict()
).get_representation()
elif malicious_type == "add_noise":
# Add noise to the weights
message["model"] = AddNoiseAttack(
self.config, self.model.state_dict()
).get_representation()
else:
message["model"] = self.model.state_dict() # type: ignore

# move the model to cpu before sending
for key in message["model"].keys():
message["model"][key] = message["model"][key].to("cpu")

# assert hasattr(self, 'images') and hasattr(self, 'labels'), "Images and labels not found"
if "gia" in self.config and hasattr(self, 'images') and hasattr(self, 'labels'):
# also stream image and labels
message["images"] = self.images.to("cpu")
message["labels"] = self.labels.to("cpu")

message["random_params"] = self.random_params
for key in message["random_params"].keys():
message["random_params"][key] = message["random_params"][key].to("cpu")
# def get_model_weights(self, **kwargs: Any) -> Dict[str, Any]:
# """
# Overwrite the get_model_weights method of the BaseNode
# to add malicious attacks
# TODO: this should be moved to BaseClient
# """

# message = {"sender": self.node_id, "round": self.round}

# malicious_type = self.config.get("malicious_type", "normal")

# if malicious_type == "normal":
# message["model"] = self.model.state_dict() # type: ignore
# elif malicious_type == "bad_weights":
# # Corrupt the weights
# message["model"] = BadWeightsAttack(
# self.config, self.model.state_dict()
# ).get_representation()
# elif malicious_type == "sign_flip":
# # Flip the sign of the weights, also TODO: consider label flipping
# message["model"] = SignFlipAttack(
# self.config, self.model.state_dict()
# ).get_representation()
# elif malicious_type == "add_noise":
# # Add noise to the weights
# message["model"] = AddNoiseAttack(
# self.config, self.model.state_dict()
# ).get_representation()
# else:
# message["model"] = self.model.state_dict() # type: ignore

# # move the model to cpu before sending
# for key in message["model"].keys():
# message["model"][key] = message["model"][key].to("cpu")

# # assert hasattr(self, 'images') and hasattr(self, 'labels'), "Images and labels not found"
# if "gia" in self.config and hasattr(self, 'images') and hasattr(self, 'labels'):
# # also stream image and labels
# message["images"] = self.images.to("cpu")
# message["labels"] = self.labels.to("cpu")

# message["random_params"] = self.random_params
# for key in message["random_params"].keys():
# message["random_params"][key] = message["random_params"][key].to("cpu")

return message # type: ignore
# return message # type: ignore

def run_protocol(self):
print(f"Client {self.node_id} ready to start training")
Expand Down
1 change: 1 addition & 0 deletions src/configs/malicious_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,5 @@
"gradient_attack": gradient_attack,
"backdoor_attack": backdoor_attack,
"data_poisoning": data_poisoning,
"label_flip": label_flip,
}

0 comments on commit a5e5d74

Please sign in to comment.