From 0797a50943f96f074687138dc39d5b7b9cc2f334 Mon Sep 17 00:00:00 2001 From: Michael112233 <88572680+Michael112233@users.noreply.github.com> Date: Thu, 18 Jul 2024 22:16:35 +0800 Subject: [PATCH] init --- main_fed.py | 6 +++++- models/Fed.py | 5 +++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/main_fed.py b/main_fed.py index d170fc8..e54ec06 100644 --- a/main_fed.py +++ b/main_fed.py @@ -71,6 +71,10 @@ best_loss = None val_acc_list, net_list = [], [] + dict_len = [] + for i in range(args.num_users): + dict_len.append(len(dict_users[i])) + if args.all_clients: print("Aggregation over all clients") w_locals = [w_glob for i in range(args.num_users)] @@ -89,7 +93,7 @@ w_locals.append(copy.deepcopy(w)) loss_locals.append(copy.deepcopy(loss)) # update global weights - w_glob = FedAvg(w_locals) + w_glob = FedAvg(w_locals, dict_len) # copy weight to net_glob net_glob.load_state_dict(w_glob) diff --git a/models/Fed.py b/models/Fed.py index 29a03d1..635c9c3 100644 --- a/models/Fed.py +++ b/models/Fed.py @@ -7,10 +7,11 @@ from torch import nn -def FedAvg(w): +def FedAvg(w, dict_len): w_avg = copy.deepcopy(w[0]) for k in w_avg.keys(): + w_avg[k] = w_avg[k] * dict_len[0] for i in range(1, len(w)): - w_avg[k] += w[i][k] + w_avg[k] += w[i][k] * dict_len[i] w_avg[k] = torch.div(w_avg[k], len(w)) return w_avg