Skip to content

Commit

Permalink
code checks
Browse files Browse the repository at this point in the history
  • Loading branch information
mattiasakesson committed Dec 9, 2024
1 parent 4e50210 commit a502f4c
Showing 1 changed file with 16 additions and 26 deletions.
42 changes: 16 additions & 26 deletions examples/mnist-pytorch-DPSGD/client/train.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
import math
import os
import sys
import yaml

import numpy as np
import torch
import yaml
from data import load_data
from model import load_parameters, save_parameters
from opacus import PrivacyEngine
from opacus.utils.batch_memory_manager import BatchMemoryManager


from model import load_parameters, save_parameters
from data import load_data
from fedn.utils.helpers.helpers import save_metadata
import numpy as np


dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.abspath(dir_path))


# Define a custom Dataset class
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, x_data, y_data):
Expand All @@ -30,14 +29,16 @@ def __getitem__(self, idx):
y_data = self.y_data[idx]
return x_data, y_data


MAX_PHYSICAL_BATCH_SIZE = 32
EPOCHS = 1
EPSILON = 1000.
EPSILON = 1000.0
DELTA = 1e-5
MAX_GRAD_NORM = 1.2
GLOBAL_ROUNDS = 10
HARDLIMIT = True


def train(in_model_path, out_model_path, data_path=None, batch_size=32, lr=0.01):
"""Complete a model update.
Expand All @@ -58,7 +59,6 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, lr=0.01)
:param lr: The learning rate to use.
:type lr: float
"""

with open("../../client_settings.yaml", "r") as fh:
try:
settings = yaml.safe_load(fh)
Expand All @@ -80,9 +80,6 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, lr=0.01)

# Train
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
n_batches = int(math.ceil(len(x_train) / batch_size))
criterion = torch.nn.NLLLoss()

privacy_engine = PrivacyEngine()

if os.path.isfile("privacy_accountant.state"):
Expand All @@ -91,17 +88,15 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, lr=0.01)
trainset = CustomDataset(x_train, y_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, num_workers=2)


try:
epsilon_spent = privacy_engine.get_epsilon(DELTA)
except:
except ValueError:
epsilon_spent = 0
print("epsilon before training: ", epsilon_spent)

round_epsilon = np.sqrt((epsilon_spent/EPSILON*np.sqrt(GLOBAL_ROUNDS))**2+1)*EPSILON/np.sqrt(GLOBAL_ROUNDS)
round_epsilon = np.sqrt((epsilon_spent / EPSILON * np.sqrt(GLOBAL_ROUNDS)) ** 2 + 1) * EPSILON / np.sqrt(GLOBAL_ROUNDS)

print("target epsilon: ", round_epsilon)

model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
Expand All @@ -116,10 +111,10 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, lr=0.01)
train_dp(model, train_loader, optimizer, EPOCHS, device, privacy_engine)
try:
print("epsilon after training: ", privacy_engine.get_epsilon(DELTA))
except:
except ValueError:
print("cant calculate epsilon")

if HARDLIMIT and privacy_engine.get_epsilon(DELTA)<EPSILON:
if HARDLIMIT and privacy_engine.get_epsilon(DELTA) < EPSILON:
# Metadata needed for aggregation server side
metadata = {
# num_examples are mandatory
Expand All @@ -138,17 +133,13 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, lr=0.01)
print("Epsilon too high, not saving model")

# Save privacy accountant
torch.save(privacy_engine.accountant,"privacy_accountant.state")
torch.save(privacy_engine.accountant, "privacy_accountant.state")


def train_dp(model, train_loader, optimizer, epoch, device, privacy_engine):
model.train()
criterion = torch.nn.NLLLoss() # nn.CrossEntropyLoss()
with BatchMemoryManager(
data_loader=train_loader,
max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE,
optimizer=optimizer
) as memory_safe_data_loader:

with BatchMemoryManager(data_loader=train_loader, max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE, optimizer=optimizer) as memory_safe_data_loader:
for i, (images, target) in enumerate(memory_safe_data_loader):
optimizer.zero_grad()
images = images.to(device)
Expand All @@ -161,6 +152,5 @@ def train_dp(model, train_loader, optimizer, epoch, device, privacy_engine):
optimizer.step()



if __name__ == "__main__":
train(sys.argv[1], sys.argv[2])

0 comments on commit a502f4c

Please sign in to comment.