forked from insait-institute/dager-gradient-inversion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathargs_factory.py
99 lines (85 loc) · 4.89 KB
/
args_factory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import argparse
import time
import sys
def get_args(argv=None):
parser = argparse.ArgumentParser(description='DAGER attack')
parser.add_argument('--neptune', type=str, help='neptune project name, leave empty to not use neptune', default=None)
parser.add_argument('--neptune_offline', action='store_true', help='Run Neptune in offline mode')
parser.add_argument('--label', type=str, default='name of the run', required=False)
# Method and setting
parser.add_argument('--rng_seed', type=int, default=101)
parser.add_argument('--dataset', choices=['cola', 'sst2', 'rte', 'rotten_tomatoes', 'stanfordnlp/imdb', 'glnmario/ECHR'], required=True)
parser.add_argument('--task', choices=['seq_class', 'next_token_pred'], required=True)
parser.add_argument('--pad', choices=['right', 'left'], default='right')
parser.add_argument('--split', choices=['val', 'test'], required=True)
parser.add_argument('-b','--batch_size', type=int, default=1)
parser.add_argument('--n_inputs', type=int, required=True) # val:10/20, test:100
parser.add_argument('--start_input', type=int, default=0)
parser.add_argument('--end_input', type=int, default=100000)
# Model path (defaults to huggingface download, use local path if offline)
parser.add_argument('--model_path', type=str, default='bert-base-uncased')
parser.add_argument('--finetuned_path', type=str, default=None)
parser.add_argument('--cache_dir', type=str, default=None)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--device_grad', type=str, default='cpu')
parser.add_argument('--attn_implementation', type=str, default='sdpa', choices=['sdpa', 'eager'])
parser.add_argument('--precision', type=str, default='full', choices=['8bit', 'half', 'full', 'double'])
parser.add_argument('--parallel', type=int, default=1000)
parser.add_argument('--grad_b', type=int, default=None)
parser.add_argument('--n_layers', type=int, default=2)
parser.add_argument('--rank_tol', type=float, default=None)
parser.add_argument('--rank_cutoff', type=int, default=20)
parser.add_argument('--l1_span_thresh', type=float, default=1e-5)
parser.add_argument('--l2_span_thresh', type=float, default=1e-3)
parser.add_argument('--l1_filter', choices=['maxB', 'all'], required=True)
parser.add_argument('--l2_filter', choices=['overlap', 'non-overlap'], required=True)
parser.add_argument('--distinct_thresh', type=float, default=0.7)
parser.add_argument('--max_ids', type=int, default=-1)
parser.add_argument('--maxC', type=int, default=10000000)
parser.add_argument('--reduce_incorrect', type=int, default=0)
parser.add_argument('--n_incorrect', type=int, default=None)
# FedAVG
parser.add_argument('--algo', type=str, default='sgd', choices=['sgd', 'fedavg'])
parser.add_argument('--avg_epochs', type=int, default=None)
parser.add_argument('--b_mini', type=int, default=None)
parser.add_argument('--avg_lr', type=float, default=None)
parser.add_argument('--dist_norm', type=str, default='l2', choices=['l1', 'l2'])
#DP
parser.add_argument('--defense_noise', type=float, default=None) # add noise to true grads
parser.add_argument('--max_len', type=int, default=1e10)
parser.add_argument('--p1_std_thrs', type=float, default=5)
parser.add_argument('--l2_std_thrs', type=float, default=5)
parser.add_argument('--dp_l2_filter', type=str, default='maxB', choices=['maxB', 'outliers'])
parser.add_argument('--defense_pct_mask', type=float, default=None) # mask some percentage of gradients
#Dropout
parser.add_argument('--grad_mode', type=str, default='eval', choices=['eval', 'train'])
#Rebuttal experiments
parser.add_argument('--hidden_act', type=str, default=None)
parser.add_argument('--loss', type=str, default='ce', choices=['ce', 'mse'])
#LoRA
parser.add_argument('--train_method', type=str, default='full', choices=['full', 'lora'])
parser.add_argument('--lora_r', type=int, default=None)
if argv is None:
argv = sys.argv[1:]
args=parser.parse_args(argv)
if args.n_incorrect is None:
args.n_incorrect = args.batch_size
if args.neptune is not None:
import neptune.new as neptune
assert('label' in args)
nep_par = { 'project':f"{args.neptune}", 'source_files':["*.py"] }
if args.neptune_offline:
nep_par['mode'] = 'offline'
args.neptune_id = 'DAG-0'
run = neptune.init( **nep_par )
args_dict = vars(args)
run[f"parameters"] = args_dict
args.neptune = run
if not args.neptune_offline:
print('waiting...')
start_wait=time.time()
args.neptune.wait()
print('waited: ',time.time()-start_wait)
args.neptune_id = args.neptune['sys/id'].fetch()
print( '\n\n\nArgs:', *argv, '\n\n\n' )
return args