diff --git a/examples/cifar100/.gitignore b/examples/cifar100/.gitignore new file mode 100644 index 000000000..3fee66770 --- /dev/null +++ b/examples/cifar100/.gitignore @@ -0,0 +1,3 @@ +data/* +results/* +*.npz diff --git a/examples/cifar100/config.py b/examples/cifar100/config.py new file mode 100644 index 000000000..f2d5132df --- /dev/null +++ b/examples/cifar100/config.py @@ -0,0 +1,13 @@ +settings = { + "N_CLIENTS": 5, + "DISCOVER_HOST": "localhost", + "DISCOVER_PORT": 8092, + "SECURE": False, + "VERIFY": False, + "ADMIN_TOKEN": None, + "CLIENT_TOKEN": None, + "BATCH_SIZE": 128, + "EPOCHS": 1, + "BALANCED": True, + "IID": True, +} diff --git a/examples/cifar100/data.py b/examples/cifar100/data.py new file mode 100644 index 000000000..450e9e5a5 --- /dev/null +++ b/examples/cifar100/data.py @@ -0,0 +1,312 @@ +import os +import pickle +from typing import List + +import numpy as np +from scipy.stats import dirichlet +from torch.utils.data import DataLoader, Dataset, Subset +from torchvision import datasets, transforms + +# Set a fixed random seed for reproducibility +RANDOM_SEED = 42 +np.random.seed(RANDOM_SEED) +# testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) + + +def fine_to_coarse_labels(fine_labels: np.ndarray) -> np.ndarray: + coarse = np.array( + [ + 4, + 1, + 14, + 8, + 0, + 6, + 7, + 7, + 18, + 3, + 3, + 14, + 9, + 18, + 7, + 11, + 3, + 9, + 7, + 11, + 6, + 11, + 5, + 10, + 7, + 6, + 13, + 15, + 3, + 15, + 0, + 11, + 1, + 10, + 12, + 14, + 16, + 9, + 11, + 5, + 5, + 19, + 8, + 8, + 15, + 13, + 14, + 17, + 18, + 10, + 16, + 4, + 17, + 4, + 2, + 0, + 17, + 4, + 18, + 17, + 10, + 3, + 2, + 12, + 12, + 16, + 12, + 1, + 9, + 19, + 2, + 10, + 0, + 1, + 16, + 12, + 9, + 13, + 15, + 13, + 16, + 19, + 2, + 4, + 6, + 19, + 5, + 5, + 8, + 19, + 18, + 1, + 2, + 15, + 6, + 0, + 17, + 8, + 14, + 13, + ] + ) + return coarse[fine_labels] + + +class CIFAR100Federated: + def __init__(self, root_dir: str = "./data/splits"): + """Initialize the splitter + :param root_dir: Directory to save the split datasets + """ + self.root_dir = root_dir + self.splits = {} + os.makedirs(root_dir, exist_ok=True) + + # Load the full dataset + self.transform_train = transforms.Compose( + [ + transforms.RandomCrop(24), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), + ] + ) + self.trainset = datasets.CIFAR100(root="./data", train=True, download=True, transform=self.transform_train) + + self.transform_test = transforms.Compose( + [transforms.CenterCrop(24), transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))] + ) + self.testset = datasets.CIFAR100(root="./data", train=False, download=True, transform=self.transform_test) + + def create_splits(self, num_splits: int, balanced: bool, iid: bool) -> None: + """Create dataset splits based on specified parameters + :param num_splits: Number of splits to create + :param balanced: Whether splits should have equal size + :param iid: Whether splits should be IID + """ + config_key = f"splits_{num_splits}_bal_{balanced}_iid_{iid}" + + if iid: + indices = self._create_iid_splits(num_splits, balanced) + else: + indices = self._create_non_iid_splits(num_splits, balanced) + + # Save splits + for i, split_indices in enumerate(indices): + split_path = os.path.join(self.root_dir, f"{config_key}_split_{i}.pkl") + with open(split_path, "wb") as f: + pickle.dump(split_indices, f) + + self.splits[config_key] = indices + + def _create_iid_splits(self, num_splits: int, balanced: bool) -> List[np.ndarray]: + """Create IID splits of the dataset""" + indices = np.arange(len(self.trainset)) + np.random.shuffle(indices) + + if balanced: + # Equal size splits + split_size = len(indices) // num_splits + return [indices[i * split_size : (i + 1) * split_size] for i in range(num_splits)] + else: + # Random size splits + split_points = sorted(np.random.choice(len(indices) - 1, num_splits - 1, replace=False)) + return np.split(indices, split_points) + + def _create_non_iid_splits(self, num_splits: int, balanced: bool) -> List[np.ndarray]: + """Create non-IID splits using Pachinko Allocation Method (PAM)""" + # Initialize parameters + alpha = 0.1 # Root Dirichlet parameter + beta = 10.0 # Coarse-to-fine Dirichlet parameter + total_examples = len(self.trainset) + + # Calculate examples per split + if balanced: + examples_per_split = [total_examples // num_splits] * num_splits + else: + # Use Dirichlet to create unbalanced split sizes + split_ratios = np.random.dirichlet([0.5] * num_splits) # Lower alpha = more unbalanced + examples_per_split = np.round(split_ratios * total_examples).astype(int) + # Ensure we use exactly total_examples + examples_per_split[-1] = total_examples - examples_per_split[:-1].sum() + + # Get fine labels and map them to coarse labels + fine_labels = np.array(self.trainset.targets) + coarse_labels = fine_to_coarse_labels(fine_labels) + + # Initialize DAG structure (track available labels) + available_coarse = list(range(20)) # 20 coarse labels as list instead of set + available_fine = {c: set(np.where(coarse_labels == c)[0]) for c in available_coarse} + + indices_per_split = [] + for split_idx in range(num_splits): + split_indices = [] + N = examples_per_split[split_idx] # Use the pre-calculated split size + + # Sample root distribution over coarse labels + coarse_probs = dirichlet.rvs(alpha=[alpha] * len(available_coarse), size=1, random_state=RANDOM_SEED + split_idx)[0] + + # Sample fine label distributions for each available coarse label + fine_distributions = {} + for c in available_coarse: + if len(available_fine[c]) > 0: + fine_probs = dirichlet.rvs(alpha=[beta] * len(available_fine[c]), size=1, random_state=RANDOM_SEED + split_idx + c)[0] + fine_distributions[c] = fine_probs + + # Sample N examples for this split + for _ in range(N): + if len(available_coarse) == 0: + break + + # Sample coarse label + coarse_idx = np.random.choice(available_coarse, p=coarse_probs) + + if len(available_fine[coarse_idx]) == 0: + # Remove empty coarse label and renormalize + idx_to_remove = available_coarse.index(coarse_idx) + available_coarse.remove(coarse_idx) + coarse_probs = self._renormalize(coarse_probs, idx_to_remove) + continue + + # Sample fine label + fine_probs = fine_distributions[coarse_idx] + available_fine_indices = list(available_fine[coarse_idx]) + fine_probs = fine_probs[: len(available_fine_indices)] + fine_probs = fine_probs / fine_probs.sum() # Renormalize + fine_idx = np.random.choice(available_fine_indices, p=fine_probs) + + # Add example to split + split_indices.append(fine_idx) + + # Remove selected example + available_fine[coarse_idx].remove(fine_idx) + + # Renormalize if necessary + if len(available_fine[coarse_idx]) == 0: + idx_to_remove = available_coarse.index(coarse_idx) + available_coarse.remove(coarse_idx) + coarse_probs = self._renormalize(coarse_probs, idx_to_remove) + + indices_per_split.append(np.array(split_indices)) + + return indices_per_split + + def _renormalize(self, probs: np.ndarray, removed_idx: int) -> np.ndarray: + """Implementation of Algorithm 8 from the paper""" + # Create a list of valid indices (excluding the removed index) + valid_indices = [i for i in range(len(probs)) if i != removed_idx] + + # Select only the probabilities for valid indices + valid_probs = probs[valid_indices] + + # Normalize the remaining probabilities + return valid_probs / valid_probs.sum() + + def get_split(self, split_id: int, num_splits: int, balanced: bool, iid: bool) -> Dataset: + """Get a specific split of the dataset + :param split_id: ID of the split to retrieve + :param num_splits: Total number of splits + :param balanced: Whether splits are balanced + :param iid: Whether splits are IID + :return: Dataset split + """ + config_key = f"splits_{num_splits}_bal_{balanced}_iid_{iid}" + split_path = os.path.join(self.root_dir, f"{config_key}_split_{split_id}.pkl") + + if not os.path.exists(split_path): + self.create_splits(num_splits, balanced, iid) + + with open(split_path, "rb") as f: + indices = pickle.load(f) # noqa: S301 + + return Subset(self.trainset, indices) + + +def get_data_loader(num_splits: int = 5, balanced: bool = True, iid: bool = True, batch_size: int = 100, is_train: bool = True): + """Get a data loader for the CIFAR-100 dataset + :param num_splits: Number of splits to create + :param balanced: Whether splits are balanced + :param iid: Whether splits are IID + :param batch_size: Batch size + :param is_train: Whether to get the training or test data loader + :return: Data loader + """ + cifar_data = CIFAR100Federated() + + if is_train: + split_id = os.environ.get("FEDN_DATA_SPLIT_ID", 0) + dataset = cifar_data.get_split(split_id=split_id, num_splits=num_splits, balanced=balanced, iid=iid) + print(f"Getting data loader for split {split_id} of trainset (size: {len(dataset)})") + else: + dataset = cifar_data.testset + print(f"Getting data loader for testset (size: {len(dataset)})") + + return DataLoader(dataset, batch_size=batch_size, shuffle=is_train) diff --git a/examples/cifar100/init_fedn.py b/examples/cifar100/init_fedn.py new file mode 100644 index 000000000..e4547b5bc --- /dev/null +++ b/examples/cifar100/init_fedn.py @@ -0,0 +1,13 @@ +from config import settings +from fedn import APIClient + +client = APIClient( + host=settings["DISCOVER_HOST"], + port=settings["DISCOVER_PORT"], + secure=settings["SECURE"], + verify=settings["VERIFY"], + token=settings["ADMIN_TOKEN"], +) + +result = client.set_active_model("seed.npz") +print(result["message"]) diff --git a/examples/cifar100/init_seed.py b/examples/cifar100/init_seed.py new file mode 100644 index 000000000..4f5fe6b5e --- /dev/null +++ b/examples/cifar100/init_seed.py @@ -0,0 +1,60 @@ +import collections + +import torch +from torch import nn +from torchvision import models + +from fedn.utils.helpers.helpers import get_helper + +HELPER_MODULE = "numpyhelper" +helper = get_helper(HELPER_MODULE) + + +# Function to replace BatchNorm layers with GroupNorm +def replace_bn_with_gn(module, num_groups=32): + for name, child in module.named_children(): + if isinstance(child, nn.BatchNorm2d): + num_channels = child.num_features + setattr(module, name, nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)) + else: + replace_bn_with_gn(child, num_groups) # Apply recursively to nested modules + + +def compile_model(): + # Load ResNet-18 and replace BatchNorm with GroupNorm + resnet18 = models.resnet18(weights=None) + replace_bn_with_gn(resnet18) + # Modify final layer for CIFAR-100 (100 classes) + resnet18.fc = nn.Linear(512, 100) + return resnet18 + + +def save_parameters(model, out_path): + parameters_np = [val.cpu().numpy() for _, val in model.state_dict().items()] + helper.save(parameters_np, out_path) + + +def init_seed(out_path="seed.npz"): + model = compile_model() + save_parameters(model, out_path) + + +def load_parameters(model_path): + """Load model parameters from file and populate model. + + param model_path: The path to load from. + :type model_path: str + :return: The loaded model. + :rtype: torch.nn.Module + """ + model = compile_model() + parameters_np = helper.load(model_path) + + params_dict = zip(model.state_dict().keys(), parameters_np) + state_dict = collections.OrderedDict({key: torch.tensor(x) for key, x in params_dict}) + model.load_state_dict(state_dict, strict=True) + return model + + +if __name__ == "__main__": + init_seed("seed.npz") diff --git a/examples/cifar100/requirements.txt b/examples/cifar100/requirements.txt new file mode 100644 index 000000000..c1236222d --- /dev/null +++ b/examples/cifar100/requirements.txt @@ -0,0 +1,3 @@ +torch +torchvision +scipy \ No newline at end of file diff --git a/examples/cifar100/run_client.py b/examples/cifar100/run_client.py new file mode 100644 index 000000000..caf6c6cb6 --- /dev/null +++ b/examples/cifar100/run_client.py @@ -0,0 +1,203 @@ +import argparse +import io +import os +import uuid + +import torch +from data import get_data_loader +from init_seed import load_parameters, save_parameters +from torch import nn, optim + +from config import settings +from fedn import FednClient +from fedn.network.clients.fedn_client import ConnectToApiResult +from fedn.utils.helpers.helpers import get_helper + +helper = get_helper("numpyhelper") + + +def get_api_url(api_url: str, api_port: int, secure: bool = False): + if secure: + url = f"https://{api_url}:{api_port}" if api_port else f"https://{api_url}" + else: + url = f"http://{api_url}:{api_port}" if api_port else f"http://{api_url}" + if not url.endswith("/"): + url += "/" + return url + + +def on_train(in_model, client_settings): + # Save model to temp file + inpath = helper.get_tmp_path() + with open(inpath, "wb") as fh: + fh.write(in_model.getbuffer()) + + # Load model from temp file + resnet18 = load_parameters(inpath) + os.unlink(inpath) + + # Move model to GPU if available + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + resnet18 = resnet18.to(device) + + # Define loss function and optimizer + criterion = nn.CrossEntropyLoss() + learning_rate = 0.001 + weight_decay = 5e-4 + optimizer = optim.Adam(resnet18.parameters(), lr=learning_rate, weight_decay=weight_decay) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200) + + # Get data loader for trainset + trainloader = get_data_loader( + num_splits=settings["N_CLIENTS"], + balanced=settings["BALANCED"], + iid=settings["IID"], + is_train=True, + batch_size=settings["BATCH_SIZE"], + ) + + # Calculate number of batches + num_batches = len(trainloader) + + # Training loop + num_epochs = settings["EPOCHS"] + for epoch in range(num_epochs): + resnet18.train() + + for batch_idx, (inputs, labels) in enumerate(trainloader): + inputs, labels = inputs.to(device), labels.to(device) + + optimizer.zero_grad() + outputs = resnet18(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + if batch_idx % 10 == 0: + print(f"Epoch: {epoch}, Batch: {batch_idx}/{num_batches}, Loss: {loss.item():.4f}") + + scheduler.step() + + # Save model parameters + outpath = helper.get_tmp_path() + save_parameters(resnet18, outpath) + with open(outpath, "rb") as fr: + out_model = io.BytesIO(fr.read()) + os.unlink(outpath) + + # Return model and metadata + training_metadata = { + "num_examples": len(trainloader.dataset), + "batch_size": settings["BATCH_SIZE"], + "epochs": num_epochs, + "lr": learning_rate, + } + metadata = {"training_metadata": training_metadata} + return out_model, metadata + + +def on_validate(in_model): + # Save model to temp file + inpath = helper.get_tmp_path() + with open(inpath, "wb") as fh: + fh.write(in_model.getbuffer()) + + # Load model from temp file + resnet18 = load_parameters(inpath) + os.unlink(inpath) + + # Move model to GPU if available + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + resnet18 = resnet18.to(device) + resnet18.eval() + + criterion = nn.CrossEntropyLoss() + + # Calculate training metrics + trainloader = get_data_loader( + num_splits=settings["N_CLIENTS"], + balanced=settings["BALANCED"], + iid=settings["IID"], + is_train=True, + batch_size=settings["BATCH_SIZE"], + ) + train_loss = 0 + train_correct = 0 + train_total = 0 + + with torch.no_grad(): + for inputs, labels in trainloader: + inputs, labels = inputs.to(device), labels.to(device) + outputs = resnet18(inputs) + loss = criterion(outputs, labels) + + train_loss += loss.item() + _, predicted = outputs.max(1) + train_total += labels.size(0) + train_correct += predicted.eq(labels).sum().item() + + train_accuracy = train_correct / train_total + train_loss = train_loss / len(trainloader) + + # Calculate test metrics + testloader = get_data_loader( + is_train=False, + batch_size=settings["BATCH_SIZE"], + ) + test_loss = 0 + test_correct = 0 + test_total = 0 + + with torch.no_grad(): + for inputs, labels in testloader: + inputs, labels = inputs.to(device), labels.to(device) + outputs = resnet18(inputs) + loss = criterion(outputs, labels) + + test_loss += loss.item() + _, predicted = outputs.max(1) + test_total += labels.size(0) + test_correct += predicted.eq(labels).sum().item() + + test_accuracy = test_correct / test_total + test_loss = test_loss / len(testloader) + + metrics = { + "test_accuracy": test_accuracy, + "test_loss": test_loss, + "train_accuracy": train_accuracy, + "train_loss": train_loss, + } + return metrics + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="CIFAR100 Client") + parser.add_argument("--split-id", type=int, required=True, help="The split ID") + args = parser.parse_args() + + client = FednClient(train_callback=on_train, validate_callback=on_validate) + url = get_api_url(settings["DISCOVER_HOST"], settings["DISCOVER_PORT"], settings["SECURE"]) + client.set_name(f"cifar100-client-{args.split_id}") + client.set_client_id(str(uuid.uuid4())) + + controller_config = { + "name": client.name, + "client_id": client.client_id, + "package": "local", + "preferred_combiner": "", + } + + result, combiner_config = client.connect_to_api(url=url, token=settings["CLIENT_TOKEN"], json=controller_config) + + if result != ConnectToApiResult.Assigned: + print("Failed to connect to API, exiting.") + exit(1) + + result = client.init_grpchandler(config=combiner_config, client_name=client.client_id, token=settings["CLIENT_TOKEN"]) + + if not result: + print("Failed to initialize gRPC handler, exiting.") + exit(1) + + client.run()