-
Notifications
You must be signed in to change notification settings - Fork 0
/
runner.py
executable file
·49 lines (42 loc) · 1.86 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#!/usr/bin/env python
import argparse
from client import FLClient
from server import FLServer
from cinic10_ds import get_train_ds, get_test_val_ds
import os
import yaml
def run():
pass
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument("--model-config", type=str, required=False)
parser.add_argument("--server", action="store_true")
parser.add_argument("--client", action="store_true")
parser.add_argument("--rounds", type=int, default=16)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--total-clients", type=int, default=10)
parser.add_argument("--client-index", type=int, default=0)
parser.add_argument("--server-address", type=str, default='0.0.0.0')
parser.add_argument("--data", type=str, default='cifar10')
parser.add_argument("--is-poisoned", action="store_true")
parser.add_argument("--is-noniid", action="store_true")
parser.add_argument("--nr_of_split_per_round", type=int, default=4)
args = parser.parse_args()
model_config_file = args.model_config or 'model_config.yml'
model_config = {}
if os.path.isfile(model_config_file):
with open(model_config_file) as f:
model_config = yaml.load(f, Loader=yaml.SafeLoader)
rounds = model_config.get('rounds') or args.rounds
epochs = model_config.get('epochs') or args.epochs
nr_of_split_per_round = model_config.get('nr_of_split_per_round') or args.nr_of_split_per_round
if args.server:
FLServer(rounds, epochs, nr_of_split_per_round, args.data).start()
else:
FLClient(
*get_train_ds(args.total_clients, args.client_index, args.data),
*get_test_val_ds(args.data),
is_poisoned=args.is_poisoned,
is_noniid=args.is_noniid,
data=args.data
).start(args.server_address)