-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathfedprox.py
30 lines (28 loc) · 1.17 KB
/
fedprox.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from flgo.algorithm.fedbase import BasicServer, BasicClient
import copy
import torch
from flgo.utils import fmodule
class Server(BasicServer):
def initialize(self, *args, **kwargs):
self.init_algo_para({'mu':0.1})
class Client(BasicClient):
@fmodule.with_multi_gpus
def train(self, model):
# global parameters
src_model = copy.deepcopy(model)
src_model.freeze_grad()
model.train()
optimizer = self.calculator.get_optimizer(model, lr=self.learning_rate, weight_decay=self.weight_decay, momentum=self.momentum)
for iter in range(self.num_steps):
# get a batch of data
batch_data = self.get_batch_data()
model.zero_grad()
# calculate the loss of the model on batched dataset through task-specified calculator
loss = self.calculator.compute_loss(model, batch_data)['loss']
loss_proximal = 0
for pm, ps in zip(model.parameters(), src_model.parameters()):
loss_proximal += torch.sum(torch.pow(pm - ps, 2))
loss = loss + 0.5 * self.mu * loss_proximal
loss.backward()
optimizer.step()
return