Skip to content

Commit

Permalink
Merge branch 'feature/SK-1183' into feature/SK-1257
Browse files Browse the repository at this point in the history
  • Loading branch information
Andreas Hellander committed Jan 10, 2025
2 parents aa11dc7 + f8f9732 commit 755051d
Show file tree
Hide file tree
Showing 7 changed files with 607 additions and 0 deletions.
3 changes: 3 additions & 0 deletions examples/cifar100/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
data/*
results/*
*.npz
13 changes: 13 additions & 0 deletions examples/cifar100/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
settings = {
"N_CLIENTS": 5,
"DISCOVER_HOST": "localhost",
"DISCOVER_PORT": 8092,
"SECURE": False,
"VERIFY": False,
"ADMIN_TOKEN": None,
"CLIENT_TOKEN": None,
"BATCH_SIZE": 128,
"EPOCHS": 1,
"BALANCED": True,
"IID": True,
}
312 changes: 312 additions & 0 deletions examples/cifar100/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
import os
import pickle
from typing import List

import numpy as np
from scipy.stats import dirichlet
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import datasets, transforms

# Set a fixed random seed for reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
# testloader = DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)


def fine_to_coarse_labels(fine_labels: np.ndarray) -> np.ndarray:
coarse = np.array(
[
4,
1,
14,
8,
0,
6,
7,
7,
18,
3,
3,
14,
9,
18,
7,
11,
3,
9,
7,
11,
6,
11,
5,
10,
7,
6,
13,
15,
3,
15,
0,
11,
1,
10,
12,
14,
16,
9,
11,
5,
5,
19,
8,
8,
15,
13,
14,
17,
18,
10,
16,
4,
17,
4,
2,
0,
17,
4,
18,
17,
10,
3,
2,
12,
12,
16,
12,
1,
9,
19,
2,
10,
0,
1,
16,
12,
9,
13,
15,
13,
16,
19,
2,
4,
6,
19,
5,
5,
8,
19,
18,
1,
2,
15,
6,
0,
17,
8,
14,
13,
]
)
return coarse[fine_labels]


class CIFAR100Federated:
def __init__(self, root_dir: str = "./data/splits"):
"""Initialize the splitter
:param root_dir: Directory to save the split datasets
"""
self.root_dir = root_dir
self.splits = {}
os.makedirs(root_dir, exist_ok=True)

# Load the full dataset
self.transform_train = transforms.Compose(
[
transforms.RandomCrop(24),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
]
)
self.trainset = datasets.CIFAR100(root="./data", train=True, download=True, transform=self.transform_train)

self.transform_test = transforms.Compose(
[transforms.CenterCrop(24), transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))]
)
self.testset = datasets.CIFAR100(root="./data", train=False, download=True, transform=self.transform_test)

def create_splits(self, num_splits: int, balanced: bool, iid: bool) -> None:
"""Create dataset splits based on specified parameters
:param num_splits: Number of splits to create
:param balanced: Whether splits should have equal size
:param iid: Whether splits should be IID
"""
config_key = f"splits_{num_splits}_bal_{balanced}_iid_{iid}"

if iid:
indices = self._create_iid_splits(num_splits, balanced)
else:
indices = self._create_non_iid_splits(num_splits, balanced)

# Save splits
for i, split_indices in enumerate(indices):
split_path = os.path.join(self.root_dir, f"{config_key}_split_{i}.pkl")
with open(split_path, "wb") as f:
pickle.dump(split_indices, f)

self.splits[config_key] = indices

def _create_iid_splits(self, num_splits: int, balanced: bool) -> List[np.ndarray]:
"""Create IID splits of the dataset"""
indices = np.arange(len(self.trainset))
np.random.shuffle(indices)

if balanced:
# Equal size splits
split_size = len(indices) // num_splits
return [indices[i * split_size : (i + 1) * split_size] for i in range(num_splits)]
else:
# Random size splits
split_points = sorted(np.random.choice(len(indices) - 1, num_splits - 1, replace=False))
return np.split(indices, split_points)

def _create_non_iid_splits(self, num_splits: int, balanced: bool) -> List[np.ndarray]:
"""Create non-IID splits using Pachinko Allocation Method (PAM)"""
# Initialize parameters
alpha = 0.1 # Root Dirichlet parameter
beta = 10.0 # Coarse-to-fine Dirichlet parameter
total_examples = len(self.trainset)

# Calculate examples per split
if balanced:
examples_per_split = [total_examples // num_splits] * num_splits
else:
# Use Dirichlet to create unbalanced split sizes
split_ratios = np.random.dirichlet([0.5] * num_splits) # Lower alpha = more unbalanced
examples_per_split = np.round(split_ratios * total_examples).astype(int)
# Ensure we use exactly total_examples
examples_per_split[-1] = total_examples - examples_per_split[:-1].sum()

# Get fine labels and map them to coarse labels
fine_labels = np.array(self.trainset.targets)
coarse_labels = fine_to_coarse_labels(fine_labels)

# Initialize DAG structure (track available labels)
available_coarse = list(range(20)) # 20 coarse labels as list instead of set
available_fine = {c: set(np.where(coarse_labels == c)[0]) for c in available_coarse}

indices_per_split = []
for split_idx in range(num_splits):
split_indices = []
N = examples_per_split[split_idx] # Use the pre-calculated split size

# Sample root distribution over coarse labels
coarse_probs = dirichlet.rvs(alpha=[alpha] * len(available_coarse), size=1, random_state=RANDOM_SEED + split_idx)[0]

# Sample fine label distributions for each available coarse label
fine_distributions = {}
for c in available_coarse:
if len(available_fine[c]) > 0:
fine_probs = dirichlet.rvs(alpha=[beta] * len(available_fine[c]), size=1, random_state=RANDOM_SEED + split_idx + c)[0]
fine_distributions[c] = fine_probs

# Sample N examples for this split
for _ in range(N):
if len(available_coarse) == 0:
break

# Sample coarse label
coarse_idx = np.random.choice(available_coarse, p=coarse_probs)

if len(available_fine[coarse_idx]) == 0:
# Remove empty coarse label and renormalize
idx_to_remove = available_coarse.index(coarse_idx)
available_coarse.remove(coarse_idx)
coarse_probs = self._renormalize(coarse_probs, idx_to_remove)
continue

# Sample fine label
fine_probs = fine_distributions[coarse_idx]
available_fine_indices = list(available_fine[coarse_idx])
fine_probs = fine_probs[: len(available_fine_indices)]
fine_probs = fine_probs / fine_probs.sum() # Renormalize
fine_idx = np.random.choice(available_fine_indices, p=fine_probs)

# Add example to split
split_indices.append(fine_idx)

# Remove selected example
available_fine[coarse_idx].remove(fine_idx)

# Renormalize if necessary
if len(available_fine[coarse_idx]) == 0:
idx_to_remove = available_coarse.index(coarse_idx)
available_coarse.remove(coarse_idx)
coarse_probs = self._renormalize(coarse_probs, idx_to_remove)

indices_per_split.append(np.array(split_indices))

return indices_per_split

def _renormalize(self, probs: np.ndarray, removed_idx: int) -> np.ndarray:
"""Implementation of Algorithm 8 from the paper"""
# Create a list of valid indices (excluding the removed index)
valid_indices = [i for i in range(len(probs)) if i != removed_idx]

# Select only the probabilities for valid indices
valid_probs = probs[valid_indices]

# Normalize the remaining probabilities
return valid_probs / valid_probs.sum()

def get_split(self, split_id: int, num_splits: int, balanced: bool, iid: bool) -> Dataset:
"""Get a specific split of the dataset
:param split_id: ID of the split to retrieve
:param num_splits: Total number of splits
:param balanced: Whether splits are balanced
:param iid: Whether splits are IID
:return: Dataset split
"""
config_key = f"splits_{num_splits}_bal_{balanced}_iid_{iid}"
split_path = os.path.join(self.root_dir, f"{config_key}_split_{split_id}.pkl")

if not os.path.exists(split_path):
self.create_splits(num_splits, balanced, iid)

with open(split_path, "rb") as f:
indices = pickle.load(f) # noqa: S301

return Subset(self.trainset, indices)


def get_data_loader(num_splits: int = 5, balanced: bool = True, iid: bool = True, batch_size: int = 100, is_train: bool = True):
"""Get a data loader for the CIFAR-100 dataset
:param num_splits: Number of splits to create
:param balanced: Whether splits are balanced
:param iid: Whether splits are IID
:param batch_size: Batch size
:param is_train: Whether to get the training or test data loader
:return: Data loader
"""
cifar_data = CIFAR100Federated()

if is_train:
split_id = os.environ.get("FEDN_DATA_SPLIT_ID", 0)
dataset = cifar_data.get_split(split_id=split_id, num_splits=num_splits, balanced=balanced, iid=iid)
print(f"Getting data loader for split {split_id} of trainset (size: {len(dataset)})")
else:
dataset = cifar_data.testset
print(f"Getting data loader for testset (size: {len(dataset)})")

return DataLoader(dataset, batch_size=batch_size, shuffle=is_train)
13 changes: 13 additions & 0 deletions examples/cifar100/init_fedn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from config import settings
from fedn import APIClient

client = APIClient(
host=settings["DISCOVER_HOST"],
port=settings["DISCOVER_PORT"],
secure=settings["SECURE"],
verify=settings["VERIFY"],
token=settings["ADMIN_TOKEN"],
)

result = client.set_active_model("seed.npz")
print(result["message"])
Loading

0 comments on commit 755051d

Please sign in to comment.